""" Custom Inference Handler for Document Classifier. HF Inference Endpoints call EndpointHandler. """ import time import io import base64 import numpy as np import cv2 import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from torchvision import models, transforms NUM_CLASSES = 3 LABEL_MAP = {"text_document": 0, "wound": 1, "clinical_medical": 2} LABEL_NAMES = {v: k for k, v in LABEL_MAP.items()} IMG_SIZE = 224 IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] ROUTE_MAP = { "text_document": "ocr", "wound": "wound_care", "clinical_medical": "clinical", } DISPLAY_NAMES = { "text_document": "Text Document", "wound": "Wound Image", "clinical_medical": "Clinical / Medical Image", } ROUTE_DISPLAY = { "ocr": "OCR Text Extraction", "wound_care": "Wound Care Pipeline", "clinical": "Clinical LLM Analysis", "both_medical": "Wound Care + Clinical LLM", "all": "All Pipelines (Review Needed)", } MEDICAL_LABELS = {"wound", "clinical_medical"} CNN_HIGH = 0.92 CNN_MED = 0.75 HEUR_MIN = 0.55 class DocumentClassifierCNN(nn.Module): def __init__(self): super().__init__() self.backbone = models.efficientnet_b0(weights=None) in_features = self.backbone.classifier[1].in_features self.backbone.classifier = nn.Sequential( nn.Dropout(p=0.3), nn.Linear(in_features, NUM_CLASSES), ) def forward(self, x): return self.backbone(x) def predict_proba(self, x): with torch.no_grad(): return F.softmax(self.forward(x), dim=-1) val_transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ]) # ── Heuristics ─────────────────────────────────────────────────────────────── def _sat_score(bgr): hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) return float(1.0 - min(hsv[:, :, 1].mean() / 80.0, 1.0)) def _edge_score(bgr): g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) s = cv2.resize(g, (512, 512)) h = np.abs(cv2.Sobel(s, cv2.CV_64F, 1, 0, ksize=3)).sum() v = np.abs(cv2.Sobel(s, cv2.CV_64F, 0, 1, ksize=3)).sum() t = h + v return float(min(max((v / t - 0.45) / 0.15, 0), 1)) if t > 1e-6 else 0.5 def _white_score(bgr): g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) return float(min(np.sum(g > 220) / g.size / 0.5, 1.0)) def _comp_score(bgr): g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) s = cv2.resize(g, (512, 512)) _, b = cv2.threshold(s, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) n, _, stats, _ = cv2.connectedComponentsWithStats(b, 8) if n <= 1: return 0.5 a = stats[1:, cv2.CC_STAT_AREA] return float(min(np.sum((a > 5) & (a < 500)) / (512 * 512 / 100) / 3.0, 1.0)) def _warm_ratio(bgr): hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) h, s, v = hsv[:, :, 0], hsv[:, :, 1], hsv[:, :, 2] m = ((h < 25) | (h > 165)) & (s > 30) & (v > 40) return float(min(m.sum() / h.size / 0.3, 1.0)) def _gray_dom(bgr): hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) return float(1.0 - min(hsv[:, :, 1].mean() / 30.0, 1.0)) def _dark_bg(bgr): g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) return float(min(np.sum(g < 30) / g.size / 0.3, 1.0)) def heuristic_classify(bgr): s1 = { "sat": _sat_score(bgr), "edge": _edge_score(bgr), "white": _white_score(bgr), "comp": _comp_score(bgr), } combined = s1["sat"] * 0.30 + s1["edge"] * 0.15 + s1["white"] * 0.35 + s1["comp"] * 0.20 if combined >= 0.50: conf = min((combined - 0.50) / 0.50 * 0.5 + 0.5, 1.0) return "text_document", conf warm = _warm_ratio(bgr) gray = _gray_dom(bgr) dark = _dark_bg(bgr) w_sig = warm * 0.50 + (1 - gray) * 0.20 + (1 - dark) * 0.30 c_sig = gray * 0.30 + dark * 0.30 + (1 - warm) * 0.40 if w_sig > c_sig: conf = min(0.5 + (w_sig - c_sig) * 2, 1.0) return "wound", conf conf = min(0.5 + (c_sig - w_sig) * 2, 1.0) return "clinical_medical", conf # ── Ensemble ───────────────────────────────────────────────────────────────── def ensemble(cnn_label, cnn_conf, heur_label, heur_conf): agree = cnn_label == heur_label if agree and cnn_conf >= CNN_HIGH and heur_conf >= HEUR_MIN: return cnn_label, ROUTE_MAP[cnn_label], min(cnn_conf, heur_conf), False if agree and cnn_conf >= CNN_MED: return cnn_label, ROUTE_MAP[cnn_label], cnn_conf * 0.9, False cnn_med = cnn_label in MEDICAL_LABELS heur_med = heur_label in MEDICAL_LABELS if cnn_med and heur_med and cnn_label != heur_label: pri = cnn_label if cnn_conf >= heur_conf else heur_label return pri, "both_medical", max(cnn_conf, heur_conf) * 0.6, cnn_conf < 0.6 if cnn_med != heur_med: pri = cnn_label if cnn_conf >= 0.6 else heur_label return pri, "all", max(cnn_conf, heur_conf) * 0.4, True pri = cnn_label if cnn_conf >= 0.5 else heur_label route = ROUTE_MAP.get(pri, "all") if cnn_conf >= CNN_MED else "all" return pri, route, max(cnn_conf, heur_conf) * 0.5, cnn_conf < 0.6 # ── HF Inference Endpoint Handler ──────────────────────────────────────────── class EndpointHandler: def __init__(self, path=""): import os model_path = os.path.join(path, "best_model.pth") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = DocumentClassifierCNN() ckpt = torch.load(model_path, map_location=self.device, weights_only=False) self.model.load_state_dict(ckpt["model_state_dict"]) self.model.to(self.device) self.model.eval() def __call__(self, data): inputs = data.get("inputs", data) if isinstance(inputs, dict) and "image" in inputs: img_data = inputs["image"] elif isinstance(inputs, str): img_data = inputs else: img_data = inputs if isinstance(img_data, str): image_bytes = base64.b64decode(img_data) pil = Image.open(io.BytesIO(image_bytes)).convert("RGB") elif isinstance(img_data, bytes): pil = Image.open(io.BytesIO(img_data)).convert("RGB") elif isinstance(img_data, Image.Image): pil = img_data.convert("RGB") else: return {"error": f"Unsupported input type: {type(img_data)}"} t0 = time.time() tensor = val_transform(pil).unsqueeze(0).to(self.device) probs = self.model.predict_proba(tensor).squeeze(0).cpu() cnn_pred = probs.argmax().item() cnn_label = LABEL_NAMES[cnn_pred] cnn_conf = probs[cnn_pred].item() img_np = np.array(pil) bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) heur_label, heur_conf = heuristic_classify(bgr) label, route, conf, review = ensemble(cnn_label, cnn_conf, heur_label, heur_conf) elapsed = (time.time() - t0) * 1000 class_probs = {LABEL_NAMES[i]: round(probs[i].item(), 4) for i in range(NUM_CLASSES)} return { "label": label, "label_display": DISPLAY_NAMES.get(label, label), "route": route, "route_display": ROUTE_DISPLAY.get(route, route), "confidence": round(conf, 4), "class_probabilities": class_probs, "cnn_label": cnn_label, "cnn_confidence": round(cnn_conf, 4), "heuristic_label": heur_label, "heuristic_confidence": round(heur_conf, 4), "needs_review": review, "inference_ms": round(elapsed, 1), }