File size: 8,692 Bytes
cdd2fbf
 
 
 
 
 
158d06d
 
 
 
 
 
 
 
 
 
 
 
 
 
cdd2fbf
 
 
 
 
 
 
 
158d06d
cdd2fbf
158d06d
cdd2fbf
6dc2242
cdd2fbf
6dc2242
cdd2fbf
 
 
 
 
 
 
 
158d06d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdd2fbf
 
 
 
 
 
 
 
 
 
 
6dc2242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdd2fbf
 
 
 
 
 
 
 
 
9e4faa1
 
 
 
 
 
 
 
 
 
 
 
cdd2fbf
 
 
 
 
 
6dc2242
 
cdd2fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
158d06d
cdd2fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""hitech-extract — numind/NuExtract3 schema-constrained extraction Space.

Serves the contract that `core/clients.py::_predict_default` expects:

    predict(prompt: str, schema_json: str, image_path: str | None) -> str   # JSON

NuExtract3 is template-driven: it takes a `template=` chat-template kwarg whose
values are NuExtract's own type DSL (`"verbatim-string"`, `"number"`, nested
objects, `[...]` arrays), NOT JSON Schema. Passing the raw Pydantic JSON Schema as
the template makes it extract *nothing* (verified live: all-null output). So this
Space converts the incoming `schema_json` → NuExtract DSL before extraction
(`_schema_to_template`). The noisy composed `prompt` is fine as the document text
— NuExtract follows the template and ignores the surrounding prose.

The converter drops `needs_review` / `needs_review_reason`: NuExtract is a pure
extractor and cannot self-assess uncertainty, so for extract-routed tasks that
flag is set downstream by Supabase-grounded validators (Step 4), not the model.
Both fields carry Pydantic defaults, so omitting them validates cleanly. Open
maps (`dict[str, str]`, e.g. SpecLookup.properties) aren't expressible in the DSL
and are likewise dropped (default `{}`).

