| """ |
| Custom Inference Handler for Document Classifier. |
| HF Inference Endpoints call EndpointHandler. |
| """ |
| import time |
| import io |
| import base64 |
| import numpy as np |
| import cv2 |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from PIL import Image |
| from torchvision import models, transforms |
|
|
|
|
| NUM_CLASSES = 3 |
| LABEL_MAP = {"text_document": 0, "wound": 1, "clinical_medical": 2} |
| LABEL_NAMES = {v: k for k, v in LABEL_MAP.items()} |
| IMG_SIZE = 224 |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] |
| IMAGENET_STD = [0.229, 0.224, 0.225] |
|
|
| ROUTE_MAP = { |
| "text_document": "ocr", |
| "wound": "wound_care", |
| "clinical_medical": "clinical", |
| } |
| DISPLAY_NAMES = { |
| "text_document": "Text Document", |
| "wound": "Wound Image", |
| "clinical_medical": "Clinical / Medical Image", |
| } |
| ROUTE_DISPLAY = { |
| "ocr": "OCR Text Extraction", |
| "wound_care": "Wound Care Pipeline", |
| "clinical": "Clinical LLM Analysis", |
| "both_medical": "Wound Care + Clinical LLM", |
| "all": "All Pipelines (Review Needed)", |
| } |
|
|
| MEDICAL_LABELS = {"wound", "clinical_medical"} |
| CNN_HIGH = 0.92 |
| CNN_MED = 0.75 |
| HEUR_MIN = 0.55 |
|
|
|
|
| class DocumentClassifierCNN(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.backbone = models.efficientnet_b0(weights=None) |
| in_features = self.backbone.classifier[1].in_features |
| self.backbone.classifier = nn.Sequential( |
| nn.Dropout(p=0.3), nn.Linear(in_features, NUM_CLASSES), |
| ) |
|
|
| def forward(self, x): |
| return self.backbone(x) |
|
|
| def predict_proba(self, x): |
| with torch.no_grad(): |
| return F.softmax(self.forward(x), dim=-1) |
|
|
|
|
| val_transform = transforms.Compose([ |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), |
| ]) |
|
|
|
|
| |
|
|
| def _sat_score(bgr): |
| hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) |
| return float(1.0 - min(hsv[:, :, 1].mean() / 80.0, 1.0)) |
|
|
| def _edge_score(bgr): |
| g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) |
| s = cv2.resize(g, (512, 512)) |
| h = np.abs(cv2.Sobel(s, cv2.CV_64F, 1, 0, ksize=3)).sum() |
| v = np.abs(cv2.Sobel(s, cv2.CV_64F, 0, 1, ksize=3)).sum() |
| t = h + v |
| return float(min(max((v / t - 0.45) / 0.15, 0), 1)) if t > 1e-6 else 0.5 |
|
|
| def _white_score(bgr): |
| g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) |
| return float(min(np.sum(g > 220) / g.size / 0.5, 1.0)) |
|
|
| def _comp_score(bgr): |
| g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) |
| s = cv2.resize(g, (512, 512)) |
| _, b = cv2.threshold(s, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) |
| n, _, stats, _ = cv2.connectedComponentsWithStats(b, 8) |
| if n <= 1: |
| return 0.5 |
| a = stats[1:, cv2.CC_STAT_AREA] |
| return float(min(np.sum((a > 5) & (a < 500)) / (512 * 512 / 100) / 3.0, 1.0)) |
|
|
| def _warm_ratio(bgr): |
| hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) |
| h, s, v = hsv[:, :, 0], hsv[:, :, 1], hsv[:, :, 2] |
| m = ((h < 25) | (h > 165)) & (s > 30) & (v > 40) |
| return float(min(m.sum() / h.size / 0.3, 1.0)) |
|
|
| def _gray_dom(bgr): |
| hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) |
| return float(1.0 - min(hsv[:, :, 1].mean() / 30.0, 1.0)) |
|
|
| def _dark_bg(bgr): |
| g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) |
| return float(min(np.sum(g < 30) / g.size / 0.3, 1.0)) |
|
|
| def heuristic_classify(bgr): |
| s1 = { |
| "sat": _sat_score(bgr), "edge": _edge_score(bgr), |
| "white": _white_score(bgr), "comp": _comp_score(bgr), |
| } |
| combined = s1["sat"] * 0.30 + s1["edge"] * 0.15 + s1["white"] * 0.35 + s1["comp"] * 0.20 |
| if combined >= 0.50: |
| conf = min((combined - 0.50) / 0.50 * 0.5 + 0.5, 1.0) |
| return "text_document", conf |
|
|
| warm = _warm_ratio(bgr) |
| gray = _gray_dom(bgr) |
| dark = _dark_bg(bgr) |
| w_sig = warm * 0.50 + (1 - gray) * 0.20 + (1 - dark) * 0.30 |
| c_sig = gray * 0.30 + dark * 0.30 + (1 - warm) * 0.40 |
| if w_sig > c_sig: |
| conf = min(0.5 + (w_sig - c_sig) * 2, 1.0) |
| return "wound", conf |
| conf = min(0.5 + (c_sig - w_sig) * 2, 1.0) |
| return "clinical_medical", conf |
|
|
|
|
| |
|
|
| def ensemble(cnn_label, cnn_conf, heur_label, heur_conf): |
| agree = cnn_label == heur_label |
| if agree and cnn_conf >= CNN_HIGH and heur_conf >= HEUR_MIN: |
| return cnn_label, ROUTE_MAP[cnn_label], min(cnn_conf, heur_conf), False |
| if agree and cnn_conf >= CNN_MED: |
| return cnn_label, ROUTE_MAP[cnn_label], cnn_conf * 0.9, False |
| cnn_med = cnn_label in MEDICAL_LABELS |
| heur_med = heur_label in MEDICAL_LABELS |
| if cnn_med and heur_med and cnn_label != heur_label: |
| pri = cnn_label if cnn_conf >= heur_conf else heur_label |
| return pri, "both_medical", max(cnn_conf, heur_conf) * 0.6, cnn_conf < 0.6 |
| if cnn_med != heur_med: |
| pri = cnn_label if cnn_conf >= 0.6 else heur_label |
| return pri, "all", max(cnn_conf, heur_conf) * 0.4, True |
| pri = cnn_label if cnn_conf >= 0.5 else heur_label |
| route = ROUTE_MAP.get(pri, "all") if cnn_conf >= CNN_MED else "all" |
| return pri, route, max(cnn_conf, heur_conf) * 0.5, cnn_conf < 0.6 |
|
|
|
|
| |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| import os |
| model_path = os.path.join(path, "best_model.pth") |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model = DocumentClassifierCNN() |
| ckpt = torch.load(model_path, map_location=self.device, weights_only=False) |
| self.model.load_state_dict(ckpt["model_state_dict"]) |
| self.model.to(self.device) |
| self.model.eval() |
|
|
| def __call__(self, data): |
| inputs = data.get("inputs", data) |
|
|
| if isinstance(inputs, dict) and "image" in inputs: |
| img_data = inputs["image"] |
| elif isinstance(inputs, str): |
| img_data = inputs |
| else: |
| img_data = inputs |
|
|
| if isinstance(img_data, str): |
| image_bytes = base64.b64decode(img_data) |
| pil = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| elif isinstance(img_data, bytes): |
| pil = Image.open(io.BytesIO(img_data)).convert("RGB") |
| elif isinstance(img_data, Image.Image): |
| pil = img_data.convert("RGB") |
| else: |
| return {"error": f"Unsupported input type: {type(img_data)}"} |
|
|
| t0 = time.time() |
|
|
| tensor = val_transform(pil).unsqueeze(0).to(self.device) |
| probs = self.model.predict_proba(tensor).squeeze(0).cpu() |
| cnn_pred = probs.argmax().item() |
| cnn_label = LABEL_NAMES[cnn_pred] |
| cnn_conf = probs[cnn_pred].item() |
|
|
| img_np = np.array(pil) |
| bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) |
| heur_label, heur_conf = heuristic_classify(bgr) |
|
|
| label, route, conf, review = ensemble(cnn_label, cnn_conf, heur_label, heur_conf) |
|
|
| elapsed = (time.time() - t0) * 1000 |
| class_probs = {LABEL_NAMES[i]: round(probs[i].item(), 4) for i in range(NUM_CLASSES)} |
|
|
| return { |
| "label": label, |
| "label_display": DISPLAY_NAMES.get(label, label), |
| "route": route, |
| "route_display": ROUTE_DISPLAY.get(route, route), |
| "confidence": round(conf, 4), |
| "class_probabilities": class_probs, |
| "cnn_label": cnn_label, |
| "cnn_confidence": round(cnn_conf, 4), |
| "heuristic_label": heur_label, |
| "heuristic_confidence": round(heur_conf, 4), |
| "needs_review": review, |
| "inference_ms": round(elapsed, 1), |
| } |
|
|