hitech-extract / app.py
ripguy's picture
feat(extract): mirror VLM deskew preprocessing into the extract Space
6dc2242
Raw
History Blame Contribute Delete
8.69 kB
"""hitech-extract — numind/NuExtract3 schema-constrained extraction Space.
Serves the contract that `core/clients.py::_predict_default` expects:
predict(prompt: str, schema_json: str, image_path: str | None) -> str # JSON
NuExtract3 is template-driven: it takes a `template=` chat-template kwarg whose
values are NuExtract's own type DSL (`"verbatim-string"`, `"number"`, nested
objects, `[...]` arrays), NOT JSON Schema. Passing the raw Pydantic JSON Schema as
the template makes it extract *nothing* (verified live: all-null output). So this
Space converts the incoming `schema_json` → NuExtract DSL before extraction
(`_schema_to_template`). The noisy composed `prompt` is fine as the document text
— NuExtract follows the template and ignores the surrounding prose.
The converter drops `needs_review` / `needs_review_reason`: NuExtract is a pure
extractor and cannot self-assess uncertainty, so for extract-routed tasks that
flag is set downstream by Supabase-grounded validators (Step 4), not the model.
Both fields carry Pydantic defaults, so omitting them validates cleanly. Open
maps (`dict[str, str]`, e.g. SpecLookup.properties) aren't expressible in the DSL
and are likewise dropped (default `{}`).
ZeroGPU notes (see huggingface-zerogpu skill): eager module-scope load; `import
spaces` is unconditional and omitted from requirements.txt; greedy decode with
thinking disabled for clean JSON.
"""
from __future__ import annotations
import json
import re
from typing import Any
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor
MODEL_ID = "numind/NuExtract3"
MAX_NEW_TOKENS = 2048
# Fields the model cannot meaningfully populate; set downstream instead.
_DROP_FIELDS = {"needs_review", "needs_review_reason"}
def _node_to_dsl(node: dict[str, Any], defs: dict[str, Any]) -> Any | None:
"""Map one JSON-Schema node to a NuExtract DSL value (None = drop the field)."""
if "$ref" in node:
node = defs.get(node["$ref"].split("/")[-1], {})
if "enum" in node: # constrained choice -> list of allowed strings
return [str(v) for v in node["enum"]]
if "anyOf" in node: # nullable / unions -> first concrete (non-null) branch
concrete = [s for s in node["anyOf"] if s.get("type") != "null"]
return _node_to_dsl(concrete[0], defs) if concrete else "verbatim-string"
node_type = node.get("type")
if node_type == "object":
props = node.get("properties")
return _object_to_dsl(props, defs) if props else None # open map -> drop
if node_type == "array":
item = _node_to_dsl(node.get("items", {}), defs)
return [item] if item is not None else []
if node_type == "integer":
return "integer"
if node_type == "number":
return "number"
return "verbatim-string" # strings (and rare booleans) extract verbatim
def _object_to_dsl(props: dict[str, Any], defs: dict[str, Any]) -> dict[str, Any]:
out: dict[str, Any] = {}
for name, sub in props.items():
if name in _DROP_FIELDS:
continue
value = _node_to_dsl(sub, defs)
if value is not None:
out[name] = value
return out
def _schema_to_template(schema_json: str) -> str:
"""Convert a Pydantic JSON Schema string into a NuExtract DSL template string."""
schema = json.loads(schema_json)
template = _object_to_dsl(schema.get("properties", {}), schema.get("$defs", {}))
return json.dumps(template, indent=4)
# Eager module-scope load. NuExtract3 is small (~9 GB bf16) and ships its own
# qwen3_5 modeling code via trust_remote_code.
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
).eval()
def _preprocess_image(pil_img: Image.Image) -> Image.Image:
"""Deskew + auto-rotate a scanned document before extraction.
Scanned POs and spec sheets — the high-frequency input routed to this Space —
often arrive slightly tilted from a phone camera or flatbed scanner, which hurts
NuExtract's field extraction. Mirrors the VLM Space's preprocessing exactly:
1. Convert to grayscale and threshold to isolate text/content.
2. Find the dominant skew angle via Hough lines and rotate to correct it.
3. Return as RGB PIL Image (unchanged if preprocessing fails for any reason).
"""
try:
img = np.array(pil_img.convert("RGB"))
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
# Detect lines to estimate skew angle
lines = cv2.HoughLinesP(binary, 1, np.pi / 180, threshold=100, minLineLength=100, maxLineGap=10)
if lines is None or len(lines) == 0:
return pil_img
angles = []
for line in lines:
x1, y1, x2, y2 = line[0]
if x2 != x1:
angles.append(np.degrees(np.arctan2(y2 - y1, x2 - x1)))
if not angles:
return pil_img
# Median angle — robust against outlier lines
skew = float(np.median(angles))
# Only correct small skews (> 0.5° and < 45°) to avoid false rotations
if abs(skew) < 0.5 or abs(skew) > 45:
return pil_img
h, w = img.shape[:2]
center = (w / 2, h / 2)
M = cv2.getRotationMatrix2D(center, skew, 1.0)
rotated = cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
return Image.fromarray(rotated)
except Exception:
return pil_img
def _clean_json(text: str) -> str:
"""Drop <think> blocks, code fences and prose; keep the JSON object.
`core/clients.py` does an unforgiving `json.loads` on the return value.
"""
text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
fence = re.search(r"```(?:json)?\s*(.*?)```", text, flags=re.DOTALL)
if fence:
text = fence.group(1)
# Return the FIRST complete JSON object. A plain find('{')..rfind('}') span
# breaks when the model emits two objects (e.g. an example then the answer):
# the slice would span the gap between them and fail the caller's json.loads.
start = text.find("{")
if start != -1:
try:
obj, _ = json.JSONDecoder().raw_decode(text[start:])
return json.dumps(obj)
except json.JSONDecodeError:
end = text.rfind("}") # fallback: original brace-span heuristic
if end > start:
return text[start : end + 1].strip()
return text.strip()
def _build_messages(prompt: str, image_path: str | None) -> list[dict]:
content: list[dict] = []
if image_path:
img = _preprocess_image(Image.open(image_path).convert("RGB"))
content.append({"type": "image", "image": img})
content.append({"type": "text", "text": prompt})
return [{"role": "user", "content": content}]
@spaces.GPU(duration=90)
def predict(prompt: str, schema_json: str, image_path: str | None) -> str:
"""Run one schema-constrained extraction and return a JSON string."""
messages = _build_messages(prompt, image_path)
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
template=_schema_to_template(schema_json), # JSON Schema -> NuExtract DSL
enable_thinking=False,
).to(model.device)
with torch.inference_mode():
generated = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
)
generated = generated[:, inputs["input_ids"].shape[1] :]
text = processor.batch_decode(
generated,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
return _clean_json(text)
# gr.Interface exposes `fn` at api_name="/predict", which is what core.clients calls.
demo = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="prompt"),
gr.Textbox(label="schema_json"),
gr.Image(type="filepath", label="image_path"),
],
outputs=gr.Textbox(label="json"),
title="Hi-Tech Extract",
description=(
"NuExtract3 — schema-constrained document→JSON extraction (text or scanned "
"image) for the Hi-Tech AI Platform. Returns a JSON string for core.clients "
"to Pydantic-validate."
),
)
if __name__ == "__main__":
demo.launch()