PAIR / app.py
fantos's picture
Update app.py
b709ba3 verified
import os
import tempfile
import zipfile
from pathlib import Path
from typing import List, Tuple, Any
import gradio as gr
import torch
from model import PairNet, load_weights_if_any
from utils import load_exam_as_batch, aggregate_predictions, clamp_days, today_plus_days
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def init_model() -> Tuple[torch.nn.Module, str, bool]:
model = PairNet(pretrained=True).to(DEVICE)
model.eval()
weights_hint = os.getenv("HF_WEIGHTS", "").strip()
ok = False
msg = "Running in baseline mode (no weights). Output is a UI demo only."
if weights_hint:
ok, msg = load_weights_if_any(model, weights_hint)
elif Path("weights/pair_v4.pt").exists():
ok, msg = load_weights_if_any(model, "weights/pair_v4.pt")
return model, msg, ok
MODEL, LOAD_MSG, HAS_WEIGHTS = init_model()
def _to_path(maybe_file: Any) -> Path:
if isinstance(maybe_file, str):
return Path(maybe_file)
if isinstance(maybe_file, dict) and "name" in maybe_file:
return Path(maybe_file["name"])
name = getattr(maybe_file, "name", None)
if name:
return Path(name)
return Path(str(maybe_file))
def predict_on_files(files) -> dict:
paths: List[str] = []
for f in files or []:
p = _to_path(f)
if p.suffix.lower() == ".zip":
with zipfile.ZipFile(p, "r") as z:
with tempfile.TemporaryDirectory() as td:
z.extractall(td)
for ext in (".png", ".jpg", ".jpeg", ".dcm"):
paths.extend([str(q) for q in Path(td).rglob(f"*{ext}")])
else:
paths.append(str(p))
if not paths:
return {"status": "no files received"}
x = load_exam_as_batch(paths).to(DEVICE).float()
if HAS_WEIGHTS:
with torch.no_grad():
days_raw, logits = MODEL(x)
days = days_raw.squeeze(-1).detach().cpu().tolist()
proba = torch.sigmoid(logits.squeeze(-1)).detach().cpu().tolist()
else:
days = [150.0 for _ in range(x.shape[0])]
proba = [0.5 for _ in range(x.shape[0])]
days = [clamp_days(float(d)) for d in days]
preterm = ["Preterm" if p >= 0.5 else "Term" for p in proba]
days_mean, proba_mean = aggregate_predictions(days, proba)
return {
"frames": len(paths),
"per_frame_days": days,
"per_frame_preterm_proba": proba,
"per_frame_preterm_label": preterm,
"aggregate_days_mean": days_mean,
"aggregate_predicted_date": today_plus_days(days_mean),
"aggregate_preterm_proba": proba_mean,
"aggregate_preterm_label": "Preterm" if proba_mean >= 0.5 else "Term",
"weights_message": LOAD_MSG
}
with gr.Blocks(title="PAIR-inspired Delivery Timing Predictor") as demo:
gr.Markdown(
"""
PAIR-inspired Delivery Timing Predictor
This app is a technical scaffold inspired by the PAIR study.
It does not include the proprietary model or clinical dataset.
Not for medical use.
"""
)
with gr.Row():
with gr.Column():
in_files = gr.Files(
label="Upload ultrasound images or a ZIP (PNG/JPG/DICOM)",
file_count="multiple",
type="filepath"
)
run_btn = gr.Button("Run prediction", variant="primary")
with gr.Column():
status = gr.JSON(label="Outputs")
gr.Markdown(f"Model status: {LOAD_MSG}")
run_btn.click(fn=predict_on_files, inputs=[in_files], outputs=[status])
if __name__ == "__main__":
demo.launch()