import os import tempfile import zipfile from pathlib import Path from typing import List, Tuple, Any import gradio as gr import torch from model import PairNet, load_weights_if_any from utils import load_exam_as_batch, aggregate_predictions, clamp_days, today_plus_days DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def init_model() -> Tuple[torch.nn.Module, str, bool]: model = PairNet(pretrained=True).to(DEVICE) model.eval() weights_hint = os.getenv("HF_WEIGHTS", "").strip() ok = False msg = "Running in baseline mode (no weights). Output is a UI demo only." if weights_hint: ok, msg = load_weights_if_any(model, weights_hint) elif Path("weights/pair_v4.pt").exists(): ok, msg = load_weights_if_any(model, "weights/pair_v4.pt") return model, msg, ok MODEL, LOAD_MSG, HAS_WEIGHTS = init_model() def _to_path(maybe_file: Any) -> Path: if isinstance(maybe_file, str): return Path(maybe_file) if isinstance(maybe_file, dict) and "name" in maybe_file: return Path(maybe_file["name"]) name = getattr(maybe_file, "name", None) if name: return Path(name) return Path(str(maybe_file)) def predict_on_files(files) -> dict: paths: List[str] = [] for f in files or []: p = _to_path(f) if p.suffix.lower() == ".zip": with zipfile.ZipFile(p, "r") as z: with tempfile.TemporaryDirectory() as td: z.extractall(td) for ext in (".png", ".jpg", ".jpeg", ".dcm"): paths.extend([str(q) for q in Path(td).rglob(f"*{ext}")]) else: paths.append(str(p)) if not paths: return {"status": "no files received"} x = load_exam_as_batch(paths).to(DEVICE).float() if HAS_WEIGHTS: with torch.no_grad(): days_raw, logits = MODEL(x) days = days_raw.squeeze(-1).detach().cpu().tolist() proba = torch.sigmoid(logits.squeeze(-1)).detach().cpu().tolist() else: days = [150.0 for _ in range(x.shape[0])] proba = [0.5 for _ in range(x.shape[0])] days = [clamp_days(float(d)) for d in days] preterm = ["Preterm" if p >= 0.5 else "Term" for p in proba] days_mean, proba_mean = aggregate_predictions(days, proba) return { "frames": len(paths), "per_frame_days": days, "per_frame_preterm_proba": proba, "per_frame_preterm_label": preterm, "aggregate_days_mean": days_mean, "aggregate_predicted_date": today_plus_days(days_mean), "aggregate_preterm_proba": proba_mean, "aggregate_preterm_label": "Preterm" if proba_mean >= 0.5 else "Term", "weights_message": LOAD_MSG } with gr.Blocks(title="PAIR-inspired Delivery Timing Predictor") as demo: gr.Markdown( """ PAIR-inspired Delivery Timing Predictor This app is a technical scaffold inspired by the PAIR study. It does not include the proprietary model or clinical dataset. Not for medical use. """ ) with gr.Row(): with gr.Column(): in_files = gr.Files( label="Upload ultrasound images or a ZIP (PNG/JPG/DICOM)", file_count="multiple", type="filepath" ) run_btn = gr.Button("Run prediction", variant="primary") with gr.Column(): status = gr.JSON(label="Outputs") gr.Markdown(f"Model status: {LOAD_MSG}") run_btn.click(fn=predict_on_files, inputs=[in_files], outputs=[status]) if __name__ == "__main__": demo.launch()