ocr-extraction / app.py
mlbench123's picture
Update app.py
e0549ff verified
Raw
History Blame Contribute Delete
22.5 kB
"""
app.py – OCR Route Data Extraction | Hugging Face Space
============================================================
Pipeline
--------
Stage 1 PaddleOCR
Deep-learning OCR β€” far more accurate than EasyOCR for photos.
Returns word-level bounding boxes + text.
Words are clustered into horizontal row-bands and sorted L→R
to produce one clean text line per table row.
Stage 2 Qwen/Qwen2.5-72B-Instruct (HF serverless Inference API, GPU)
Receives row-organised text + strict JSON schema prompt.
Returns ONE complete structured JSON array in a single call.
The model corrects OCR typos, understands table context, and
classifies each instruction into a navigation constraint.
No hand-written column parsers. No regex tables.
The LLM understands meaning; PaddleOCR provides accurate characters.
"""
from __future__ import annotations
import datetime, json, logging, os, re, time
from statistics import median
import cv2
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import InferenceClient
logging.basicConfig(level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(message)s")
log = logging.getLogger(__name__)
LLM_MODEL = "Qwen/Qwen2.5-72B-Instruct"
# ──────────────────────────────────────────────────────────
# STAGE 1a – IMAGE PREPROCESSING
# ──────────────────────────────────────────────────────────
def preprocess(pil_img: Image.Image) -> np.ndarray:
"""
Convert PIL β†’ BGR numpy, upscale, sharpen, denoise, threshold.
Returns a high-contrast grayscale image for best OCR accuracy.
"""
img = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2BGR)
h, w = img.shape[:2]
# Upscale to at least 2800px long side
if max(h, w) < 2800:
s = 2800 / max(h, w)
img = cv2.resize(img, None, fx=s, fy=s, interpolation=cv2.INTER_CUBIC)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# Unsharp mask – recovers thin font strokes
blur = cv2.GaussianBlur(gray, (0, 0), 3)
gray = cv2.addWeighted(gray, 1.5, blur, -0.5, 0)
# Denoise
gray = cv2.fastNlMeansDenoising(gray, h=10)
# Adaptive threshold – handles uneven lighting from phone camera
thresh = cv2.adaptiveThreshold(
gray, 255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY,
blockSize=21, C=8
)
# Deskew up to Β±5Β°
coords = np.column_stack(np.where(thresh < 128))
if len(coords) > 100:
angle = cv2.minAreaRect(coords)[-1]
if angle < -45: angle += 90
if abs(angle) > 0.3:
h2, w2 = thresh.shape
M = cv2.getRotationMatrix2D((w2 // 2, h2 // 2), angle, 1.0)
thresh = cv2.warpAffine(thresh, M, (w2, h2),
flags=cv2.INTER_CUBIC,
borderMode=cv2.BORDER_REPLICATE)
return thresh
# ──────────────────────────────────────────────────────────
# STAGE 1b – PaddleOCR (lazy-loaded once)
# ──────────────────────────────────────────────────────────
_ocr_engine = None
_ocr_name = "none"
def get_ocr():
global _ocr_engine, _ocr_name
if _ocr_engine is not None:
return _ocr_engine, _ocr_name
# ── Try PaddleOCR first (best accuracy for document photos) ──
try:
from paddleocr import PaddleOCR
log.info("Loading PaddleOCR...")
_ocr_engine = PaddleOCR(
use_textline_orientation=True,
lang="en",
text_det_thresh=0.3,
text_det_box_thresh=0.5,
)
_ocr_name = "PaddleOCR"
log.info("PaddleOCR ready.")
return _ocr_engine, _ocr_name
except Exception as e:
log.warning("PaddleOCR unavailable (%s) β€” falling back to EasyOCR", e)
# ── Fallback: EasyOCR ────────────────────────────────────────
try:
import easyocr
log.info("Loading EasyOCR...")
_ocr_engine = easyocr.Reader(["en"], gpu=False, verbose=False)
_ocr_name = "EasyOCR"
log.info("EasyOCR ready.")
return _ocr_engine, _ocr_name
except Exception as e:
raise RuntimeError(f"No OCR engine available: {e}")
def _run_paddle(engine, img: np.ndarray) -> list[tuple]:
"""
Run PaddleOCR v3 predict() and return normalised (bbox, text, conf) tuples.
Handles multiple result formats defensively.
"""
detections = []
try:
raw = list(engine.predict(img)) # materialise generator
except Exception as e:
raise RuntimeError(f"PaddleOCR.predict failed: {e}") from e
for res in raw:
# ── Try PaddleOCR v3 key access ──────────────────────────
boxes = texts = scores = None
for box_key in ("dt_polys", "dt_boxes"):
try:
boxes = list(res[box_key])
texts = list(res["rec_texts"])
scores = list(res["rec_scores"])
break
except (KeyError, TypeError, IndexError):
continue
if boxes is None:
# ── Fallback: v2 style [[bbox, (text, conf)], ...] ──
try:
for item in res:
if not (isinstance(item, (list, tuple)) and len(item) == 2):
continue
bbox_raw, text_conf = item
if not (isinstance(text_conf, (list, tuple)) and len(text_conf) == 2):
continue
text, conf = text_conf
text = str(text).strip()
if text and float(conf) > 0.3:
bbox = bbox_raw.tolist() if hasattr(bbox_raw, "tolist") else list(bbox_raw)
detections.append((bbox, text, float(conf)))
except Exception as e:
log.warning("Skipping unparseable OCR result: %s", e)
continue
# ── Parse v3 boxes/texts/scores ───────────────────────────
for bbox_raw, text, conf in zip(boxes, texts, scores):
text = str(text).strip()
if not text or float(conf) < 0.3:
continue
# Convert numpy array β†’ plain list of [x,y] points
if hasattr(bbox_raw, "tolist"):
bbox = bbox_raw.tolist()
elif isinstance(bbox_raw, (list, tuple)):
bbox = [[float(c) for c in pt] for pt in bbox_raw]
else:
continue
detections.append((bbox, text, float(conf)))
log.info("PaddleOCR returned %d detections", len(detections))
return detections
def _run_easyocr(engine, img: np.ndarray) -> list[tuple]:
"""Run EasyOCR and normalise output to (bbox, text, conf) tuples."""
results = engine.readtext(img, detail=1, paragraph=False)
return [(bbox, text, float(conf))
for bbox, text, conf in results
if text.strip() and float(conf) > 0.3]
def run_ocr(img: np.ndarray) -> tuple[list[tuple], str]:
engine, name = get_ocr()
if name == "PaddleOCR":
dets = _run_paddle(engine, img)
else:
dets = _run_easyocr(engine, img)
log.info("[%s] %d detections", name, len(dets))
return dets, name
# ──────────────────────────────────────────────────────────
# STAGE 1c – ROW BAND CLUSTERING
# ──────────────────────────────────────────────────────────
def _cx(bbox) -> float:
pts = bbox.tolist() if hasattr(bbox, "tolist") else bbox
return sum(float(p[0]) for p in pts) / len(pts)
def _cy(bbox) -> float:
pts = bbox.tolist() if hasattr(bbox, "tolist") else bbox
return sum(float(p[1]) for p in pts) / len(pts)
_HEADER_PAT = re.compile(r"^\s*(miles|route|distance|time|to)\s*$", re.I)
def detections_to_rows(detections: list[tuple]) -> list[str]:
"""
Cluster bounding-box detections into horizontal row bands.
Return one text string per row, tokens sorted left β†’ right.
Example output line:
"9.90 SL8EFR n Merge onto SL8SFR ne [BW 8SFR] 43.00 01:00"
"""
if not detections:
return []
# Sort all detections by y-centre
items = sorted(
[(_cy(b), _cx(b), t) for b, t, _ in detections],
key=lambda x: x[0]
)
# Median gap between consecutive detections β†’ row-separation threshold
gaps = [items[i+1][0] - items[i][0] for i in range(len(items) - 1)]
line_h = max(1.0, median(gaps) if gaps else 20)
row_thr = max(line_h * 0.65, 10)
bands: list[list] = []
cur = [items[0]]
for item in items[1:]:
if item[0] - cur[-1][0] > row_thr:
bands.append(cur)
cur = [item]
else:
cur.append(item)
bands.append(cur)
rows = []
for band in bands:
# Sort tokens left β†’ right within the band
line = " ".join(t for _, _, t in sorted(band, key=lambda x: x[1]))
# Skip header bands
if _HEADER_PAT.search(line.strip()):
continue
rows.append(line)
log.info("Row clustering produced %d rows", len(rows))
return rows
# ──────────────────────────────────────────────────────────
# STAGE 2 – LLM (Qwen2.5-72B via HF Inference API, GPU)
# ──────────────────────────────────────────────────────────
SYSTEM_PROMPT = """You are a strict route data extraction engine for permit documents.
INPUT: Raw OCR text from a route table image. Each numbered line is one table row.
The columns are: Segment Miles | Road/Route | Navigation Instruction | Cumulative Miles | Time
YOUR TASK: Parse every row and return ONE valid JSON array. Nothing else.
━━━ OUTPUT FORMAT ━━━
Start with [ and end with ].
No explanation, no markdown, no code fences.
Each element in the array:
{
"step": <integer, 1-based sequential>,
"segment_miles": <float, distance for this segment>,
"road": <string, road/highway identifier e.g. "SL8EFR n", "IH45 n">,
"instruction": <string, clean navigation text>,
"cumulative_miles": <float, total distance from start>,
"time": <string, "HH:MM" format>,
"constraints": [
{
"type": <"mandatory_action" | "restriction" | "conditional_rule">,
"action": <string>,
"location": <string, road or junction name>,
"priority": <"hard" | "soft">,
"condition": <string or null>
}
]
}
━━━ FIELD RULES ━━━
segment_miles : small decimal at START of line (e.g. 9.90, 0.20, 214.10)
road : highway code after segment miles (e.g. "SL8EFR n", "IH45 n", "US287 Ramp nw")
instruction : the navigation sentence in the middle (longest text part)
cumulative_miles: the larger decimal near the END of the line (running total, always > 40)
time : HH:MM near the end. ONLY accept values where hours 0-23, minutes 0-59.
If a value like "43.00" appears, it is cumulative_miles not time.
Fix separators: "01.12" or "01*12" β†’ "01:12"
If cumulative > 1000, it has a misplaced decimal: 38290 β†’ 382.90
━━━ OCR CORRECTION ━━━
Fix these common errors in the instruction and road fields:
onlo/Onlo β†’ onto/Onto Tum/Tumn β†’ Turn
lelt/lcli/Ielt β†’ left nighl/righl/rght β†’ right
loward/l0ward β†’ toward conneclor/conecor β†’ connector
Straighi/Straighl β†’ Straight Continuo/Conlinue β†’ Continue
SH1OT/SHTOT/SHTOI β†’ SH101 IH4S β†’ IH45
IHZO/IH2O β†’ IH20 IHAO/IH4O β†’ IH40
UST83/UST8J β†’ US183 USZ87 β†’ US287
SLB β†’ SL8 IH3S/IH3SE β†’ IH35
━━━ CONSTRAINT RULES ━━━
Extract exactly 1 constraint per step. Empty array [] only if truly no action.
mandatory_action β†’ merge, turn_left, turn_right, take_exit, take_ramp,
take_connector, continue_straight β†’ priority: "hard"
conditional_rule β†’ keep_left, keep_right β†’ priority: "soft"
restriction β†’ no_turn, prohibited_action β†’ priority: "hard"
━━━ EXAMPLES ━━━
Row: "9.90 SL8EFR n Merge onto SL8SFR ne [BW 8SFR] [WEST SAM HOUSTON PARKWAY] 43.00 01:00"
β†’ {"step":1,"segment_miles":9.9,"road":"SL8EFR n","instruction":"Merge onto SL8SFR ne [BW 8SFR] [WEST SAM HOUSTON PARKWAY]","cumulative_miles":43.0,"time":"01:00","constraints":[{"type":"mandatory_action","action":"merge","location":"SL8SFR ne / BW 8SFR","priority":"hard","condition":null}]}
Row: "0.20 IH45 e Keep left toward IH45 North/Dallas 51.30 01:13"
β†’ {"step":5,"segment_miles":0.2,"road":"IH45 e","instruction":"Keep left toward IH45 North/Dallas","cumulative_miles":51.3,"time":"01:13","constraints":[{"type":"conditional_rule","action":"keep_left","location":"IH45 North / Dallas","priority":"soft","condition":"heading toward IH45 North/Dallas"}]}"""
def call_llm(row_lines: list[str]) -> str:
table_text = "\n".join(f"{i+1}. {line}" for i, line in enumerate(row_lines))
token = os.environ.get("HF_TOKEN", "")
client = InferenceClient(token=token if token else None)
log.info("Calling %s with %d rows ...", LLM_MODEL, len(row_lines))
t0 = time.perf_counter()
response = client.chat_completion(
model=LLM_MODEL,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content":
f"OCR rows from route document:\n{table_text}\n\nReturn the complete JSON array:"},
],
max_tokens=8000,
temperature=0.01,
)
raw = response.choices[0].message.content.strip()
log.info("LLM call finished in %.1fs", time.perf_counter() - t0)
return raw
def parse_llm_json(raw: str) -> list[dict]:
raw = re.sub(r"```(?:json)?", "", raw, flags=re.I).strip()
start = raw.find("[")
if start == -1:
raise ValueError("LLM response contains no JSON array")
depth = 0
for i, ch in enumerate(raw[start:], start):
depth += (ch == "[") - (ch == "]")
if depth == 0:
return json.loads(raw[start: i + 1])
return json.loads(raw[start:].rstrip(",") + "]")
# ──────────────────────────────────────────────────────────
# POST-PROCESSING – normalise types, fix edge cases
# ──────────────────────────────────────────────────────────
_TIME_RE = re.compile(r"\b(\d{1,2})[.:;*,](\d{2})\b")
def _fix_time(v: str) -> str:
for m in _TIME_RE.finditer(str(v)):
h, mn = int(m.group(1)), int(m.group(2))
if 0 <= h <= 23 and 0 <= mn <= 59:
return f"{h:02d}:{mn:02d}"
return "00:00"
def _fix_miles(v) -> float:
try:
f = float(str(v).replace(",", "."))
return round(f / 100 if f > 1000 else f, 2)
except (ValueError, TypeError):
return 0.0
_VALID_TYPES = {"mandatory_action", "restriction", "conditional_rule"}
_VALID_PRIO = {"hard", "soft"}
def clean_steps(steps: list[dict]) -> list[dict]:
out = []
for i, s in enumerate(steps):
s["step"] = i + 1
s["segment_miles"] = _fix_miles(s.get("segment_miles", 0))
s["cumulative_miles"] = _fix_miles(s.get("cumulative_miles", 0))
s["time"] = _fix_time(s.get("time", ""))
s.setdefault("road", "UNKNOWN")
s.setdefault("instruction", "")
clean_c = []
for c in s.get("constraints", []):
if not isinstance(c, dict): continue
c["type"] = c.get("type", "mandatory_action")
c["priority"] = c.get("priority", "hard")
if c["type"] not in _VALID_TYPES: c["type"] = "mandatory_action"
if c["priority"] not in _VALID_PRIO: c["priority"] = "hard"
c.setdefault("action", "")
c.setdefault("location", "")
c.setdefault("condition", None)
clean_c.append(c)
s["constraints"] = clean_c
out.append(s)
return out
# ──────────────────────────────────────────────────────────
# MAIN PIPELINE
# ──────────────────────────────────────────────────────────
def run_pipeline(image, progress=gr.Progress(track_tqdm=True)):
if image is None:
return '{"error": "No image provided."}', ""
t0 = time.perf_counter()
# ── Stage 1: preprocess ──────────────────────────────
progress(0.05, desc="Preprocessing image...")
processed = preprocess(image)
# ── Stage 1: OCR ─────────────────────────────────────
progress(0.15, desc="Running OCR (PaddleOCR)...")
try:
detections, ocr_name = run_ocr(processed)
except Exception as e:
return json.dumps({"error": f"OCR failed: {e}"}), ""
if not detections:
return '{"error": "OCR returned no text. Try a clearer image."}', ""
# ── Stage 1: row clustering ───────────────────────────
progress(0.35, desc="Organising rows...")
row_lines = detections_to_rows(detections)
if not row_lines:
return '{"error": "No table rows found after clustering."}', ""
debug = "\n".join(f"[row {i+1:02d}] {r}" for i, r in enumerate(row_lines))
# ── Stage 2: LLM ─────────────────────────────────────
progress(0.50, desc=f"Sending {len(row_lines)} rows to LLM...")
try:
raw_llm = call_llm(row_lines)
except Exception as e:
log.error("LLM error: %s", e)
return json.dumps({"error": f"LLM API failed: {e}", "ocr_rows": row_lines}), debug
# ── Parse + clean ─────────────────────────────────────
progress(0.90, desc="Parsing JSON response...")
try:
steps = parse_llm_json(raw_llm)
except Exception as e:
log.error("JSON parse error: %s | raw: %.300s", e, raw_llm)
return json.dumps({
"error": f"LLM returned invalid JSON: {e}",
"raw_output": raw_llm[:1000],
}), debug
steps = clean_steps(steps)
last_cum = max((s["cumulative_miles"] for s in steps), default=0.0)
last_time = next((s["time"] for s in reversed(steps) if s["time"] != "00:00"), "00:00")
result = {
"source": f"uploaded_{datetime.datetime.utcnow().strftime('%H%M%S')}.png",
"extracted_at": datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"),
"ocr_engine": ocr_name,
"llm_model": LLM_MODEL,
"total_steps": len(steps),
"total_miles": last_cum,
"total_time": last_time,
"steps": steps,
}
log.info("Pipeline done in %.1fs β€” %d steps", time.perf_counter() - t0, len(steps))
return json.dumps(result, indent=2, ensure_ascii=False), debug
# ──────────────────────────────────────────────────────────
# GRADIO UI
# ──────────────────────────────────────────────────────────
with gr.Blocks(title="OCR Route Extraction") as demo:
demo.queue()
gr.Markdown(f"""
## OCR Route Data Extraction Pipeline
| Stage | Component | Role |
|-------|-----------|------|
| 1 | **PaddleOCR** (local) | Deep-learning OCR β†’ word bounding boxes |
| 2 | **Row clustering** | Groups words into table rows by y-position |
| 3 | **{LLM_MODEL}** (HF GPU) | Row text β†’ complete structured JSON in one call |
*Constraint types: `mandatory_action` Β· `restriction` Β· `conditional_rule`*
""")
with gr.Row():
with gr.Column(scale=1):
img_input = gr.Image(type="pil", label="Upload Route Document Image", height=460)
run_btn = gr.Button("Extract Route Data", variant="primary", size="lg")
with gr.Column(scale=2):
with gr.Tabs():
with gr.Tab("JSON Output"):
json_out = gr.Code(language="json", label="Structured JSON", lines=34)
with gr.Tab("OCR Rows (sent to LLM)"):
ocr_out = gr.Textbox(
label="Row-organised text β€” exactly what the LLM receives",
lines=26, max_lines=60,
)
run_btn.click(
fn=run_pipeline,
inputs=[img_input],
outputs=[json_out, ocr_out],
api_name=False,
)
gr.Examples(examples=[["route_sample.png"]], inputs=[img_input],
label="Sample route image")
if __name__ == "__main__":
demo.queue()
demo.launch(theme=gr.themes.Soft(), share=True)