Spaces:
Running
Running
| """ | |
| app.py β PharMinds OCR API for HuggingFace Spaces (v6 β Pydantic + matcher) | |
| ========================================================================== | |
| Pipeline (handwritten): | |
| Primary: Florence-2 VLM full-image pass (microsoft/Florence-2-base) | |
| Fallback: TrOCR fine-tuned ONNX (Abdou-19/trocr-algerian-medical-onnx) | |
| β used only when Florence-2 is unavailable or returns no text | |
| Post-processing: DrugMatcher (RapidFuzz + phonetic + bilingual FR/AR) | |
| Pipeline (printed): | |
| Primary: PaddleOCR (paddleocr 2.9.1 / paddlepaddle 3.0.0) | |
| Supplement: Florence-2 VLM (adds regions PaddleOCR missed) | |
| Post-processing: DrugMatcher | |
| v6 additions over v5: | |
| β /v2/scan β Pydantic-strict response + DrugMatcher grounding | |
| β /v2/feedback β strict feedback contract, batched corrections | |
| β /v2/health β exposes matcher status + version stamps | |
| β /v2/metrics β simple request/latency counters | |
| β Backward-compat: /scan, /feedback, /health unchanged | |
| """ | |
| from __future__ import annotations | |
| import os, uuid, base64, csv, logging, threading, time, json | |
| from io import BytesIO | |
| from pathlib import Path | |
| from collections import deque | |
| import numpy as np | |
| from PIL import Image | |
| # ββ PyTorch 2.2+ weights_only compatibility fix βββββββββββββββββββββββββββββββ | |
| import torch as _torch_patch | |
| _orig_load = _torch_patch.load | |
| def _patched_load(*args, **kwargs): | |
| kwargs.setdefault("weights_only", False) | |
| return _orig_load(*args, **kwargs) | |
| _torch_patch.load = _patched_load | |
| import torch | |
| try: | |
| import transformers.utils.import_utils as _tu | |
| import transformers.modeling_utils as _mu | |
| _tu.check_torch_load_is_safe = lambda: None # type: ignore[attr-defined] | |
| _mu.check_torch_load_is_safe = lambda: None # type: ignore[attr-defined] | |
| except Exception: | |
| pass | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| import uvicorn | |
| # ββ v2: schema + matcher (colocated in spaces/) ββββββββββββββββββββββββββββββ | |
| from schema import ( | |
| Prescription, Medication, LineCrop, MatchResult, | |
| FeedbackPayload as FeedbackV2Payload, FeedbackCorrection, | |
| ) | |
| from matcher import DrugMatcher, CACHE_FILE as DRUGS_CACHE_PATH | |
| # ββ logging βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s") | |
| log = logging.getLogger("pharminds-ocr") | |
| # ββ config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "Abdou-19/trocr-algerian-medical-onnx") | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_VERSION = os.getenv("MODEL_VERSION", "v1-deployed") # set by deploy | |
| DATASET_VERSION = os.getenv("DATASET_VERSION", "v1") # set by deploy | |
| # ββ app βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="PharMinds OCR API", | |
| version="6.0", | |
| description="Algerian prescription OCR β TrOCR + DrugMatcher", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ model holders βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _trocr = None | |
| _processor = None | |
| _model_id = None | |
| _is_onnx = False | |
| _ready = False | |
| _load_error : str | None = None | |
| # ββ PaddleOCR β lazy-loaded on first /v2/scan-printed request ββββββββββββββββ | |
| _paddle_ocr = None | |
| _paddle_ready = False | |
| _paddle_load_error: str | None = None | |
| _paddle_load_lock = threading.Lock() | |
| # ββ Florence-2 VLM β lazy-loaded, used as primary for printed mode ββββββββββββ | |
| _florence_processor = None | |
| _florence_model = None | |
| _florence_ready = False | |
| _florence_load_error: str | None = None | |
| _florence_lock = threading.Lock() | |
| FLORENCE_MODEL = "microsoft/Florence-2-base" | |
| # ββ matcher (loaded lazily, cheap) ββββββββββββββββββββββββββββββββββββββββββββ | |
| _matcher: DrugMatcher | None = None | |
| _matcher_error: str | None = None | |
| # ββ metrics (in-memory, reset on cold-start) ββββββββββββββββββββββββββββββββββ | |
| _metrics_lock = threading.Lock() | |
| _request_count = 0 | |
| _v2_request_count = 0 | |
| _latencies_ms: deque = deque(maxlen=500) # rolling window | |
| _feedback_count = 0 | |
| def _record(latency_ms: float, *, v2: bool = False): | |
| global _request_count, _v2_request_count | |
| with _metrics_lock: | |
| _request_count += 1 | |
| if v2: | |
| _v2_request_count += 1 | |
| _latencies_ms.append(latency_ms) | |
| def _get_paddle_ocr(): | |
| """Lazy-load PaddleOCR on first /v2/scan-printed request.""" | |
| global _paddle_ocr, _paddle_ready, _paddle_load_error | |
| if _paddle_ready: | |
| return _paddle_ocr | |
| if _paddle_load_error: | |
| return None | |
| with _paddle_load_lock: | |
| if _paddle_ready: | |
| return _paddle_ocr | |
| try: | |
| from paddleocr import PaddleOCR | |
| log.info("Loading PaddleOCR (fr)β¦") | |
| _paddle_ocr = PaddleOCR( | |
| use_angle_cls=True, | |
| lang="fr", | |
| use_gpu=False, | |
| show_log=False, | |
| ) | |
| _paddle_ready = True | |
| log.info("=== PaddleOCR ready ===") | |
| except Exception as e: | |
| _paddle_load_error = f"{type(e).__name__}: {e}" | |
| log.error(f"PaddleOCR load failed: {_paddle_load_error}") | |
| return _paddle_ocr | |
| def _get_florence(): | |
| """Lazy-load Florence-2-base for printed prescription OCR.""" | |
| global _florence_processor, _florence_model, _florence_ready, _florence_load_error | |
| if _florence_ready: | |
| return _florence_processor, _florence_model | |
| if _florence_load_error: | |
| return None, None | |
| with _florence_lock: | |
| if _florence_ready: | |
| return _florence_processor, _florence_model | |
| try: | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| log.info(f"Loading Florence-2: {FLORENCE_MODEL}") | |
| _florence_processor = AutoProcessor.from_pretrained( | |
| FLORENCE_MODEL, trust_remote_code=True | |
| ) | |
| _florence_model = AutoModelForCausalLM.from_pretrained( | |
| FLORENCE_MODEL, trust_remote_code=True, torch_dtype=torch.float32 | |
| ).to(DEVICE).eval() | |
| _florence_ready = True | |
| log.info("=== Florence-2 ready ===") | |
| except Exception as e: | |
| _florence_load_error = f"{type(e).__name__}: {e}" | |
| log.error(f"Florence-2 load failed: {_florence_load_error}") | |
| return _florence_processor, _florence_model | |
| def run_florence_ocr(image: Image.Image, processor, model) -> list[tuple[str, float, list[int]]]: | |
| """Run Florence-2 <OCR_WITH_REGION> on a PIL image. | |
| Returns list of (text, confidence, bbox) sorted top-to-bottom.""" | |
| task = "<OCR_WITH_REGION>" | |
| inputs = processor(text=task, images=image, return_tensors="pt") | |
| if DEVICE != "cpu": | |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=1024, | |
| do_sample=False, | |
| ) | |
| raw = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| parsed = processor.post_process_generation( | |
| raw, task=task, image_size=(image.width, image.height) | |
| ) | |
| # parsed = {"<OCR_WITH_REGION>": {"quad_boxes": [...], "labels": [...]}} | |
| data = parsed.get(task, {}) | |
| labels = data.get("labels", []) | |
| quads = data.get("quad_boxes", []) # each quad: [x1,y1,x2,y2,x3,y3,x4,y4] | |
| lines: list[tuple[str, float, list[int]]] = [] | |
| for text, quad in zip(labels, quads): | |
| text = text.strip() | |
| if not text: | |
| continue | |
| xs = [quad[i] for i in range(0, len(quad), 2)] | |
| ys = [quad[i] for i in range(1, len(quad), 2)] | |
| bbox = [int(min(xs)), int(min(ys)), int(max(xs)), int(max(ys))] | |
| lines.append((text, 0.92, bbox)) | |
| # Sort top-to-bottom so the text order matches the prescription layout | |
| lines.sort(key=lambda t: t[2][1]) | |
| return lines | |
| def _get_matcher() -> DrugMatcher | None: | |
| """Lazy-load the DrugMatcher (small, ~50ms).""" | |
| global _matcher, _matcher_error | |
| if _matcher is not None: | |
| return _matcher | |
| if _matcher_error is not None: | |
| return None | |
| try: | |
| if not DRUGS_CACHE_PATH.exists(): | |
| _matcher_error = f"drugs_cache.json missing at {DRUGS_CACHE_PATH}" | |
| log.warning(_matcher_error) | |
| return None | |
| _matcher = DrugMatcher.from_cache(DRUGS_CACHE_PATH) | |
| log.info(f"DrugMatcher loaded: {_matcher}") | |
| return _matcher | |
| except Exception as e: | |
| _matcher_error = f"{type(e).__name__}: {e}" | |
| log.error(f"DrugMatcher load failed: {_matcher_error}") | |
| return None | |
| def _load_models_bg() -> None: | |
| """Load all OCR models in a background thread so uvicorn starts immediately. | |
| Load order (sequential so we don't OOM): | |
| 1. Custom TrOCR ONNX (handwritten) β primary | |
| 2. Fallback to trocr-small-handwritten if ONNX fails | |
| 3. Florence-2 (printed) β pre-warmed here so the first /v2/scan-printed | |
| request doesn't block on a cold 90-second model download. | |
| """ | |
| global _trocr, _processor, _model_id, _is_onnx, _ready, _load_error | |
| errors: list[str] = [] | |
| try: | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| log.info("transformers imported OK") | |
| except Exception as e: | |
| _load_error = f"transformers import failed: {type(e).__name__}: {e}" | |
| log.error(_load_error) | |
| return | |
| onnx_ok = False | |
| try: | |
| from optimum.onnxruntime import ORTModelForVision2Seq | |
| log.info(f"Trying ONNX: {HF_MODEL_REPO}") | |
| proc = TrOCRProcessor.from_pretrained(HF_MODEL_REPO, token=HF_TOKEN) | |
| model = ORTModelForVision2Seq.from_pretrained(HF_MODEL_REPO, token=HF_TOKEN) | |
| _trocr = model | |
| _processor = proc | |
| _model_id = HF_MODEL_REPO | |
| _is_onnx = True | |
| _ready = True | |
| log.info(f"=== ONNX model ready: {HF_MODEL_REPO} ===") | |
| onnx_ok = True | |
| except Exception as e: | |
| msg = f"ONNX load failed ({HF_MODEL_REPO}): {type(e).__name__}: {e}" | |
| errors.append(msg) | |
| log.warning(msg) | |
| if not onnx_ok: | |
| BASE = "microsoft/trocr-small-handwritten" | |
| try: | |
| log.info(f"Trying PyTorch fallback: {BASE}") | |
| proc = TrOCRProcessor.from_pretrained(BASE) | |
| model = VisionEncoderDecoderModel.from_pretrained(BASE).to(DEVICE).eval() | |
| _trocr = model | |
| _processor = proc | |
| _model_id = BASE | |
| _is_onnx = False | |
| _ready = True | |
| log.info(f"=== PyTorch model ready: {BASE} ===") | |
| except Exception as e: | |
| msg = f"PyTorch fallback failed ({BASE}): {type(e).__name__}: {e}" | |
| errors.append(msg) | |
| log.error(msg) | |
| _load_error = " | ".join(errors) or "Unknown error" | |
| log.error(f"ALL HANDWRITTEN CANDIDATES FAILED: {_load_error}") | |
| # ββ Stage 2: pre-warm Florence-2 βββββββββββββββββββββββββββββββββββββββββββββ | |
| # TrOCR is already ready so handwritten endpoint is live. Florence-2 loads | |
| # next for the handwritten supplement pass + printed fallback. | |
| if _ready: | |
| log.info("TrOCR ready β pre-warming Florence-2 in backgroundβ¦") | |
| _get_florence() | |
| # ββ Stage 3: pre-warm PaddleOCR after Florence-2 βββββββββββββββββββββββββββββ | |
| # PaddleOCR is the primary engine for printed prescriptions. Pre-warming it | |
| # here means the first /v2/scan-printed call is instant instead of waiting | |
| # for a cold model download (~30-60s). | |
| if _ready: | |
| log.info("Florence-2 ready β pre-warming PaddleOCR in backgroundβ¦") | |
| _get_paddle_ocr() | |
| def startup_event(): | |
| # Load TrOCR ONNX in background β uvicorn starts immediately and serves | |
| # /health / /v2/health while models are loading (returns status: loading). | |
| # Same thread then pre-warms Florence-2 after TrOCR is ready. | |
| t = threading.Thread(target=_load_models_bg, daemon=True) | |
| t.start() | |
| log.info("Model loading started in background (TrOCR β then Florence-2)") | |
| # Warm the matcher in parallel β it's tiny and fast | |
| _get_matcher() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Line segmentation β horizontal projection profile | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def segment_lines( | |
| image: Image.Image, | |
| min_line_height: int = 10, | |
| pad: int = 6, | |
| smooth_window: int = 5, | |
| density_threshold: float = 0.06, | |
| ) -> list[tuple[Image.Image, tuple[int, int, int, int]]]: | |
| """Returns [(crop, (x1,y1,x2,y2)), ...].""" | |
| import cv2 | |
| gray = np.array(image.convert("L")) | |
| _, bw = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU) | |
| proj = np.sum(bw, axis=1).astype(float) | |
| if proj.max() == 0: | |
| return [(image, (0, 0, image.width, image.height))] | |
| k = np.ones(smooth_window) / smooth_window | |
| smooth = np.convolve(proj, k, mode="same") | |
| thresh = smooth.max() * density_threshold | |
| H, W = gray.shape | |
| out = [] | |
| in_line = False | |
| y_start = 0 | |
| for y in range(H): | |
| if smooth[y] > thresh and not in_line: | |
| in_line = True | |
| y_start = y | |
| elif smooth[y] <= thresh and in_line: | |
| in_line = False | |
| h = y - y_start | |
| if h >= min_line_height: | |
| y1 = max(0, y_start - pad) | |
| y2 = min(H, y + pad) | |
| crop = image.crop((0, y1, W, y2)) | |
| out.append((crop, (0, y1, W, y2))) | |
| if in_line and H - y_start >= min_line_height: | |
| y1 = max(0, y_start - pad) | |
| crop = image.crop((0, y1, W, H)) | |
| out.append((crop, (0, y1, W, H))) | |
| log.info(f"segment_lines: {len(out)} line(s) detected") | |
| return out or [(image, (0, 0, image.width, image.height))] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Recognition helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _preprocess_line(img: Image.Image) -> Image.Image: | |
| w, h = img.size | |
| if h < 32: | |
| scale = 32 / h | |
| img = img.resize((int(w * scale), 32), Image.LANCZOS) | |
| if w < 64: | |
| canvas = Image.new("RGB", (64, img.height), (255, 255, 255)) | |
| canvas.paste(img, (0, 0)) | |
| img = canvas | |
| return img | |
| def recognize_line(pil_img: Image.Image) -> tuple[str, float]: | |
| pil_img = _preprocess_line(pil_img) | |
| pixel_values = _processor(images=pil_img, return_tensors="pt").pixel_values | |
| if not _is_onnx and DEVICE != "cpu": | |
| pixel_values = pixel_values.to(DEVICE) | |
| with torch.no_grad(): | |
| generated = _trocr.generate(pixel_values, max_new_tokens=64) | |
| text = _processor.batch_decode(generated, skip_special_tokens=True)[0].strip() | |
| return text, 0.75 | |
| def run_paddle_ocr(image: Image.Image, ocr) -> list[tuple[str, float, list[int]]]: | |
| """Run PaddleOCR on a PIL image. Returns list of (text, confidence, bbox).""" | |
| import numpy as np | |
| img_array = np.array(image.convert("RGB")) | |
| result = ocr.ocr(img_array, cls=True) | |
| lines = [] | |
| if not result or not result[0]: | |
| return lines | |
| for item in result[0]: | |
| # item: [[x1,y1],[x2,y2],[x3,y3],[x4,y4]], (text, confidence) | |
| bbox_pts, (text, conf) = item | |
| if not text.strip(): | |
| continue | |
| xs = [int(p[0]) for p in bbox_pts] | |
| ys = [int(p[1]) for p in bbox_pts] | |
| bbox = [min(xs), min(ys), max(xs), max(ys)] | |
| lines.append((text.strip(), float(conf), bbox)) | |
| return lines | |
| def pil_to_b64(img: Image.Image) -> str: | |
| buf = BytesIO() | |
| img.save(buf, format="JPEG", quality=85) | |
| return base64.b64encode(buf.getvalue()).decode() | |
| def _ensure_ready(): | |
| if _load_error: | |
| raise HTTPException(status_code=500, detail=f"Model load failed: {_load_error}") | |
| if not _ready: | |
| raise HTTPException(status_code=503, detail="Models loading; retry in 30s") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # v1 routes (kept for backward compatibility β DO NOT BREAK) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def health(): | |
| if _load_error: | |
| return {"status": "error", "detail": _load_error} | |
| if not _ready: | |
| return {"status": "loading", "model": HF_MODEL_REPO} | |
| return { | |
| "status": "ok", | |
| "device": DEVICE, | |
| "model": _model_id, | |
| "engine": "onnx" if _is_onnx else "pytorch", | |
| "trocr_loaded": True, | |
| } | |
| async def scan(file: UploadFile = File(...)): | |
| _ensure_ready() | |
| t0 = time.time() | |
| raw = await file.read() | |
| try: | |
| image = Image.open(BytesIO(raw)).convert("RGB") | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Cannot decode image.") | |
| max_w = 1600 | |
| if image.width > max_w: | |
| ratio = max_w / image.width | |
| image = image.resize((max_w, int(image.height * ratio)), Image.LANCZOS) | |
| segs = segment_lines(image) | |
| texts: list[str] = [] | |
| crops: list[dict] = [] | |
| for i, (line_img, _bbox) in enumerate(segs): | |
| try: | |
| text, conf = recognize_line(line_img) | |
| if text.strip(): | |
| texts.append(text) | |
| crops.append({ | |
| "line_index": i, | |
| "predicted_text": text, | |
| "confidence": round(conf, 3), | |
| "image_base64": pil_to_b64(line_img), | |
| }) | |
| log.info(f" line {i:02d}: {text!r}") | |
| except Exception as e: | |
| log.warning(f" line {i} skipped ({e})") | |
| _record((time.time() - t0) * 1000, v2=False) | |
| return { | |
| "success": True, | |
| "method": "trocr", | |
| "raw_ocr": texts, | |
| "confidence_score": 0.75, | |
| "extracted_data": { | |
| "doctor_name": None, | |
| "patient_name": None, | |
| "prescription_date": None, | |
| "medications": [], | |
| }, | |
| "line_crops": crops, | |
| } | |
| class FeedbackV1Payload(BaseModel): | |
| image_base64: str | |
| corrected_text: str | |
| async def feedback(data: FeedbackV1Payload): | |
| global _feedback_count | |
| try: | |
| Path("dataset/images").mkdir(parents=True, exist_ok=True) | |
| name = f"feedback_{uuid.uuid4().hex[:8]}.jpg" | |
| b64 = data.image_base64.split(",")[-1] | |
| with open(f"dataset/images/{name}", "wb") as f: | |
| f.write(base64.b64decode(b64)) | |
| csv_path = Path("dataset/labels.csv") | |
| is_new = not csv_path.exists() | |
| with open(csv_path, "a", encoding="utf-8", newline="") as f: | |
| w = csv.writer(f) | |
| if is_new: | |
| w.writerow(["file_name", "text"]) | |
| w.writerow([name, data.corrected_text]) | |
| with _metrics_lock: | |
| _feedback_count += 1 | |
| return {"success": True, "saved": name} | |
| except Exception as e: | |
| log.error(f"Feedback error: {e}") | |
| return {"success": False, "error": str(e)} | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # v2 routes β Pydantic-strict + DrugMatcher | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class V2HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| matcher_loaded: bool | |
| model: str | None = None | |
| engine: str | None = None | |
| device: str | |
| model_version: str | |
| dataset_version: str | |
| drug_count: int = 0 | |
| matcher_error: str | None = None | |
| load_error: str | None = None | |
| class V2MetricsResponse(BaseModel): | |
| request_count: int | |
| v2_request_count: int | |
| feedback_count: int | |
| p50_ms: float | |
| p95_ms: float | |
| sample_size: int | |
| uptime_seconds: float | |
| _BOOT_TIME = time.time() | |
| def v2_health(): | |
| matcher = _get_matcher() | |
| return V2HealthResponse( | |
| status=("ok" if _ready and matcher else "loading" if not _ready else "degraded"), | |
| model_loaded=_ready, | |
| matcher_loaded=matcher is not None, | |
| model=_model_id, | |
| engine="onnx" if _is_onnx else ("pytorch" if _ready else None), | |
| device=DEVICE, | |
| model_version=MODEL_VERSION, | |
| dataset_version=DATASET_VERSION, | |
| drug_count=len(matcher) if matcher else 0, | |
| matcher_error=_matcher_error, | |
| load_error=_load_error, | |
| ) | |
| def v2_metrics(): | |
| with _metrics_lock: | |
| latencies = sorted(_latencies_ms) | |
| n = len(latencies) | |
| p50 = latencies[n // 2] if n else 0.0 | |
| p95 = latencies[int(n * 0.95)] if n else 0.0 | |
| return V2MetricsResponse( | |
| request_count=_request_count, | |
| v2_request_count=_v2_request_count, | |
| feedback_count=_feedback_count, | |
| p50_ms=round(p50, 1), | |
| p95_ms=round(p95, 1), | |
| sample_size=n, | |
| uptime_seconds=round(time.time() - _BOOT_TIME, 1), | |
| ) | |
| async def v2_scan(file: UploadFile = File(...)): | |
| """Handwritten prescription OCR. | |
| Primary: Florence-2 VLM β full-image pass, best for varied handwriting | |
| Fallback: TrOCR ONNX (fine-tuned on Algerian handwriting) β only if Florence-2 | |
| is unavailable or returns no text | |
| Post-process: DrugMatcher (fuzzy + phonetic + bilingual FR/AR) | |
| """ | |
| _ensure_ready() | |
| t0 = time.time() | |
| raw = await file.read() | |
| try: | |
| image = Image.open(BytesIO(raw)).convert("RGB") | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Cannot decode image.") | |
| max_w = 1600 | |
| if image.width > max_w: | |
| ratio = max_w / image.width | |
| image = image.resize((max_w, int(image.height * ratio)), Image.LANCZOS) | |
| line_crops: list[LineCrop] = [] | |
| raw_lines: list[str] = [] | |
| used_method = "trocr" | |
| # ββ Stage 1: Florence-2 VLM (primary β full-image, handles varied handwriting) | |
| fl_proc, fl_model = _get_florence() | |
| if fl_proc is not None: | |
| try: | |
| fl_lines = run_florence_ocr(image, fl_proc, fl_model) | |
| if fl_lines: | |
| for i, (text, conf, bbox) in enumerate(fl_lines): | |
| raw_lines.append(text) | |
| line_crops.append(LineCrop( | |
| bbox=bbox, | |
| text=text, | |
| confidence=conf, | |
| line_index=i, | |
| )) | |
| used_method = "florence2-handwritten" | |
| log.info(f"Florence-2 primary: {len(fl_lines)} regions") | |
| else: | |
| log.warning("Florence-2 returned no text β falling back to TrOCR") | |
| except Exception as e: | |
| log.warning(f"Florence-2 primary failed ({e}) β falling back to TrOCR") | |
| else: | |
| log.warning(f"Florence-2 unavailable ({_florence_load_error or 'not loaded'}) β falling back to TrOCR") | |
| # ββ Stage 2: TrOCR fallback (only when Florence-2 produced nothing) | |
| if not raw_lines: | |
| segs = segment_lines(image) | |
| for i, (line_img, bbox) in enumerate(segs): | |
| try: | |
| text, conf = recognize_line(line_img) | |
| except Exception as e: | |
| log.warning(f"recognize_line failed: {e}") | |
| continue | |
| if text.strip(): | |
| raw_lines.append(text) | |
| line_crops.append(LineCrop( | |
| bbox=list(bbox), | |
| text=text, | |
| confidence=conf, | |
| line_index=i, | |
| image_base64=pil_to_b64(line_img), | |
| )) | |
| used_method = "trocr" | |
| log.info(f"TrOCR fallback: {len(line_crops)} lines") | |
| raw_ocr = " ".join(raw_lines) | |
| # ββ Stage 3: DrugMatcher | |
| medications: list[Medication] = [] | |
| seen_ids: set[str] = set() | |
| matcher = _get_matcher() | |
| if matcher and raw_ocr: | |
| for r in matcher.match_text(raw_ocr): | |
| if r.matched_name and r.drug_id and r.drug_id not in seen_ids: | |
| medications.append(Medication( | |
| name=r.matched_name, | |
| drug_id=r.drug_id, | |
| confidence=r.confidence, | |
| match_strategy=r.strategy, | |
| )) | |
| seen_ids.add(r.drug_id) | |
| overall_conf = ( | |
| sum(c.confidence for c in line_crops) / len(line_crops) | |
| if line_crops else 0.0 | |
| ) | |
| elapsed = (time.time() - t0) * 1000 | |
| _record(elapsed, v2=True) | |
| return Prescription( | |
| success=True, | |
| method=used_method, | |
| confidence=round(overall_conf, 3), | |
| medications=medications, | |
| line_crops=line_crops, | |
| raw_ocr_text=raw_ocr, | |
| processing_ms=int(elapsed), | |
| model_version=MODEL_VERSION, | |
| dataset_version=DATASET_VERSION, | |
| ) | |
| async def v2_scan_printed(file: UploadFile = File(...)): | |
| """Printed prescription OCR. | |
| Primary: PaddleOCR β specialized printed-text engine, fast & accurate | |
| Supplement: Florence-2 VLM β catches any regions PaddleOCR missed | |
| Post-process: DrugMatcher (fuzzy + phonetic + bilingual FR/AR) | |
| """ | |
| t0 = time.time() | |
| raw = await file.read() | |
| try: | |
| image = Image.open(BytesIO(raw)).convert("RGB") | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Cannot decode image.") | |
| max_w = 1600 | |
| if image.width > max_w: | |
| ratio = max_w / image.width | |
| image = image.resize((max_w, int(image.height * ratio)), Image.LANCZOS) | |
| line_crops: list[LineCrop] = [] | |
| raw_lines: list[str] = [] | |
| used_paddle = False | |
| used_florence = False | |
| # ββ Stage 1: PaddleOCR (primary β specialized for printed/typewritten text) | |
| paddle = _get_paddle_ocr() | |
| if paddle is not None: | |
| try: | |
| paddle_lines = run_paddle_ocr(image, paddle) | |
| for i, (text, conf, bbox) in enumerate(paddle_lines): | |
| raw_lines.append(text) | |
| line_crops.append(LineCrop(bbox=bbox, text=text, confidence=conf, line_index=i)) | |
| log.info(f"PaddleOCR printed: {len(paddle_lines)} lines") | |
| used_paddle = bool(paddle_lines) | |
| except Exception as e: | |
| log.warning(f"PaddleOCR failed: {e}") | |
| else: | |
| log.warning(f"PaddleOCR unavailable: {_paddle_load_error or 'not loaded'}") | |
| # ββ Stage 2: Florence-2 supplement (adds regions PaddleOCR missed) | |
| fl_proc, fl_model = _get_florence() | |
| if fl_proc is not None: | |
| try: | |
| fl_lines = run_florence_ocr(image, fl_proc, fl_model) | |
| existing = {t.lower().strip() for t in raw_lines} | |
| added = 0 | |
| for text, conf, bbox in fl_lines: | |
| if text.lower().strip() not in existing: | |
| raw_lines.append(text) | |
| line_crops.append(LineCrop( | |
| bbox=bbox, text=text, confidence=conf, | |
| line_index=len(line_crops), | |
| )) | |
| added += 1 | |
| if added: | |
| log.info(f"Florence-2 supplement: {added} additional regions") | |
| used_florence = added > 0 | |
| except Exception as e: | |
| log.warning(f"Florence-2 supplement failed: {e}") | |
| else: | |
| log.warning(f"Florence-2 unavailable: {_florence_load_error or 'not loaded'}") | |
| # ββ Both engines returned nothing ββββββββββββββββββββββββββββββββββββββββββββ | |
| if not raw_lines: | |
| log.warning( | |
| f"Printed OCR: all engines returned nothing " | |
| f"(Paddle: {_paddle_load_error or 'empty'} | " | |
| f"Florence: {_florence_load_error or 'empty'})" | |
| ) | |
| return Prescription( | |
| success=False, | |
| method="paddle-printed", | |
| confidence=0.0, | |
| medications=[], | |
| line_crops=[], | |
| raw_ocr_text="", | |
| processing_ms=int((time.time() - t0) * 1000), | |
| model_version=MODEL_VERSION, | |
| dataset_version=DATASET_VERSION, | |
| error="No text detected. Try a clearer photo with better lighting.", | |
| ) | |
| # ββ Method label βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if used_paddle and used_florence: | |
| used_method = "paddle+florence2" | |
| elif used_paddle: | |
| used_method = "paddle-printed" | |
| else: | |
| used_method = "florence2-printed" | |
| # ββ Stage 3: DrugMatcher βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| raw_ocr = " ".join(raw_lines) | |
| medications: list[Medication] = [] | |
| seen_ids: set[str] = set() | |
| matcher = _get_matcher() | |
| if matcher and raw_ocr: | |
| for r in matcher.match_text(raw_ocr): | |
| if r.matched_name and r.drug_id and r.drug_id not in seen_ids: | |
| medications.append(Medication( | |
| name=r.matched_name, | |
| drug_id=r.drug_id, | |
| confidence=r.confidence, | |
| match_strategy=r.strategy, | |
| )) | |
| seen_ids.add(r.drug_id) | |
| overall_conf = ( | |
| sum(c.confidence for c in line_crops) / len(line_crops) | |
| if line_crops else 0.0 | |
| ) | |
| elapsed = (time.time() - t0) * 1000 | |
| _record(elapsed, v2=True) | |
| return Prescription( | |
| success=True, | |
| method=used_method, | |
| confidence=round(overall_conf, 3), | |
| medications=medications, | |
| line_crops=line_crops, | |
| raw_ocr_text=raw_ocr, | |
| processing_ms=int(elapsed), | |
| model_version=MODEL_VERSION, | |
| dataset_version=DATASET_VERSION, | |
| ) | |
| async def v2_feedback(payload: FeedbackV2Payload): | |
| """Batched corrections with audit trail. Appends to dataset/feedback_log.jsonl | |
| so dataset_tool.py /add can promote them later.""" | |
| global _feedback_count | |
| try: | |
| log_dir = Path("dataset/feedback_log") | |
| log_dir.mkdir(parents=True, exist_ok=True) | |
| log_file = log_dir / f"{time.strftime('%Y%m%d')}.jsonl" | |
| record = { | |
| "ts": time.strftime("%Y-%m-%dT%H:%M:%SZ"), | |
| "image_id": payload.image_id, | |
| "reviewer_id": payload.reviewer_id, | |
| "submitted_at": payload.submitted_at, | |
| "corrections": [c.model_dump() for c in payload.corrections], | |
| } | |
| with open(log_file, "a", encoding="utf-8") as f: | |
| f.write(json.dumps(record, ensure_ascii=False) + "\n") | |
| with _metrics_lock: | |
| _feedback_count += len(payload.corrections) | |
| return {"success": True, "logged": len(payload.corrections), "file": str(log_file)} | |
| except Exception as e: | |
| log.error(f"v2 feedback error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |