Pathora / backend /Colpo /inference.py
nusaibah0110's picture
NewModeladded
925c34c
# ============================================================
# Colposcopy Inference Backend
# Production-ready | VS Code | Hugging Face compatible
# ============================================================
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import joblib
from torchvision import transforms, models
from PIL import Image
# ------------------------------------------------------------
# DEVICE
# ------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ------------------------------------------------------------
# PATHS (RELATIVE — REQUIRED FOR DEPLOYMENT)
# ------------------------------------------------------------
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_DIR = os.path.join(BASE_DIR, "models")
OUTPUT_DIR = os.path.join(BASE_DIR, "outputs")
os.makedirs(OUTPUT_DIR, exist_ok=True)
SEG_MODEL_PATH = os.path.join(MODEL_DIR, "seg_yolov8n_best.pt")
FUSION_MODEL_PATH = os.path.join(MODEL_DIR, "fusion_model.pth")
CLF_PATH = os.path.join(MODEL_DIR, "logreg_classifier.joblib")
# ------------------------------------------------------------
# LOAD MODELS (ONCE)
# ------------------------------------------------------------
from ultralytics import YOLO
seg_model = YOLO(SEG_MODEL_PATH)
clf = joblib.load(CLF_PATH)
# ------------------------------------------------------------
# FUSION MODEL DEFINITION
# ------------------------------------------------------------
class ImageEncoder(nn.Module):
def __init__(self):
super().__init__()
base = models.resnet18(pretrained=False)
self.backbone = nn.Sequential(*list(base.children())[:-1])
self.fc = nn.Linear(512, 512)
def forward(self, x):
x = self.backbone(x)
return self.fc(x.view(x.size(0), -1))
class FeatureEncoder(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(7, 64),
nn.ReLU(),
nn.Linear(64, 64)
)
def forward(self, x):
return self.net(x)
class FusionModel(nn.Module):
def __init__(self):
super().__init__()
self.img_enc = ImageEncoder()
self.feat_enc = FeatureEncoder()
self.norm = nn.BatchNorm1d(576)
def forward(self, img, feat):
img_emb = self.img_enc(img)
feat_emb = self.feat_enc(feat)
return self.norm(torch.cat([img_emb, feat_emb], dim=1))
fusion_model = FusionModel().to(device)
fusion_model.load_state_dict(torch.load(FUSION_MODEL_PATH, map_location=device))
fusion_model.eval()
# ------------------------------------------------------------
# IMAGE TRANSFORM
# ------------------------------------------------------------
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# ------------------------------------------------------------
# CONSTANTS
# ------------------------------------------------------------
CERVIX_ID = 0
SCJ_ID = 1
ACET_ID = 3
MIN_ACET_RATIO = 0.01
# ------------------------------------------------------------
# GEOMETRY UTILITIES
# ------------------------------------------------------------
def polygon_to_mask(polygon, H, W):
pts = np.array([[int(x * W), int(y * H)] for x, y in polygon], np.int32)
mask = np.zeros((H, W), dtype=np.uint8)
cv2.fillPoly(mask, [pts], 1)
return mask
def mask_area(mask):
return mask.sum() / mask.size
def centroid_distance(mask1, mask2):
if mask2 is None:
return 1.0
ys1, xs1 = np.where(mask1 == 1)
ys2, xs2 = np.where(mask2 == 1)
if len(xs1) == 0 or len(xs2) == 0:
return 1.0
c1 = np.array([xs1.mean(), ys1.mean()])
c2 = np.array([xs2.mean(), ys2.mean()])
return np.linalg.norm(c1 - c2) / max(mask1.shape)
def overlap_ratio(mask1, mask2):
if mask2 is None:
return 0.0
inter = np.logical_and(mask1, mask2).sum()
return inter / mask1.sum() if mask1.sum() > 0 else 0.0
# ------------------------------------------------------------
# LOAD YOLO POLYGONS
# ------------------------------------------------------------
def load_yolo_segmentation(label_path):
objects = []
if not os.path.exists(label_path):
return objects
with open(label_path) as f:
for line in f:
parts = list(map(float, line.strip().split()))
cls = int(parts[0])
coords = parts[1:]
polygon = [(coords[i], coords[i + 1]) for i in range(0, len(coords), 2)]
objects.append({"cls": cls, "polygon": polygon})
return objects
# ------------------------------------------------------------
# FEATURE EXTRACTION
# ------------------------------------------------------------
def extract_features_from_label(label_path, H, W):
objects = load_yolo_segmentation(label_path)
cervix_masks, scj_masks, acet_masks = [], [], []
for obj in objects:
m = polygon_to_mask(obj["polygon"], H, W)
if obj["cls"] == CERVIX_ID:
cervix_masks.append(m)
elif obj["cls"] == SCJ_ID:
scj_masks.append(m)
elif obj["cls"] == ACET_ID:
acet_masks.append(m)
cervix = max(cervix_masks, key=lambda m: m.sum()) if cervix_masks else np.zeros((H, W))
scj = max(scj_masks, key=lambda m: m.sum()) if scj_masks else None
cervix_area = mask_area(cervix)
acet_union = np.zeros((H, W), dtype=np.uint8)
for m in acet_masks:
acet_union = np.maximum(acet_union, m)
acet_union = acet_union * cervix
if acet_union.sum() > 0:
acet_union = cv2.morphologyEx(
acet_union, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8)
)
acet_area = mask_area(acet_union)
acet_present = int(cervix_area > 0 and acet_area / cervix_area >= MIN_ACET_RATIO)
if acet_present:
dist_acet_scj = centroid_distance(acet_union, scj)
lesion_center_dist = centroid_distance(acet_union, cervix)
overlap_lesion_scj = overlap_ratio(acet_union, scj)
else:
dist_acet_scj = lesion_center_dist = 1.0
overlap_lesion_scj = 0.0
return torch.tensor([
acet_present,
1 if acet_present else 0,
acet_area if acet_present else 0.0,
acet_area / cervix_area if acet_present else 0.0,
dist_acet_scj,
lesion_center_dist,
overlap_lesion_scj
], dtype=torch.float32)
# ------------------------------------------------------------
# SAVE VISUALIZATION FOR UI
# ------------------------------------------------------------
def save_overlay(image_path, label_path, out_path):
image = np.array(Image.open(image_path).convert("RGB"))
H, W, _ = image.shape
objects = load_yolo_segmentation(label_path)
cervix = np.zeros((H, W))
scj = np.zeros((H, W))
acet = np.zeros((H, W))
for obj in objects:
m = polygon_to_mask(obj["polygon"], H, W)
if obj["cls"] == CERVIX_ID:
cervix = np.maximum(cervix, m)
elif obj["cls"] == SCJ_ID:
scj = np.maximum(scj, m)
elif obj["cls"] == ACET_ID:
acet = np.maximum(acet, m)
overlay = image.copy()
overlay[cervix == 1] = 0.7 * overlay[cervix == 1] + 0.3 * np.array([0, 0, 255])
overlay[scj == 1] = 0.7 * overlay[scj == 1] + 0.3 * np.array([0, 255, 0])
overlay[acet == 1] = 0.7 * overlay[acet == 1] + 0.3 * np.array([255, 0, 0])
Image.fromarray(overlay.astype(np.uint8)).save(out_path)
# ------------------------------------------------------------
# PUBLIC API — UI CALLS THIS
# ------------------------------------------------------------
def run_inference(image_path: str) -> dict:
results = seg_model(image_path, conf=0.15, save_txt=True, save=False)
save_dir = results[0].save_dir
name = os.path.splitext(os.path.basename(image_path))[0]
label_path = os.path.join(save_dir, "labels", f"{name}.txt")
if not os.path.exists(label_path):
return {"decision": "Segmentation failed"}
image = Image.open(image_path).convert("RGB")
W, H = image.size
img_tensor = transform(image).unsqueeze(0).to(device)
feat = extract_features_from_label(label_path, H, W)
feat_tensor = feat.unsqueeze(0).to(device)
with torch.no_grad():
embedding = fusion_model(img_tensor, feat_tensor)
prob = clf.predict_proba(embedding.cpu().numpy())[0, 1]
acet_present = int(feat[0].item())
if acet_present == 0:
decision = "Low-confidence normal (no acet detected)" if prob < 0.2 else "Uncertain – lesion may be subtle"
else:
decision = "Likely Normal" if prob < 0.2 else "Borderline – Review" if prob < 0.5 else "Likely Abnormal"
overlay_path = os.path.join(OUTPUT_DIR, f"{name}_overlay.png")
save_overlay(image_path, label_path, overlay_path)
return {
"decision": decision,
"probability_abnormal": float(prob),
"acet_present": acet_present,
"overlay_image": overlay_path
}