Spaces:
Sleeping
Sleeping
| """ | |
| 3-Class Document Classifier API | |
| Routes page images to: OCR | Wound Care | Clinical LLM | |
| API endpoint: /api/predict | |
| """ | |
| import time | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from PIL import Image | |
| from torchvision import models, transforms | |
| # ββ Constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| NUM_CLASSES = 3 | |
| LABEL_MAP = {"text_document": 0, "wound": 1, "clinical_medical": 2} | |
| LABEL_NAMES = {v: k for k, v in LABEL_MAP.items()} | |
| IMG_SIZE = 224 | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| ROUTE_MAP = { | |
| "text_document": "ocr", | |
| "wound": "wound_care", | |
| "clinical_medical": "clinical", | |
| } | |
| DISPLAY_NAMES = { | |
| "text_document": "Text Document", | |
| "wound": "Wound Image", | |
| "clinical_medical": "Clinical / Medical Image", | |
| } | |
| ROUTE_DISPLAY = { | |
| "ocr": "OCR Text Extraction", | |
| "wound_care": "Wound Care Pipeline", | |
| "clinical": "Clinical LLM Analysis", | |
| "both_medical": "Wound Care + Clinical LLM", | |
| "all": "All Pipelines (Review Needed)", | |
| } | |
| MEDICAL_LABELS = {"wound", "clinical_medical"} | |
| CNN_HIGH = 0.92 | |
| CNN_MED = 0.75 | |
| HEUR_MIN = 0.55 | |
| # ββ Model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class DocumentClassifierCNN(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.backbone = models.efficientnet_b0(weights=None) | |
| in_features = self.backbone.classifier[1].in_features | |
| self.backbone.classifier = nn.Sequential( | |
| nn.Dropout(p=0.3), nn.Linear(in_features, NUM_CLASSES), | |
| ) | |
| def forward(self, x): | |
| return self.backbone(x) | |
| def predict_proba(self, x): | |
| with torch.no_grad(): | |
| return F.softmax(self.forward(x), dim=-1) | |
| def load_model(): | |
| model = DocumentClassifierCNN() | |
| ckpt = torch.load("best_model.pth", map_location="cpu", weights_only=False) | |
| model.load_state_dict(ckpt["model_state_dict"]) | |
| model.eval() | |
| return model | |
| val_transform = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), | |
| ]) | |
| # ββ Heuristics βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _sat_score(bgr): | |
| hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) | |
| return float(1.0 - min(hsv[:, :, 1].mean() / 80.0, 1.0)) | |
| def _edge_score(bgr): | |
| g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) | |
| s = cv2.resize(g, (512, 512)) | |
| h = np.abs(cv2.Sobel(s, cv2.CV_64F, 1, 0, ksize=3)).sum() | |
| v = np.abs(cv2.Sobel(s, cv2.CV_64F, 0, 1, ksize=3)).sum() | |
| t = h + v | |
| return float(min(max((v / t - 0.45) / 0.15, 0), 1)) if t > 1e-6 else 0.5 | |
| def _white_score(bgr): | |
| g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) | |
| return float(min(np.sum(g > 220) / g.size / 0.5, 1.0)) | |
| def _comp_score(bgr): | |
| g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) | |
| s = cv2.resize(g, (512, 512)) | |
| _, b = cv2.threshold(s, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) | |
| n, _, stats, _ = cv2.connectedComponentsWithStats(b, 8) | |
| if n <= 1: return 0.5 | |
| a = stats[1:, cv2.CC_STAT_AREA] | |
| return float(min(np.sum((a > 5) & (a < 500)) / (512 * 512 / 100) / 3.0, 1.0)) | |
| def _warm_ratio(bgr): | |
| hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) | |
| h, s, v = hsv[:, :, 0], hsv[:, :, 1], hsv[:, :, 2] | |
| m = ((h < 25) | (h > 165)) & (s > 30) & (v > 40) | |
| return float(min(m.sum() / h.size / 0.3, 1.0)) | |
| def _gray_dom(bgr): | |
| hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) | |
| return float(1.0 - min(hsv[:, :, 1].mean() / 30.0, 1.0)) | |
| def _dark_bg(bgr): | |
| g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) | |
| return float(min(np.sum(g < 30) / g.size / 0.3, 1.0)) | |
| def heuristic_classify(bgr): | |
| s1 = { | |
| "sat": _sat_score(bgr), "edge": _edge_score(bgr), | |
| "white": _white_score(bgr), "comp": _comp_score(bgr), | |
| } | |
| combined = s1["sat"] * 0.30 + s1["edge"] * 0.15 + s1["white"] * 0.35 + s1["comp"] * 0.20 | |
| if combined >= 0.50: | |
| conf = min((combined - 0.50) / 0.50 * 0.5 + 0.5, 1.0) | |
| return "text_document", conf | |
| warm = _warm_ratio(bgr) | |
| gray = _gray_dom(bgr) | |
| dark = _dark_bg(bgr) | |
| w_sig = warm * 0.50 + (1 - gray) * 0.20 + (1 - dark) * 0.30 | |
| c_sig = gray * 0.30 + dark * 0.30 + (1 - warm) * 0.40 | |
| if w_sig > c_sig: | |
| conf = min(0.5 + (w_sig - c_sig) * 2, 1.0) | |
| return "wound", conf | |
| conf = min(0.5 + (c_sig - w_sig) * 2, 1.0) | |
| return "clinical_medical", conf | |
| # ββ Ensemble βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def ensemble(cnn_label, cnn_conf, heur_label, heur_conf): | |
| agree = cnn_label == heur_label | |
| if agree and cnn_conf >= CNN_HIGH and heur_conf >= HEUR_MIN: | |
| return cnn_label, ROUTE_MAP[cnn_label], min(cnn_conf, heur_conf), False | |
| if agree and cnn_conf >= CNN_MED: | |
| return cnn_label, ROUTE_MAP[cnn_label], cnn_conf * 0.9, False | |
| cnn_med = cnn_label in MEDICAL_LABELS | |
| heur_med = heur_label in MEDICAL_LABELS | |
| if cnn_med and heur_med and cnn_label != heur_label: | |
| pri = cnn_label if cnn_conf >= heur_conf else heur_label | |
| return pri, "both_medical", max(cnn_conf, heur_conf) * 0.6, cnn_conf < 0.6 | |
| if cnn_med != heur_med: | |
| pri = cnn_label if cnn_conf >= 0.6 else heur_label | |
| return pri, "all", max(cnn_conf, heur_conf) * 0.4, True | |
| pri = cnn_label if cnn_conf >= 0.5 else heur_label | |
| route = ROUTE_MAP.get(pri, "all") if cnn_conf >= CNN_MED else "all" | |
| return pri, route, max(cnn_conf, heur_conf) * 0.5, cnn_conf < 0.6 | |
| # ββ Load model at startup ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("Loading model...") | |
| model = load_model() | |
| print("Model loaded.") | |
| # ββ Predict function βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict(image): | |
| """Classify a page image into text_document, wound, or clinical_medical.""" | |
| if image is None: | |
| return {"error": "No image provided"} | |
| t0 = time.time() | |
| pil = Image.fromarray(image).convert("RGB") | |
| tensor = val_transform(pil).unsqueeze(0) | |
| probs = model.predict_proba(tensor).squeeze(0) | |
| cnn_pred = probs.argmax().item() | |
| cnn_label = LABEL_NAMES[cnn_pred] | |
| cnn_conf = probs[cnn_pred].item() | |
| bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| heur_label, heur_conf = heuristic_classify(bgr) | |
| label, route, conf, review = ensemble(cnn_label, cnn_conf, heur_label, heur_conf) | |
| elapsed = (time.time() - t0) * 1000 | |
| class_probs = {LABEL_NAMES[i]: round(probs[i].item(), 4) for i in range(NUM_CLASSES)} | |
| return { | |
| "label": label, | |
| "label_display": DISPLAY_NAMES.get(label, label), | |
| "route": route, | |
| "route_display": ROUTE_DISPLAY.get(route, route), | |
| "confidence": round(conf, 4), | |
| "class_probabilities": class_probs, | |
| "cnn_label": cnn_label, | |
| "cnn_confidence": round(cnn_conf, 4), | |
| "heuristic_label": heur_label, | |
| "heuristic_confidence": round(heur_conf, 4), | |
| "needs_review": review, | |
| "inference_ms": round(elapsed, 1), | |
| } | |
| # ββ Gradio Interface βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="numpy", label="Upload Page Image"), | |
| outputs=gr.JSON(label="Classification Result"), | |
| title="Document Classifier API", | |
| description="3-class classifier: **Text Document** | **Wound** | **Clinical/Medical**. " | |
| "Use the UI to test, or call the API programmatically at `/api/predict`.", | |
| examples=None, | |
| api_name="predict", | |
| flagging_mode="never", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True) | |