abdorenouni
feat(ocr): Florence-2 as primary handwritten engine β€” TrOCR as fallback
ce0c24a
"""
app.py β€” PharMinds OCR API for HuggingFace Spaces (v6 β€” Pydantic + matcher)
==========================================================================
Pipeline (handwritten):
Primary: Florence-2 VLM full-image pass (microsoft/Florence-2-base)
Fallback: TrOCR fine-tuned ONNX (Abdou-19/trocr-algerian-medical-onnx)
β€” used only when Florence-2 is unavailable or returns no text
Post-processing: DrugMatcher (RapidFuzz + phonetic + bilingual FR/AR)
Pipeline (printed):
Primary: PaddleOCR (paddleocr 2.9.1 / paddlepaddle 3.0.0)
Supplement: Florence-2 VLM (adds regions PaddleOCR missed)
Post-processing: DrugMatcher
v6 additions over v5:
βœ“ /v2/scan β€” Pydantic-strict response + DrugMatcher grounding
βœ“ /v2/feedback β€” strict feedback contract, batched corrections
βœ“ /v2/health β€” exposes matcher status + version stamps
βœ“ /v2/metrics β€” simple request/latency counters
βœ“ Backward-compat: /scan, /feedback, /health unchanged
"""
from __future__ import annotations
import os, uuid, base64, csv, logging, threading, time, json
from io import BytesIO
from pathlib import Path
from collections import deque
import numpy as np
from PIL import Image
# ── PyTorch 2.2+ weights_only compatibility fix ───────────────────────────────
import torch as _torch_patch
_orig_load = _torch_patch.load
def _patched_load(*args, **kwargs):
kwargs.setdefault("weights_only", False)
return _orig_load(*args, **kwargs)
_torch_patch.load = _patched_load
import torch
try:
import transformers.utils.import_utils as _tu
import transformers.modeling_utils as _mu
_tu.check_torch_load_is_safe = lambda: None # type: ignore[attr-defined]
_mu.check_torch_load_is_safe = lambda: None # type: ignore[attr-defined]
except Exception:
pass
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
# ── v2: schema + matcher (colocated in spaces/) ──────────────────────────────
from schema import (
Prescription, Medication, LineCrop, MatchResult,
FeedbackPayload as FeedbackV2Payload, FeedbackCorrection,
)
from matcher import DrugMatcher, CACHE_FILE as DRUGS_CACHE_PATH
# ── logging ───────────────────────────────────────────────────────────────────
logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s")
log = logging.getLogger("pharminds-ocr")
# ── config ────────────────────────────────────────────────────────────────────
HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "Abdou-19/trocr-algerian-medical-onnx")
HF_TOKEN = os.getenv("HF_TOKEN", None)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_VERSION = os.getenv("MODEL_VERSION", "v1-deployed") # set by deploy
DATASET_VERSION = os.getenv("DATASET_VERSION", "v1") # set by deploy
# ── app ───────────────────────────────────────────────────────────────────────
app = FastAPI(
title="PharMinds OCR API",
version="6.0",
description="Algerian prescription OCR β€” TrOCR + DrugMatcher",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── model holders ─────────────────────────────────────────────────────────────
_trocr = None
_processor = None
_model_id = None
_is_onnx = False
_ready = False
_load_error : str | None = None
# ── PaddleOCR β€” lazy-loaded on first /v2/scan-printed request ────────────────
_paddle_ocr = None
_paddle_ready = False
_paddle_load_error: str | None = None
_paddle_load_lock = threading.Lock()
# ── Florence-2 VLM β€” lazy-loaded, used as primary for printed mode ────────────
_florence_processor = None
_florence_model = None
_florence_ready = False
_florence_load_error: str | None = None
_florence_lock = threading.Lock()
FLORENCE_MODEL = "microsoft/Florence-2-base"
# ── matcher (loaded lazily, cheap) ────────────────────────────────────────────
_matcher: DrugMatcher | None = None
_matcher_error: str | None = None
# ── metrics (in-memory, reset on cold-start) ──────────────────────────────────
_metrics_lock = threading.Lock()
_request_count = 0
_v2_request_count = 0
_latencies_ms: deque = deque(maxlen=500) # rolling window
_feedback_count = 0
def _record(latency_ms: float, *, v2: bool = False):
global _request_count, _v2_request_count
with _metrics_lock:
_request_count += 1
if v2:
_v2_request_count += 1
_latencies_ms.append(latency_ms)
def _get_paddle_ocr():
"""Lazy-load PaddleOCR on first /v2/scan-printed request."""
global _paddle_ocr, _paddle_ready, _paddle_load_error
if _paddle_ready:
return _paddle_ocr
if _paddle_load_error:
return None
with _paddle_load_lock:
if _paddle_ready:
return _paddle_ocr
try:
from paddleocr import PaddleOCR
log.info("Loading PaddleOCR (fr)…")
_paddle_ocr = PaddleOCR(
use_angle_cls=True,
lang="fr",
use_gpu=False,
show_log=False,
)
_paddle_ready = True
log.info("=== PaddleOCR ready ===")
except Exception as e:
_paddle_load_error = f"{type(e).__name__}: {e}"
log.error(f"PaddleOCR load failed: {_paddle_load_error}")
return _paddle_ocr
def _get_florence():
"""Lazy-load Florence-2-base for printed prescription OCR."""
global _florence_processor, _florence_model, _florence_ready, _florence_load_error
if _florence_ready:
return _florence_processor, _florence_model
if _florence_load_error:
return None, None
with _florence_lock:
if _florence_ready:
return _florence_processor, _florence_model
try:
from transformers import AutoProcessor, AutoModelForCausalLM
log.info(f"Loading Florence-2: {FLORENCE_MODEL}")
_florence_processor = AutoProcessor.from_pretrained(
FLORENCE_MODEL, trust_remote_code=True
)
_florence_model = AutoModelForCausalLM.from_pretrained(
FLORENCE_MODEL, trust_remote_code=True, torch_dtype=torch.float32
).to(DEVICE).eval()
_florence_ready = True
log.info("=== Florence-2 ready ===")
except Exception as e:
_florence_load_error = f"{type(e).__name__}: {e}"
log.error(f"Florence-2 load failed: {_florence_load_error}")
return _florence_processor, _florence_model
def run_florence_ocr(image: Image.Image, processor, model) -> list[tuple[str, float, list[int]]]:
"""Run Florence-2 <OCR_WITH_REGION> on a PIL image.
Returns list of (text, confidence, bbox) sorted top-to-bottom."""
task = "<OCR_WITH_REGION>"
inputs = processor(text=task, images=image, return_tensors="pt")
if DEVICE != "cpu":
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
)
raw = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed = processor.post_process_generation(
raw, task=task, image_size=(image.width, image.height)
)
# parsed = {"<OCR_WITH_REGION>": {"quad_boxes": [...], "labels": [...]}}
data = parsed.get(task, {})
labels = data.get("labels", [])
quads = data.get("quad_boxes", []) # each quad: [x1,y1,x2,y2,x3,y3,x4,y4]
lines: list[tuple[str, float, list[int]]] = []
for text, quad in zip(labels, quads):
text = text.strip()
if not text:
continue
xs = [quad[i] for i in range(0, len(quad), 2)]
ys = [quad[i] for i in range(1, len(quad), 2)]
bbox = [int(min(xs)), int(min(ys)), int(max(xs)), int(max(ys))]
lines.append((text, 0.92, bbox))
# Sort top-to-bottom so the text order matches the prescription layout
lines.sort(key=lambda t: t[2][1])
return lines
def _get_matcher() -> DrugMatcher | None:
"""Lazy-load the DrugMatcher (small, ~50ms)."""
global _matcher, _matcher_error
if _matcher is not None:
return _matcher
if _matcher_error is not None:
return None
try:
if not DRUGS_CACHE_PATH.exists():
_matcher_error = f"drugs_cache.json missing at {DRUGS_CACHE_PATH}"
log.warning(_matcher_error)
return None
_matcher = DrugMatcher.from_cache(DRUGS_CACHE_PATH)
log.info(f"DrugMatcher loaded: {_matcher}")
return _matcher
except Exception as e:
_matcher_error = f"{type(e).__name__}: {e}"
log.error(f"DrugMatcher load failed: {_matcher_error}")
return None
def _load_models_bg() -> None:
"""Load all OCR models in a background thread so uvicorn starts immediately.
Load order (sequential so we don't OOM):
1. Custom TrOCR ONNX (handwritten) β€” primary
2. Fallback to trocr-small-handwritten if ONNX fails
3. Florence-2 (printed) β€” pre-warmed here so the first /v2/scan-printed
request doesn't block on a cold 90-second model download.
"""
global _trocr, _processor, _model_id, _is_onnx, _ready, _load_error
errors: list[str] = []
try:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
log.info("transformers imported OK")
except Exception as e:
_load_error = f"transformers import failed: {type(e).__name__}: {e}"
log.error(_load_error)
return
onnx_ok = False
try:
from optimum.onnxruntime import ORTModelForVision2Seq
log.info(f"Trying ONNX: {HF_MODEL_REPO}")
proc = TrOCRProcessor.from_pretrained(HF_MODEL_REPO, token=HF_TOKEN)
model = ORTModelForVision2Seq.from_pretrained(HF_MODEL_REPO, token=HF_TOKEN)
_trocr = model
_processor = proc
_model_id = HF_MODEL_REPO
_is_onnx = True
_ready = True
log.info(f"=== ONNX model ready: {HF_MODEL_REPO} ===")
onnx_ok = True
except Exception as e:
msg = f"ONNX load failed ({HF_MODEL_REPO}): {type(e).__name__}: {e}"
errors.append(msg)
log.warning(msg)
if not onnx_ok:
BASE = "microsoft/trocr-small-handwritten"
try:
log.info(f"Trying PyTorch fallback: {BASE}")
proc = TrOCRProcessor.from_pretrained(BASE)
model = VisionEncoderDecoderModel.from_pretrained(BASE).to(DEVICE).eval()
_trocr = model
_processor = proc
_model_id = BASE
_is_onnx = False
_ready = True
log.info(f"=== PyTorch model ready: {BASE} ===")
except Exception as e:
msg = f"PyTorch fallback failed ({BASE}): {type(e).__name__}: {e}"
errors.append(msg)
log.error(msg)
_load_error = " | ".join(errors) or "Unknown error"
log.error(f"ALL HANDWRITTEN CANDIDATES FAILED: {_load_error}")
# ── Stage 2: pre-warm Florence-2 ─────────────────────────────────────────────
# TrOCR is already ready so handwritten endpoint is live. Florence-2 loads
# next for the handwritten supplement pass + printed fallback.
if _ready:
log.info("TrOCR ready β€” pre-warming Florence-2 in background…")
_get_florence()
# ── Stage 3: pre-warm PaddleOCR after Florence-2 ─────────────────────────────
# PaddleOCR is the primary engine for printed prescriptions. Pre-warming it
# here means the first /v2/scan-printed call is instant instead of waiting
# for a cold model download (~30-60s).
if _ready:
log.info("Florence-2 ready β€” pre-warming PaddleOCR in background…")
_get_paddle_ocr()
@app.on_event("startup")
def startup_event():
# Load TrOCR ONNX in background β€” uvicorn starts immediately and serves
# /health / /v2/health while models are loading (returns status: loading).
# Same thread then pre-warms Florence-2 after TrOCR is ready.
t = threading.Thread(target=_load_models_bg, daemon=True)
t.start()
log.info("Model loading started in background (TrOCR β†’ then Florence-2)")
# Warm the matcher in parallel β€” it's tiny and fast
_get_matcher()
# ═══════════════════════════════════════════════════════════════════════════════
# Line segmentation β€” horizontal projection profile
# ═══════════════════════════════════════════════════════════════════════════════
def segment_lines(
image: Image.Image,
min_line_height: int = 10,
pad: int = 6,
smooth_window: int = 5,
density_threshold: float = 0.06,
) -> list[tuple[Image.Image, tuple[int, int, int, int]]]:
"""Returns [(crop, (x1,y1,x2,y2)), ...]."""
import cv2
gray = np.array(image.convert("L"))
_, bw = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
proj = np.sum(bw, axis=1).astype(float)
if proj.max() == 0:
return [(image, (0, 0, image.width, image.height))]
k = np.ones(smooth_window) / smooth_window
smooth = np.convolve(proj, k, mode="same")
thresh = smooth.max() * density_threshold
H, W = gray.shape
out = []
in_line = False
y_start = 0
for y in range(H):
if smooth[y] > thresh and not in_line:
in_line = True
y_start = y
elif smooth[y] <= thresh and in_line:
in_line = False
h = y - y_start
if h >= min_line_height:
y1 = max(0, y_start - pad)
y2 = min(H, y + pad)
crop = image.crop((0, y1, W, y2))
out.append((crop, (0, y1, W, y2)))
if in_line and H - y_start >= min_line_height:
y1 = max(0, y_start - pad)
crop = image.crop((0, y1, W, H))
out.append((crop, (0, y1, W, H)))
log.info(f"segment_lines: {len(out)} line(s) detected")
return out or [(image, (0, 0, image.width, image.height))]
# ═══════════════════════════════════════════════════════════════════════════════
# Recognition helpers
# ═══════════════════════════════════════════════════════════════════════════════
def _preprocess_line(img: Image.Image) -> Image.Image:
w, h = img.size
if h < 32:
scale = 32 / h
img = img.resize((int(w * scale), 32), Image.LANCZOS)
if w < 64:
canvas = Image.new("RGB", (64, img.height), (255, 255, 255))
canvas.paste(img, (0, 0))
img = canvas
return img
def recognize_line(pil_img: Image.Image) -> tuple[str, float]:
pil_img = _preprocess_line(pil_img)
pixel_values = _processor(images=pil_img, return_tensors="pt").pixel_values
if not _is_onnx and DEVICE != "cpu":
pixel_values = pixel_values.to(DEVICE)
with torch.no_grad():
generated = _trocr.generate(pixel_values, max_new_tokens=64)
text = _processor.batch_decode(generated, skip_special_tokens=True)[0].strip()
return text, 0.75
def run_paddle_ocr(image: Image.Image, ocr) -> list[tuple[str, float, list[int]]]:
"""Run PaddleOCR on a PIL image. Returns list of (text, confidence, bbox)."""
import numpy as np
img_array = np.array(image.convert("RGB"))
result = ocr.ocr(img_array, cls=True)
lines = []
if not result or not result[0]:
return lines
for item in result[0]:
# item: [[x1,y1],[x2,y2],[x3,y3],[x4,y4]], (text, confidence)
bbox_pts, (text, conf) = item
if not text.strip():
continue
xs = [int(p[0]) for p in bbox_pts]
ys = [int(p[1]) for p in bbox_pts]
bbox = [min(xs), min(ys), max(xs), max(ys)]
lines.append((text.strip(), float(conf), bbox))
return lines
def pil_to_b64(img: Image.Image) -> str:
buf = BytesIO()
img.save(buf, format="JPEG", quality=85)
return base64.b64encode(buf.getvalue()).decode()
def _ensure_ready():
if _load_error:
raise HTTPException(status_code=500, detail=f"Model load failed: {_load_error}")
if not _ready:
raise HTTPException(status_code=503, detail="Models loading; retry in 30s")
# ═══════════════════════════════════════════════════════════════════════════════
# v1 routes (kept for backward compatibility β€” DO NOT BREAK)
# ═══════════════════════════════════════════════════════════════════════════════
@app.get("/health")
def health():
if _load_error:
return {"status": "error", "detail": _load_error}
if not _ready:
return {"status": "loading", "model": HF_MODEL_REPO}
return {
"status": "ok",
"device": DEVICE,
"model": _model_id,
"engine": "onnx" if _is_onnx else "pytorch",
"trocr_loaded": True,
}
@app.post("/scan")
async def scan(file: UploadFile = File(...)):
_ensure_ready()
t0 = time.time()
raw = await file.read()
try:
image = Image.open(BytesIO(raw)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Cannot decode image.")
max_w = 1600
if image.width > max_w:
ratio = max_w / image.width
image = image.resize((max_w, int(image.height * ratio)), Image.LANCZOS)
segs = segment_lines(image)
texts: list[str] = []
crops: list[dict] = []
for i, (line_img, _bbox) in enumerate(segs):
try:
text, conf = recognize_line(line_img)
if text.strip():
texts.append(text)
crops.append({
"line_index": i,
"predicted_text": text,
"confidence": round(conf, 3),
"image_base64": pil_to_b64(line_img),
})
log.info(f" line {i:02d}: {text!r}")
except Exception as e:
log.warning(f" line {i} skipped ({e})")
_record((time.time() - t0) * 1000, v2=False)
return {
"success": True,
"method": "trocr",
"raw_ocr": texts,
"confidence_score": 0.75,
"extracted_data": {
"doctor_name": None,
"patient_name": None,
"prescription_date": None,
"medications": [],
},
"line_crops": crops,
}
class FeedbackV1Payload(BaseModel):
image_base64: str
corrected_text: str
@app.post("/feedback")
async def feedback(data: FeedbackV1Payload):
global _feedback_count
try:
Path("dataset/images").mkdir(parents=True, exist_ok=True)
name = f"feedback_{uuid.uuid4().hex[:8]}.jpg"
b64 = data.image_base64.split(",")[-1]
with open(f"dataset/images/{name}", "wb") as f:
f.write(base64.b64decode(b64))
csv_path = Path("dataset/labels.csv")
is_new = not csv_path.exists()
with open(csv_path, "a", encoding="utf-8", newline="") as f:
w = csv.writer(f)
if is_new:
w.writerow(["file_name", "text"])
w.writerow([name, data.corrected_text])
with _metrics_lock:
_feedback_count += 1
return {"success": True, "saved": name}
except Exception as e:
log.error(f"Feedback error: {e}")
return {"success": False, "error": str(e)}
# ═══════════════════════════════════════════════════════════════════════════════
# v2 routes β€” Pydantic-strict + DrugMatcher
# ═══════════════════════════════════════════════════════════════════════════════
class V2HealthResponse(BaseModel):
status: str
model_loaded: bool
matcher_loaded: bool
model: str | None = None
engine: str | None = None
device: str
model_version: str
dataset_version: str
drug_count: int = 0
matcher_error: str | None = None
load_error: str | None = None
class V2MetricsResponse(BaseModel):
request_count: int
v2_request_count: int
feedback_count: int
p50_ms: float
p95_ms: float
sample_size: int
uptime_seconds: float
_BOOT_TIME = time.time()
@app.get("/v2/health", response_model=V2HealthResponse)
def v2_health():
matcher = _get_matcher()
return V2HealthResponse(
status=("ok" if _ready and matcher else "loading" if not _ready else "degraded"),
model_loaded=_ready,
matcher_loaded=matcher is not None,
model=_model_id,
engine="onnx" if _is_onnx else ("pytorch" if _ready else None),
device=DEVICE,
model_version=MODEL_VERSION,
dataset_version=DATASET_VERSION,
drug_count=len(matcher) if matcher else 0,
matcher_error=_matcher_error,
load_error=_load_error,
)
@app.get("/v2/metrics", response_model=V2MetricsResponse)
def v2_metrics():
with _metrics_lock:
latencies = sorted(_latencies_ms)
n = len(latencies)
p50 = latencies[n // 2] if n else 0.0
p95 = latencies[int(n * 0.95)] if n else 0.0
return V2MetricsResponse(
request_count=_request_count,
v2_request_count=_v2_request_count,
feedback_count=_feedback_count,
p50_ms=round(p50, 1),
p95_ms=round(p95, 1),
sample_size=n,
uptime_seconds=round(time.time() - _BOOT_TIME, 1),
)
@app.post("/v2/scan", response_model=Prescription)
async def v2_scan(file: UploadFile = File(...)):
"""Handwritten prescription OCR.
Primary: Florence-2 VLM β€” full-image pass, best for varied handwriting
Fallback: TrOCR ONNX (fine-tuned on Algerian handwriting) β€” only if Florence-2
is unavailable or returns no text
Post-process: DrugMatcher (fuzzy + phonetic + bilingual FR/AR)
"""
_ensure_ready()
t0 = time.time()
raw = await file.read()
try:
image = Image.open(BytesIO(raw)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Cannot decode image.")
max_w = 1600
if image.width > max_w:
ratio = max_w / image.width
image = image.resize((max_w, int(image.height * ratio)), Image.LANCZOS)
line_crops: list[LineCrop] = []
raw_lines: list[str] = []
used_method = "trocr"
# ── Stage 1: Florence-2 VLM (primary β€” full-image, handles varied handwriting)
fl_proc, fl_model = _get_florence()
if fl_proc is not None:
try:
fl_lines = run_florence_ocr(image, fl_proc, fl_model)
if fl_lines:
for i, (text, conf, bbox) in enumerate(fl_lines):
raw_lines.append(text)
line_crops.append(LineCrop(
bbox=bbox,
text=text,
confidence=conf,
line_index=i,
))
used_method = "florence2-handwritten"
log.info(f"Florence-2 primary: {len(fl_lines)} regions")
else:
log.warning("Florence-2 returned no text β€” falling back to TrOCR")
except Exception as e:
log.warning(f"Florence-2 primary failed ({e}) β€” falling back to TrOCR")
else:
log.warning(f"Florence-2 unavailable ({_florence_load_error or 'not loaded'}) β€” falling back to TrOCR")
# ── Stage 2: TrOCR fallback (only when Florence-2 produced nothing)
if not raw_lines:
segs = segment_lines(image)
for i, (line_img, bbox) in enumerate(segs):
try:
text, conf = recognize_line(line_img)
except Exception as e:
log.warning(f"recognize_line failed: {e}")
continue
if text.strip():
raw_lines.append(text)
line_crops.append(LineCrop(
bbox=list(bbox),
text=text,
confidence=conf,
line_index=i,
image_base64=pil_to_b64(line_img),
))
used_method = "trocr"
log.info(f"TrOCR fallback: {len(line_crops)} lines")
raw_ocr = " ".join(raw_lines)
# ── Stage 3: DrugMatcher
medications: list[Medication] = []
seen_ids: set[str] = set()
matcher = _get_matcher()
if matcher and raw_ocr:
for r in matcher.match_text(raw_ocr):
if r.matched_name and r.drug_id and r.drug_id not in seen_ids:
medications.append(Medication(
name=r.matched_name,
drug_id=r.drug_id,
confidence=r.confidence,
match_strategy=r.strategy,
))
seen_ids.add(r.drug_id)
overall_conf = (
sum(c.confidence for c in line_crops) / len(line_crops)
if line_crops else 0.0
)
elapsed = (time.time() - t0) * 1000
_record(elapsed, v2=True)
return Prescription(
success=True,
method=used_method,
confidence=round(overall_conf, 3),
medications=medications,
line_crops=line_crops,
raw_ocr_text=raw_ocr,
processing_ms=int(elapsed),
model_version=MODEL_VERSION,
dataset_version=DATASET_VERSION,
)
@app.post("/v2/scan-printed", response_model=Prescription)
async def v2_scan_printed(file: UploadFile = File(...)):
"""Printed prescription OCR.
Primary: PaddleOCR β€” specialized printed-text engine, fast & accurate
Supplement: Florence-2 VLM β€” catches any regions PaddleOCR missed
Post-process: DrugMatcher (fuzzy + phonetic + bilingual FR/AR)
"""
t0 = time.time()
raw = await file.read()
try:
image = Image.open(BytesIO(raw)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Cannot decode image.")
max_w = 1600
if image.width > max_w:
ratio = max_w / image.width
image = image.resize((max_w, int(image.height * ratio)), Image.LANCZOS)
line_crops: list[LineCrop] = []
raw_lines: list[str] = []
used_paddle = False
used_florence = False
# ── Stage 1: PaddleOCR (primary β€” specialized for printed/typewritten text)
paddle = _get_paddle_ocr()
if paddle is not None:
try:
paddle_lines = run_paddle_ocr(image, paddle)
for i, (text, conf, bbox) in enumerate(paddle_lines):
raw_lines.append(text)
line_crops.append(LineCrop(bbox=bbox, text=text, confidence=conf, line_index=i))
log.info(f"PaddleOCR printed: {len(paddle_lines)} lines")
used_paddle = bool(paddle_lines)
except Exception as e:
log.warning(f"PaddleOCR failed: {e}")
else:
log.warning(f"PaddleOCR unavailable: {_paddle_load_error or 'not loaded'}")
# ── Stage 2: Florence-2 supplement (adds regions PaddleOCR missed)
fl_proc, fl_model = _get_florence()
if fl_proc is not None:
try:
fl_lines = run_florence_ocr(image, fl_proc, fl_model)
existing = {t.lower().strip() for t in raw_lines}
added = 0
for text, conf, bbox in fl_lines:
if text.lower().strip() not in existing:
raw_lines.append(text)
line_crops.append(LineCrop(
bbox=bbox, text=text, confidence=conf,
line_index=len(line_crops),
))
added += 1
if added:
log.info(f"Florence-2 supplement: {added} additional regions")
used_florence = added > 0
except Exception as e:
log.warning(f"Florence-2 supplement failed: {e}")
else:
log.warning(f"Florence-2 unavailable: {_florence_load_error or 'not loaded'}")
# ── Both engines returned nothing ────────────────────────────────────────────
if not raw_lines:
log.warning(
f"Printed OCR: all engines returned nothing "
f"(Paddle: {_paddle_load_error or 'empty'} | "
f"Florence: {_florence_load_error or 'empty'})"
)
return Prescription(
success=False,
method="paddle-printed",
confidence=0.0,
medications=[],
line_crops=[],
raw_ocr_text="",
processing_ms=int((time.time() - t0) * 1000),
model_version=MODEL_VERSION,
dataset_version=DATASET_VERSION,
error="No text detected. Try a clearer photo with better lighting.",
)
# ── Method label ─────────────────────────────────────────────────────────────
if used_paddle and used_florence:
used_method = "paddle+florence2"
elif used_paddle:
used_method = "paddle-printed"
else:
used_method = "florence2-printed"
# ── Stage 3: DrugMatcher ─────────────────────────────────────────────────────
raw_ocr = " ".join(raw_lines)
medications: list[Medication] = []
seen_ids: set[str] = set()
matcher = _get_matcher()
if matcher and raw_ocr:
for r in matcher.match_text(raw_ocr):
if r.matched_name and r.drug_id and r.drug_id not in seen_ids:
medications.append(Medication(
name=r.matched_name,
drug_id=r.drug_id,
confidence=r.confidence,
match_strategy=r.strategy,
))
seen_ids.add(r.drug_id)
overall_conf = (
sum(c.confidence for c in line_crops) / len(line_crops)
if line_crops else 0.0
)
elapsed = (time.time() - t0) * 1000
_record(elapsed, v2=True)
return Prescription(
success=True,
method=used_method,
confidence=round(overall_conf, 3),
medications=medications,
line_crops=line_crops,
raw_ocr_text=raw_ocr,
processing_ms=int(elapsed),
model_version=MODEL_VERSION,
dataset_version=DATASET_VERSION,
)
@app.post("/v2/feedback")
async def v2_feedback(payload: FeedbackV2Payload):
"""Batched corrections with audit trail. Appends to dataset/feedback_log.jsonl
so dataset_tool.py /add can promote them later."""
global _feedback_count
try:
log_dir = Path("dataset/feedback_log")
log_dir.mkdir(parents=True, exist_ok=True)
log_file = log_dir / f"{time.strftime('%Y%m%d')}.jsonl"
record = {
"ts": time.strftime("%Y-%m-%dT%H:%M:%SZ"),
"image_id": payload.image_id,
"reviewer_id": payload.reviewer_id,
"submitted_at": payload.submitted_at,
"corrections": [c.model_dump() for c in payload.corrections],
}
with open(log_file, "a", encoding="utf-8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
with _metrics_lock:
_feedback_count += len(payload.corrections)
return {"success": True, "logged": len(payload.corrections), "file": str(log_file)}
except Exception as e:
log.error(f"v2 feedback error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)