""" 3-Class Document Classifier API Routes page images to: OCR | Wound Care | Clinical LLM API endpoint: /api/predict """ import time import numpy as np import cv2 import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr from PIL import Image from torchvision import models, transforms # ── Constants ──────────────────────────────────────────────────────────────── NUM_CLASSES = 3 LABEL_MAP = {"text_document": 0, "wound": 1, "clinical_medical": 2} LABEL_NAMES = {v: k for k, v in LABEL_MAP.items()} IMG_SIZE = 224 IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] ROUTE_MAP = { "text_document": "ocr", "wound": "wound_care", "clinical_medical": "clinical", } DISPLAY_NAMES = { "text_document": "Text Document", "wound": "Wound Image", "clinical_medical": "Clinical / Medical Image", } ROUTE_DISPLAY = { "ocr": "OCR Text Extraction", "wound_care": "Wound Care Pipeline", "clinical": "Clinical LLM Analysis", "both_medical": "Wound Care + Clinical LLM", "all": "All Pipelines (Review Needed)", } MEDICAL_LABELS = {"wound", "clinical_medical"} CNN_HIGH = 0.92 CNN_MED = 0.75 HEUR_MIN = 0.55 # ── Model ──────────────────────────────────────────────────────────────────── class DocumentClassifierCNN(nn.Module): def __init__(self): super().__init__() self.backbone = models.efficientnet_b0(weights=None) in_features = self.backbone.classifier[1].in_features self.backbone.classifier = nn.Sequential( nn.Dropout(p=0.3), nn.Linear(in_features, NUM_CLASSES), ) def forward(self, x): return self.backbone(x) def predict_proba(self, x): with torch.no_grad(): return F.softmax(self.forward(x), dim=-1) def load_model(): model = DocumentClassifierCNN() ckpt = torch.load("best_model.pth", map_location="cpu", weights_only=False) model.load_state_dict(ckpt["model_state_dict"]) model.eval() return model val_transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ]) # ── Heuristics ─────────────────────────────────────────────────────────────── def _sat_score(bgr): hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) return float(1.0 - min(hsv[:, :, 1].mean() / 80.0, 1.0)) def _edge_score(bgr): g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) s = cv2.resize(g, (512, 512)) h = np.abs(cv2.Sobel(s, cv2.CV_64F, 1, 0, ksize=3)).sum() v = np.abs(cv2.Sobel(s, cv2.CV_64F, 0, 1, ksize=3)).sum() t = h + v return float(min(max((v / t - 0.45) / 0.15, 0), 1)) if t > 1e-6 else 0.5 def _white_score(bgr): g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) return float(min(np.sum(g > 220) / g.size / 0.5, 1.0)) def _comp_score(bgr): g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) s = cv2.resize(g, (512, 512)) _, b = cv2.threshold(s, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) n, _, stats, _ = cv2.connectedComponentsWithStats(b, 8) if n <= 1: return 0.5 a = stats[1:, cv2.CC_STAT_AREA] return float(min(np.sum((a > 5) & (a < 500)) / (512 * 512 / 100) / 3.0, 1.0)) def _warm_ratio(bgr): hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) h, s, v = hsv[:, :, 0], hsv[:, :, 1], hsv[:, :, 2] m = ((h < 25) | (h > 165)) & (s > 30) & (v > 40) return float(min(m.sum() / h.size / 0.3, 1.0)) def _gray_dom(bgr): hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) return float(1.0 - min(hsv[:, :, 1].mean() / 30.0, 1.0)) def _dark_bg(bgr): g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) return float(min(np.sum(g < 30) / g.size / 0.3, 1.0)) def heuristic_classify(bgr): s1 = { "sat": _sat_score(bgr), "edge": _edge_score(bgr), "white": _white_score(bgr), "comp": _comp_score(bgr), } combined = s1["sat"] * 0.30 + s1["edge"] * 0.15 + s1["white"] * 0.35 + s1["comp"] * 0.20 if combined >= 0.50: conf = min((combined - 0.50) / 0.50 * 0.5 + 0.5, 1.0) return "text_document", conf warm = _warm_ratio(bgr) gray = _gray_dom(bgr) dark = _dark_bg(bgr) w_sig = warm * 0.50 + (1 - gray) * 0.20 + (1 - dark) * 0.30 c_sig = gray * 0.30 + dark * 0.30 + (1 - warm) * 0.40 if w_sig > c_sig: conf = min(0.5 + (w_sig - c_sig) * 2, 1.0) return "wound", conf conf = min(0.5 + (c_sig - w_sig) * 2, 1.0) return "clinical_medical", conf # ── Ensemble ───────────────────────────────────────────────────────────────── def ensemble(cnn_label, cnn_conf, heur_label, heur_conf): agree = cnn_label == heur_label if agree and cnn_conf >= CNN_HIGH and heur_conf >= HEUR_MIN: return cnn_label, ROUTE_MAP[cnn_label], min(cnn_conf, heur_conf), False if agree and cnn_conf >= CNN_MED: return cnn_label, ROUTE_MAP[cnn_label], cnn_conf * 0.9, False cnn_med = cnn_label in MEDICAL_LABELS heur_med = heur_label in MEDICAL_LABELS if cnn_med and heur_med and cnn_label != heur_label: pri = cnn_label if cnn_conf >= heur_conf else heur_label return pri, "both_medical", max(cnn_conf, heur_conf) * 0.6, cnn_conf < 0.6 if cnn_med != heur_med: pri = cnn_label if cnn_conf >= 0.6 else heur_label return pri, "all", max(cnn_conf, heur_conf) * 0.4, True pri = cnn_label if cnn_conf >= 0.5 else heur_label route = ROUTE_MAP.get(pri, "all") if cnn_conf >= CNN_MED else "all" return pri, route, max(cnn_conf, heur_conf) * 0.5, cnn_conf < 0.6 # ── Load model at startup ──────────────────────────────────────────────────── print("Loading model...") model = load_model() print("Model loaded.") # ── Predict function ───────────────────────────────────────────────────────── def predict(image): """Classify a page image into text_document, wound, or clinical_medical.""" if image is None: return {"error": "No image provided"} t0 = time.time() pil = Image.fromarray(image).convert("RGB") tensor = val_transform(pil).unsqueeze(0) probs = model.predict_proba(tensor).squeeze(0) cnn_pred = probs.argmax().item() cnn_label = LABEL_NAMES[cnn_pred] cnn_conf = probs[cnn_pred].item() bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) heur_label, heur_conf = heuristic_classify(bgr) label, route, conf, review = ensemble(cnn_label, cnn_conf, heur_label, heur_conf) elapsed = (time.time() - t0) * 1000 class_probs = {LABEL_NAMES[i]: round(probs[i].item(), 4) for i in range(NUM_CLASSES)} return { "label": label, "label_display": DISPLAY_NAMES.get(label, label), "route": route, "route_display": ROUTE_DISPLAY.get(route, route), "confidence": round(conf, 4), "class_probabilities": class_probs, "cnn_label": cnn_label, "cnn_confidence": round(cnn_conf, 4), "heuristic_label": heur_label, "heuristic_confidence": round(heur_conf, 4), "needs_review": review, "inference_ms": round(elapsed, 1), } # ── Gradio Interface ───────────────────────────────────────────────────────── demo = gr.Interface( fn=predict, inputs=gr.Image(type="numpy", label="Upload Page Image"), outputs=gr.JSON(label="Classification Result"), title="Document Classifier API", description="3-class classifier: **Text Document** | **Wound** | **Clinical/Medical**. " "Use the UI to test, or call the API programmatically at `/api/predict`.", examples=None, api_name="predict", flagging_mode="never", ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)