Spaces:
Running on Zero
Running on Zero
File size: 8,692 Bytes
cdd2fbf 158d06d cdd2fbf 158d06d cdd2fbf 158d06d cdd2fbf 6dc2242 cdd2fbf 6dc2242 cdd2fbf 158d06d cdd2fbf 6dc2242 cdd2fbf 9e4faa1 cdd2fbf 6dc2242 cdd2fbf 158d06d cdd2fbf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | """hitech-extract — numind/NuExtract3 schema-constrained extraction Space.
Serves the contract that `core/clients.py::_predict_default` expects:
predict(prompt: str, schema_json: str, image_path: str | None) -> str # JSON
NuExtract3 is template-driven: it takes a `template=` chat-template kwarg whose
values are NuExtract's own type DSL (`"verbatim-string"`, `"number"`, nested
objects, `[...]` arrays), NOT JSON Schema. Passing the raw Pydantic JSON Schema as
the template makes it extract *nothing* (verified live: all-null output). So this
Space converts the incoming `schema_json` → NuExtract DSL before extraction
(`_schema_to_template`). The noisy composed `prompt` is fine as the document text
— NuExtract follows the template and ignores the surrounding prose.
The converter drops `needs_review` / `needs_review_reason`: NuExtract is a pure
extractor and cannot self-assess uncertainty, so for extract-routed tasks that
flag is set downstream by Supabase-grounded validators (Step 4), not the model.
Both fields carry Pydantic defaults, so omitting them validates cleanly. Open
maps (`dict[str, str]`, e.g. SpecLookup.properties) aren't expressible in the DSL
and are likewise dropped (default `{}`).
ZeroGPU notes (see huggingface-zerogpu skill): eager module-scope load; `import
spaces` is unconditional and omitted from requirements.txt; greedy decode with
thinking disabled for clean JSON.
"""
from __future__ import annotations
import json
import re
from typing import Any
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor
MODEL_ID = "numind/NuExtract3"
MAX_NEW_TOKENS = 2048
# Fields the model cannot meaningfully populate; set downstream instead.
_DROP_FIELDS = {"needs_review", "needs_review_reason"}
def _node_to_dsl(node: dict[str, Any], defs: dict[str, Any]) -> Any | None:
"""Map one JSON-Schema node to a NuExtract DSL value (None = drop the field)."""
if "$ref" in node:
node = defs.get(node["$ref"].split("/")[-1], {})
if "enum" in node: # constrained choice -> list of allowed strings
return [str(v) for v in node["enum"]]
if "anyOf" in node: # nullable / unions -> first concrete (non-null) branch
concrete = [s for s in node["anyOf"] if s.get("type") != "null"]
return _node_to_dsl(concrete[0], defs) if concrete else "verbatim-string"
node_type = node.get("type")
if node_type == "object":
props = node.get("properties")
return _object_to_dsl(props, defs) if props else None # open map -> drop
if node_type == "array":
item = _node_to_dsl(node.get("items", {}), defs)
return [item] if item is not None else []
if node_type == "integer":
return "integer"
if node_type == "number":
return "number"
return "verbatim-string" # strings (and rare booleans) extract verbatim
def _object_to_dsl(props: dict[str, Any], defs: dict[str, Any]) -> dict[str, Any]:
out: dict[str, Any] = {}
for name, sub in props.items():
if name in _DROP_FIELDS:
continue
value = _node_to_dsl(sub, defs)
if value is not None:
out[name] = value
return out
def _schema_to_template(schema_json: str) -> str:
"""Convert a Pydantic JSON Schema string into a NuExtract DSL template string."""
schema = json.loads(schema_json)
template = _object_to_dsl(schema.get("properties", {}), schema.get("$defs", {}))
return json.dumps(template, indent=4)
# Eager module-scope load. NuExtract3 is small (~9 GB bf16) and ships its own
# qwen3_5 modeling code via trust_remote_code.
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
).eval()
def _preprocess_image(pil_img: Image.Image) -> Image.Image:
"""Deskew + auto-rotate a scanned document before extraction.
Scanned POs and spec sheets — the high-frequency input routed to this Space —
often arrive slightly tilted from a phone camera or flatbed scanner, which hurts
NuExtract's field extraction. Mirrors the VLM Space's preprocessing exactly:
1. Convert to grayscale and threshold to isolate text/content.
2. Find the dominant skew angle via Hough lines and rotate to correct it.
3. Return as RGB PIL Image (unchanged if preprocessing fails for any reason).
"""
try:
img = np.array(pil_img.convert("RGB"))
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
# Detect lines to estimate skew angle
lines = cv2.HoughLinesP(binary, 1, np.pi / 180, threshold=100, minLineLength=100, maxLineGap=10)
if lines is None or len(lines) == 0:
return pil_img
angles = []
for line in lines:
x1, y1, x2, y2 = line[0]
if x2 != x1:
angles.append(np.degrees(np.arctan2(y2 - y1, x2 - x1)))
if not angles:
return pil_img
# Median angle — robust against outlier lines
skew = float(np.median(angles))
# Only correct small skews (> 0.5° and < 45°) to avoid false rotations
if abs(skew) < 0.5 or abs(skew) > 45:
return pil_img
h, w = img.shape[:2]
center = (w / 2, h / 2)
M = cv2.getRotationMatrix2D(center, skew, 1.0)
rotated = cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
return Image.fromarray(rotated)
except Exception:
return pil_img
def _clean_json(text: str) -> str:
"""Drop <think> blocks, code fences and prose; keep the JSON object.
`core/clients.py` does an unforgiving `json.loads` on the return value.
"""
text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
fence = re.search(r"```(?:json)?\s*(.*?)```", text, flags=re.DOTALL)
if fence:
text = fence.group(1)
# Return the FIRST complete JSON object. A plain find('{')..rfind('}') span
# breaks when the model emits two objects (e.g. an example then the answer):
# the slice would span the gap between them and fail the caller's json.loads.
start = text.find("{")
if start != -1:
try:
obj, _ = json.JSONDecoder().raw_decode(text[start:])
return json.dumps(obj)
except json.JSONDecodeError:
end = text.rfind("}") # fallback: original brace-span heuristic
if end > start:
return text[start : end + 1].strip()
return text.strip()
def _build_messages(prompt: str, image_path: str | None) -> list[dict]:
content: list[dict] = []
if image_path:
img = _preprocess_image(Image.open(image_path).convert("RGB"))
content.append({"type": "image", "image": img})
content.append({"type": "text", "text": prompt})
return [{"role": "user", "content": content}]
@spaces.GPU(duration=90)
def predict(prompt: str, schema_json: str, image_path: str | None) -> str:
"""Run one schema-constrained extraction and return a JSON string."""
messages = _build_messages(prompt, image_path)
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
template=_schema_to_template(schema_json), # JSON Schema -> NuExtract DSL
enable_thinking=False,
).to(model.device)
with torch.inference_mode():
generated = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
)
generated = generated[:, inputs["input_ids"].shape[1] :]
text = processor.batch_decode(
generated,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
return _clean_json(text)
# gr.Interface exposes `fn` at api_name="/predict", which is what core.clients calls.
demo = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="prompt"),
gr.Textbox(label="schema_json"),
gr.Image(type="filepath", label="image_path"),
],
outputs=gr.Textbox(label="json"),
title="Hi-Tech Extract",
description=(
"NuExtract3 — schema-constrained document→JSON extraction (text or scanned "
"image) for the Hi-Tech AI Platform. Returns a JSON string for core.clients "
"to Pydantic-validate."
),
)
if __name__ == "__main__":
demo.launch()
|