bilalEthizo's picture
Upload app.py with huggingface_hub
66f4bbf verified
"""
3-Class Document Classifier API
Routes page images to: OCR | Wound Care | Clinical LLM
API endpoint: /api/predict
"""
import time
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from PIL import Image
from torchvision import models, transforms
# ── Constants ────────────────────────────────────────────────────────────────
NUM_CLASSES = 3
LABEL_MAP = {"text_document": 0, "wound": 1, "clinical_medical": 2}
LABEL_NAMES = {v: k for k, v in LABEL_MAP.items()}
IMG_SIZE = 224
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
ROUTE_MAP = {
"text_document": "ocr",
"wound": "wound_care",
"clinical_medical": "clinical",
}
DISPLAY_NAMES = {
"text_document": "Text Document",
"wound": "Wound Image",
"clinical_medical": "Clinical / Medical Image",
}
ROUTE_DISPLAY = {
"ocr": "OCR Text Extraction",
"wound_care": "Wound Care Pipeline",
"clinical": "Clinical LLM Analysis",
"both_medical": "Wound Care + Clinical LLM",
"all": "All Pipelines (Review Needed)",
}
MEDICAL_LABELS = {"wound", "clinical_medical"}
CNN_HIGH = 0.92
CNN_MED = 0.75
HEUR_MIN = 0.55
# ── Model ────────────────────────────────────────────────────────────────────
class DocumentClassifierCNN(nn.Module):
def __init__(self):
super().__init__()
self.backbone = models.efficientnet_b0(weights=None)
in_features = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Sequential(
nn.Dropout(p=0.3), nn.Linear(in_features, NUM_CLASSES),
)
def forward(self, x):
return self.backbone(x)
def predict_proba(self, x):
with torch.no_grad():
return F.softmax(self.forward(x), dim=-1)
def load_model():
model = DocumentClassifierCNN()
ckpt = torch.load("best_model.pth", map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
return model
val_transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
# ── Heuristics ───────────────────────────────────────────────────────────────
def _sat_score(bgr):
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
return float(1.0 - min(hsv[:, :, 1].mean() / 80.0, 1.0))
def _edge_score(bgr):
g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
s = cv2.resize(g, (512, 512))
h = np.abs(cv2.Sobel(s, cv2.CV_64F, 1, 0, ksize=3)).sum()
v = np.abs(cv2.Sobel(s, cv2.CV_64F, 0, 1, ksize=3)).sum()
t = h + v
return float(min(max((v / t - 0.45) / 0.15, 0), 1)) if t > 1e-6 else 0.5
def _white_score(bgr):
g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
return float(min(np.sum(g > 220) / g.size / 0.5, 1.0))
def _comp_score(bgr):
g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
s = cv2.resize(g, (512, 512))
_, b = cv2.threshold(s, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
n, _, stats, _ = cv2.connectedComponentsWithStats(b, 8)
if n <= 1: return 0.5
a = stats[1:, cv2.CC_STAT_AREA]
return float(min(np.sum((a > 5) & (a < 500)) / (512 * 512 / 100) / 3.0, 1.0))
def _warm_ratio(bgr):
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
h, s, v = hsv[:, :, 0], hsv[:, :, 1], hsv[:, :, 2]
m = ((h < 25) | (h > 165)) & (s > 30) & (v > 40)
return float(min(m.sum() / h.size / 0.3, 1.0))
def _gray_dom(bgr):
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
return float(1.0 - min(hsv[:, :, 1].mean() / 30.0, 1.0))
def _dark_bg(bgr):
g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
return float(min(np.sum(g < 30) / g.size / 0.3, 1.0))
def heuristic_classify(bgr):
s1 = {
"sat": _sat_score(bgr), "edge": _edge_score(bgr),
"white": _white_score(bgr), "comp": _comp_score(bgr),
}
combined = s1["sat"] * 0.30 + s1["edge"] * 0.15 + s1["white"] * 0.35 + s1["comp"] * 0.20
if combined >= 0.50:
conf = min((combined - 0.50) / 0.50 * 0.5 + 0.5, 1.0)
return "text_document", conf
warm = _warm_ratio(bgr)
gray = _gray_dom(bgr)
dark = _dark_bg(bgr)
w_sig = warm * 0.50 + (1 - gray) * 0.20 + (1 - dark) * 0.30
c_sig = gray * 0.30 + dark * 0.30 + (1 - warm) * 0.40
if w_sig > c_sig:
conf = min(0.5 + (w_sig - c_sig) * 2, 1.0)
return "wound", conf
conf = min(0.5 + (c_sig - w_sig) * 2, 1.0)
return "clinical_medical", conf
# ── Ensemble ─────────────────────────────────────────────────────────────────
def ensemble(cnn_label, cnn_conf, heur_label, heur_conf):
agree = cnn_label == heur_label
if agree and cnn_conf >= CNN_HIGH and heur_conf >= HEUR_MIN:
return cnn_label, ROUTE_MAP[cnn_label], min(cnn_conf, heur_conf), False
if agree and cnn_conf >= CNN_MED:
return cnn_label, ROUTE_MAP[cnn_label], cnn_conf * 0.9, False
cnn_med = cnn_label in MEDICAL_LABELS
heur_med = heur_label in MEDICAL_LABELS
if cnn_med and heur_med and cnn_label != heur_label:
pri = cnn_label if cnn_conf >= heur_conf else heur_label
return pri, "both_medical", max(cnn_conf, heur_conf) * 0.6, cnn_conf < 0.6
if cnn_med != heur_med:
pri = cnn_label if cnn_conf >= 0.6 else heur_label
return pri, "all", max(cnn_conf, heur_conf) * 0.4, True
pri = cnn_label if cnn_conf >= 0.5 else heur_label
route = ROUTE_MAP.get(pri, "all") if cnn_conf >= CNN_MED else "all"
return pri, route, max(cnn_conf, heur_conf) * 0.5, cnn_conf < 0.6
# ── Load model at startup ────────────────────────────────────────────────────
print("Loading model...")
model = load_model()
print("Model loaded.")
# ── Predict function ─────────────────────────────────────────────────────────
def predict(image):
"""Classify a page image into text_document, wound, or clinical_medical."""
if image is None:
return {"error": "No image provided"}
t0 = time.time()
pil = Image.fromarray(image).convert("RGB")
tensor = val_transform(pil).unsqueeze(0)
probs = model.predict_proba(tensor).squeeze(0)
cnn_pred = probs.argmax().item()
cnn_label = LABEL_NAMES[cnn_pred]
cnn_conf = probs[cnn_pred].item()
bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
heur_label, heur_conf = heuristic_classify(bgr)
label, route, conf, review = ensemble(cnn_label, cnn_conf, heur_label, heur_conf)
elapsed = (time.time() - t0) * 1000
class_probs = {LABEL_NAMES[i]: round(probs[i].item(), 4) for i in range(NUM_CLASSES)}
return {
"label": label,
"label_display": DISPLAY_NAMES.get(label, label),
"route": route,
"route_display": ROUTE_DISPLAY.get(route, route),
"confidence": round(conf, 4),
"class_probabilities": class_probs,
"cnn_label": cnn_label,
"cnn_confidence": round(cnn_conf, 4),
"heuristic_label": heur_label,
"heuristic_confidence": round(heur_conf, 4),
"needs_review": review,
"inference_ms": round(elapsed, 1),
}
# ── Gradio Interface ─────────────────────────────────────────────────────────
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy", label="Upload Page Image"),
outputs=gr.JSON(label="Classification Result"),
title="Document Classifier API",
description="3-class classifier: **Text Document** | **Wound** | **Clinical/Medical**. "
"Use the UI to test, or call the API programmatically at `/api/predict`.",
examples=None,
api_name="predict",
flagging_mode="never",
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)