Spaces:
Running
Running
| # ============================================================ | |
| # Colposcopy Inference Backend | |
| # Production-ready | VS Code | Hugging Face compatible | |
| # ============================================================ | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import joblib | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| # ------------------------------------------------------------ | |
| # DEVICE | |
| # ------------------------------------------------------------ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ------------------------------------------------------------ | |
| # PATHS (RELATIVE — REQUIRED FOR DEPLOYMENT) | |
| # ------------------------------------------------------------ | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| MODEL_DIR = os.path.join(BASE_DIR, "models") | |
| OUTPUT_DIR = os.path.join(BASE_DIR, "outputs") | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| SEG_MODEL_PATH = os.path.join(MODEL_DIR, "seg_yolov8n_best.pt") | |
| FUSION_MODEL_PATH = os.path.join(MODEL_DIR, "fusion_model.pth") | |
| CLF_PATH = os.path.join(MODEL_DIR, "logreg_classifier.joblib") | |
| # ------------------------------------------------------------ | |
| # LOAD MODELS (ONCE) | |
| # ------------------------------------------------------------ | |
| from ultralytics import YOLO | |
| seg_model = YOLO(SEG_MODEL_PATH) | |
| clf = joblib.load(CLF_PATH) | |
| # ------------------------------------------------------------ | |
| # FUSION MODEL DEFINITION | |
| # ------------------------------------------------------------ | |
| class ImageEncoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| base = models.resnet18(pretrained=False) | |
| self.backbone = nn.Sequential(*list(base.children())[:-1]) | |
| self.fc = nn.Linear(512, 512) | |
| def forward(self, x): | |
| x = self.backbone(x) | |
| return self.fc(x.view(x.size(0), -1)) | |
| class FeatureEncoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(7, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, 64) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class FusionModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.img_enc = ImageEncoder() | |
| self.feat_enc = FeatureEncoder() | |
| self.norm = nn.BatchNorm1d(576) | |
| def forward(self, img, feat): | |
| img_emb = self.img_enc(img) | |
| feat_emb = self.feat_enc(feat) | |
| return self.norm(torch.cat([img_emb, feat_emb], dim=1)) | |
| fusion_model = FusionModel().to(device) | |
| fusion_model.load_state_dict(torch.load(FUSION_MODEL_PATH, map_location=device)) | |
| fusion_model.eval() | |
| # ------------------------------------------------------------ | |
| # IMAGE TRANSFORM | |
| # ------------------------------------------------------------ | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor() | |
| ]) | |
| # ------------------------------------------------------------ | |
| # CONSTANTS | |
| # ------------------------------------------------------------ | |
| CERVIX_ID = 0 | |
| SCJ_ID = 1 | |
| ACET_ID = 3 | |
| MIN_ACET_RATIO = 0.01 | |
| # ------------------------------------------------------------ | |
| # GEOMETRY UTILITIES | |
| # ------------------------------------------------------------ | |
| def polygon_to_mask(polygon, H, W): | |
| pts = np.array([[int(x * W), int(y * H)] for x, y in polygon], np.int32) | |
| mask = np.zeros((H, W), dtype=np.uint8) | |
| cv2.fillPoly(mask, [pts], 1) | |
| return mask | |
| def mask_area(mask): | |
| return mask.sum() / mask.size | |
| def centroid_distance(mask1, mask2): | |
| if mask2 is None: | |
| return 1.0 | |
| ys1, xs1 = np.where(mask1 == 1) | |
| ys2, xs2 = np.where(mask2 == 1) | |
| if len(xs1) == 0 or len(xs2) == 0: | |
| return 1.0 | |
| c1 = np.array([xs1.mean(), ys1.mean()]) | |
| c2 = np.array([xs2.mean(), ys2.mean()]) | |
| return np.linalg.norm(c1 - c2) / max(mask1.shape) | |
| def overlap_ratio(mask1, mask2): | |
| if mask2 is None: | |
| return 0.0 | |
| inter = np.logical_and(mask1, mask2).sum() | |
| return inter / mask1.sum() if mask1.sum() > 0 else 0.0 | |
| # ------------------------------------------------------------ | |
| # LOAD YOLO POLYGONS | |
| # ------------------------------------------------------------ | |
| def load_yolo_segmentation(label_path): | |
| objects = [] | |
| if not os.path.exists(label_path): | |
| return objects | |
| with open(label_path) as f: | |
| for line in f: | |
| parts = list(map(float, line.strip().split())) | |
| cls = int(parts[0]) | |
| coords = parts[1:] | |
| polygon = [(coords[i], coords[i + 1]) for i in range(0, len(coords), 2)] | |
| objects.append({"cls": cls, "polygon": polygon}) | |
| return objects | |
| # ------------------------------------------------------------ | |
| # FEATURE EXTRACTION | |
| # ------------------------------------------------------------ | |
| def extract_features_from_label(label_path, H, W): | |
| objects = load_yolo_segmentation(label_path) | |
| cervix_masks, scj_masks, acet_masks = [], [], [] | |
| for obj in objects: | |
| m = polygon_to_mask(obj["polygon"], H, W) | |
| if obj["cls"] == CERVIX_ID: | |
| cervix_masks.append(m) | |
| elif obj["cls"] == SCJ_ID: | |
| scj_masks.append(m) | |
| elif obj["cls"] == ACET_ID: | |
| acet_masks.append(m) | |
| cervix = max(cervix_masks, key=lambda m: m.sum()) if cervix_masks else np.zeros((H, W)) | |
| scj = max(scj_masks, key=lambda m: m.sum()) if scj_masks else None | |
| cervix_area = mask_area(cervix) | |
| acet_union = np.zeros((H, W), dtype=np.uint8) | |
| for m in acet_masks: | |
| acet_union = np.maximum(acet_union, m) | |
| acet_union = acet_union * cervix | |
| if acet_union.sum() > 0: | |
| acet_union = cv2.morphologyEx( | |
| acet_union, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8) | |
| ) | |
| acet_area = mask_area(acet_union) | |
| acet_present = int(cervix_area > 0 and acet_area / cervix_area >= MIN_ACET_RATIO) | |
| if acet_present: | |
| dist_acet_scj = centroid_distance(acet_union, scj) | |
| lesion_center_dist = centroid_distance(acet_union, cervix) | |
| overlap_lesion_scj = overlap_ratio(acet_union, scj) | |
| else: | |
| dist_acet_scj = lesion_center_dist = 1.0 | |
| overlap_lesion_scj = 0.0 | |
| return torch.tensor([ | |
| acet_present, | |
| 1 if acet_present else 0, | |
| acet_area if acet_present else 0.0, | |
| acet_area / cervix_area if acet_present else 0.0, | |
| dist_acet_scj, | |
| lesion_center_dist, | |
| overlap_lesion_scj | |
| ], dtype=torch.float32) | |
| # ------------------------------------------------------------ | |
| # SAVE VISUALIZATION FOR UI | |
| # ------------------------------------------------------------ | |
| def save_overlay(image_path, label_path, out_path): | |
| image = np.array(Image.open(image_path).convert("RGB")) | |
| H, W, _ = image.shape | |
| objects = load_yolo_segmentation(label_path) | |
| cervix = np.zeros((H, W)) | |
| scj = np.zeros((H, W)) | |
| acet = np.zeros((H, W)) | |
| for obj in objects: | |
| m = polygon_to_mask(obj["polygon"], H, W) | |
| if obj["cls"] == CERVIX_ID: | |
| cervix = np.maximum(cervix, m) | |
| elif obj["cls"] == SCJ_ID: | |
| scj = np.maximum(scj, m) | |
| elif obj["cls"] == ACET_ID: | |
| acet = np.maximum(acet, m) | |
| overlay = image.copy() | |
| overlay[cervix == 1] = 0.7 * overlay[cervix == 1] + 0.3 * np.array([0, 0, 255]) | |
| overlay[scj == 1] = 0.7 * overlay[scj == 1] + 0.3 * np.array([0, 255, 0]) | |
| overlay[acet == 1] = 0.7 * overlay[acet == 1] + 0.3 * np.array([255, 0, 0]) | |
| Image.fromarray(overlay.astype(np.uint8)).save(out_path) | |
| # ------------------------------------------------------------ | |
| # PUBLIC API — UI CALLS THIS | |
| # ------------------------------------------------------------ | |
| def run_inference(image_path: str) -> dict: | |
| results = seg_model(image_path, conf=0.15, save_txt=True, save=False) | |
| save_dir = results[0].save_dir | |
| name = os.path.splitext(os.path.basename(image_path))[0] | |
| label_path = os.path.join(save_dir, "labels", f"{name}.txt") | |
| if not os.path.exists(label_path): | |
| return {"decision": "Segmentation failed"} | |
| image = Image.open(image_path).convert("RGB") | |
| W, H = image.size | |
| img_tensor = transform(image).unsqueeze(0).to(device) | |
| feat = extract_features_from_label(label_path, H, W) | |
| feat_tensor = feat.unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| embedding = fusion_model(img_tensor, feat_tensor) | |
| prob = clf.predict_proba(embedding.cpu().numpy())[0, 1] | |
| acet_present = int(feat[0].item()) | |
| if acet_present == 0: | |
| decision = "Low-confidence normal (no acet detected)" if prob < 0.2 else "Uncertain – lesion may be subtle" | |
| else: | |
| decision = "Likely Normal" if prob < 0.2 else "Borderline – Review" if prob < 0.5 else "Likely Abnormal" | |
| overlay_path = os.path.join(OUTPUT_DIR, f"{name}_overlay.png") | |
| save_overlay(image_path, label_path, overlay_path) | |
| return { | |
| "decision": decision, | |
| "probability_abnormal": float(prob), | |
| "acet_present": acet_present, | |
| "overlay_image": overlay_path | |
| } | |