ZeroGPU notes (see huggingface-zerogpu skill): eager module-scope load; `import
spaces` is unconditional and omitted from requirements.txt; greedy decode with
thinking disabled for clean JSON.
"""

from __future__ import annotations

import json
import re
from typing import Any

import cv2
import gradio as gr
import numpy as np
import spaces
import torch
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor

MODEL_ID = "numind/NuExtract3"
MAX_NEW_TOKENS = 2048

# Fields the model cannot meaningfully populate; set downstream instead.
_DROP_FIELDS = {"needs_review", "needs_review_reason"}


def _node_to_dsl(node: dict[str, Any], defs: dict[str, Any]) -> Any | None:
    """Map one JSON-Schema node to a NuExtract DSL value (None = drop the field)."""
    if "$ref" in node:
        node = defs.get(node["$ref"].split("/")[-1], {})
    if "enum" in node:  # constrained choice -> list of allowed strings
        return [str(v) for v in node["enum"]]
    if "anyOf" in node:  # nullable / unions -> first concrete (non-null) branch
        concrete = [s for s in node["anyOf"] if s.get("type") != "null"]
        return _node_to_dsl(concrete[0], defs) if concrete else "verbatim-string"
    node_type = node.get("type")
    if node_type == "object":
        props = node.get("properties")
        return _object_to_dsl(props, defs) if props else None  # open map -> drop
    if node_type == "array":
        item = _node_to_dsl(node.get("items", {}), defs)
        return [item] if item is not None else []
    if node_type == "integer":
        return "integer"
    if node_type == "number":
        return "number"
    return "verbatim-string"  # strings (and rare booleans) extract verbatim


def _object_to_dsl(props: dict[str, Any], defs: dict[str, Any]) -> dict[str, Any]:
    out: dict[str, Any] = {}
    for name, sub in props.items():
        if name in _DROP_FIELDS:
            continue
        value = _node_to_dsl(sub, defs)
        if value is not None:
            out[name] = value
    return out


def _schema_to_template(schema_json: str) -> str:
    """Convert a Pydantic JSON Schema string into a NuExtract DSL template string."""
    schema = json.loads(schema_json)
    template = _object_to_dsl(schema.get("properties", {}), schema.get("$defs", {}))
    return json.dumps(template, indent=4)

# Eager module-scope load. NuExtract3 is small (~9 GB bf16) and ships its own
# qwen3_5 modeling code via trust_remote_code.
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForImageTextToText.from_pretrained(
    MODEL_ID,
    dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
).eval()


def _preprocess_image(pil_img: Image.Image) -> Image.Image:
    """Deskew + auto-rotate a scanned document before extraction.

    Scanned POs and spec sheets — the high-frequency input routed to this Space —
    often arrive slightly tilted from a phone camera or flatbed scanner, which hurts
    NuExtract's field extraction. Mirrors the VLM Space's preprocessing exactly:

    1. Convert to grayscale and threshold to isolate text/content.
    2. Find the dominant skew angle via Hough lines and rotate to correct it.
    3. Return as RGB PIL Image (unchanged if preprocessing fails for any reason).
    """
    try:
        img = np.array(pil_img.convert("RGB"))
        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)

        # Detect lines to estimate skew angle
        lines = cv2.HoughLinesP(binary, 1, np.pi / 180, threshold=100, minLineLength=100, maxLineGap=10)
        if lines is None or len(lines) == 0:
            return pil_img

        angles = []
        for line in lines:
            x1, y1, x2, y2 = line[0]
            if x2 != x1:
                angles.append(np.degrees(np.arctan2(y2 - y1, x2 - x1)))

        if not angles:
            return pil_img

        # Median angle — robust against outlier lines
        skew = float(np.median(angles))
        # Only correct small skews (> 0.5° and < 45°) to avoid false rotations
        if abs(skew) < 0.5 or abs(skew) > 45:
            return pil_img

        h, w = img.shape[:2]
        center = (w / 2, h / 2)
        M = cv2.getRotationMatrix2D(center, skew, 1.0)
        rotated = cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
        return Image.fromarray(rotated)
    except Exception:
        return pil_img


def _clean_json(text: str) -> str:
    """Drop <think> blocks, code fences and prose; keep the JSON object.

    `core/clients.py` does an unforgiving `json.loads` on the return value.
    """
    text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
    fence = re.search(r"```(?:json)?\s*(.*?)```", text, flags=re.DOTALL)
    if fence:
        text = fence.group(1)
    # Return the FIRST complete JSON object. A plain find('{')..rfind('}') span
    # breaks when the model emits two objects (e.g. an example then the answer):
    # the slice would span the gap between them and fail the caller's json.loads.
    start = text.find("{")
    if start != -1:
        try:
            obj, _ = json.JSONDecoder().raw_decode(text[start:])
            return json.dumps(obj)
        except json.JSONDecodeError:
            end = text.rfind("}")  # fallback: original brace-span heuristic
            if end > start:
                return text[start : end + 1].strip()
    return text.strip()


def _build_messages(prompt: str, image_path: str | None) -> list[dict]:
    content: list[dict] = []
    if image_path:
        img = _preprocess_image(Image.open(image_path).convert("RGB"))
        content.append({"type": "image", "image": img})
    content.append({"type": "text", "text": prompt})
    return [{"role": "user", "content": content}]


@spaces.GPU(duration=90)
def predict(prompt: str, schema_json: str, image_path: str | None) -> str:
    """Run one schema-constrained extraction and return a JSON string."""
    messages = _build_messages(prompt, image_path)
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
        template=_schema_to_template(schema_json),  # JSON Schema -> NuExtract DSL
        enable_thinking=False,
    ).to(model.device)

    with torch.inference_mode():
        generated = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,
        )

    generated = generated[:, inputs["input_ids"].shape[1] :]
    text = processor.batch_decode(
        generated,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )[0]
    return _clean_json(text)


# gr.Interface exposes `fn` at api_name="/predict", which is what core.clients calls.
demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.Textbox(label="prompt"),
        gr.Textbox(label="schema_json"),
        gr.Image(type="filepath", label="image_path"),
    ],
    outputs=gr.Textbox(label="json"),
    title="Hi-Tech Extract",
    description=(
        "NuExtract3 — schema-constrained document→JSON extraction (text or scanned "
        "image) for the Hi-Tech AI Platform. Returns a JSON string for core.clients "
        "to Pydantic-validate."
    ),
)

if __name__ == "__main__":
    demo.launch()