| import os | |
| import tempfile | |
| import zipfile | |
| from pathlib import Path | |
| from typing import List, Tuple, Any | |
| import gradio as gr | |
| import torch | |
| from model import PairNet, load_weights_if_any | |
| from utils import load_exam_as_batch, aggregate_predictions, clamp_days, today_plus_days | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| def init_model() -> Tuple[torch.nn.Module, str, bool]: | |
| model = PairNet(pretrained=True).to(DEVICE) | |
| model.eval() | |
| weights_hint = os.getenv("HF_WEIGHTS", "").strip() | |
| ok = False | |
| msg = "Running in baseline mode (no weights). Output is a UI demo only." | |
| if weights_hint: | |
| ok, msg = load_weights_if_any(model, weights_hint) | |
| elif Path("weights/pair_v4.pt").exists(): | |
| ok, msg = load_weights_if_any(model, "weights/pair_v4.pt") | |
| return model, msg, ok | |
| MODEL, LOAD_MSG, HAS_WEIGHTS = init_model() | |
| def _to_path(maybe_file: Any) -> Path: | |
| if isinstance(maybe_file, str): | |
| return Path(maybe_file) | |
| if isinstance(maybe_file, dict) and "name" in maybe_file: | |
| return Path(maybe_file["name"]) | |
| name = getattr(maybe_file, "name", None) | |
| if name: | |
| return Path(name) | |
| return Path(str(maybe_file)) | |
| def predict_on_files(files) -> dict: | |
| paths: List[str] = [] | |
| for f in files or []: | |
| p = _to_path(f) | |
| if p.suffix.lower() == ".zip": | |
| with zipfile.ZipFile(p, "r") as z: | |
| with tempfile.TemporaryDirectory() as td: | |
| z.extractall(td) | |
| for ext in (".png", ".jpg", ".jpeg", ".dcm"): | |
| paths.extend([str(q) for q in Path(td).rglob(f"*{ext}")]) | |
| else: | |
| paths.append(str(p)) | |
| if not paths: | |
| return {"status": "no files received"} | |
| x = load_exam_as_batch(paths).to(DEVICE).float() | |
| if HAS_WEIGHTS: | |
| with torch.no_grad(): | |
| days_raw, logits = MODEL(x) | |
| days = days_raw.squeeze(-1).detach().cpu().tolist() | |
| proba = torch.sigmoid(logits.squeeze(-1)).detach().cpu().tolist() | |
| else: | |
| days = [150.0 for _ in range(x.shape[0])] | |
| proba = [0.5 for _ in range(x.shape[0])] | |
| days = [clamp_days(float(d)) for d in days] | |
| preterm = ["Preterm" if p >= 0.5 else "Term" for p in proba] | |
| days_mean, proba_mean = aggregate_predictions(days, proba) | |
| return { | |
| "frames": len(paths), | |
| "per_frame_days": days, | |
| "per_frame_preterm_proba": proba, | |
| "per_frame_preterm_label": preterm, | |
| "aggregate_days_mean": days_mean, | |
| "aggregate_predicted_date": today_plus_days(days_mean), | |
| "aggregate_preterm_proba": proba_mean, | |
| "aggregate_preterm_label": "Preterm" if proba_mean >= 0.5 else "Term", | |
| "weights_message": LOAD_MSG | |
| } | |
| with gr.Blocks(title="PAIR-inspired Delivery Timing Predictor") as demo: | |
| gr.Markdown( | |
| """ | |
| PAIR-inspired Delivery Timing Predictor | |
| This app is a technical scaffold inspired by the PAIR study. | |
| It does not include the proprietary model or clinical dataset. | |
| Not for medical use. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| in_files = gr.Files( | |
| label="Upload ultrasound images or a ZIP (PNG/JPG/DICOM)", | |
| file_count="multiple", | |
| type="filepath" | |
| ) | |
| run_btn = gr.Button("Run prediction", variant="primary") | |
| with gr.Column(): | |
| status = gr.JSON(label="Outputs") | |
| gr.Markdown(f"Model status: {LOAD_MSG}") | |
| run_btn.click(fn=predict_on_files, inputs=[in_files], outputs=[status]) | |
| if __name__ == "__main__": | |
| demo.launch() | |