File size: 3,611 Bytes
2a034f9 b709ba3 2a034f9 b709ba3 2a034f9 b709ba3 2a034f9 b709ba3 2a034f9 b709ba3 2a034f9 b709ba3 2a034f9 b709ba3 2a034f9 b709ba3 2a034f9 b709ba3 2a034f9 b709ba3 2a034f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import os
import tempfile
import zipfile
from pathlib import Path
from typing import List, Tuple, Any
import gradio as gr
import torch
from model import PairNet, load_weights_if_any
from utils import load_exam_as_batch, aggregate_predictions, clamp_days, today_plus_days
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def init_model() -> Tuple[torch.nn.Module, str, bool]:
model = PairNet(pretrained=True).to(DEVICE)
model.eval()
weights_hint = os.getenv("HF_WEIGHTS", "").strip()
ok = False
msg = "Running in baseline mode (no weights). Output is a UI demo only."
if weights_hint:
ok, msg = load_weights_if_any(model, weights_hint)
elif Path("weights/pair_v4.pt").exists():
ok, msg = load_weights_if_any(model, "weights/pair_v4.pt")
return model, msg, ok
MODEL, LOAD_MSG, HAS_WEIGHTS = init_model()
def _to_path(maybe_file: Any) -> Path:
if isinstance(maybe_file, str):
return Path(maybe_file)
if isinstance(maybe_file, dict) and "name" in maybe_file:
return Path(maybe_file["name"])
name = getattr(maybe_file, "name", None)
if name:
return Path(name)
return Path(str(maybe_file))
def predict_on_files(files) -> dict:
paths: List[str] = []
for f in files or []:
p = _to_path(f)
if p.suffix.lower() == ".zip":
with zipfile.ZipFile(p, "r") as z:
with tempfile.TemporaryDirectory() as td:
z.extractall(td)
for ext in (".png", ".jpg", ".jpeg", ".dcm"):
paths.extend([str(q) for q in Path(td).rglob(f"*{ext}")])
else:
paths.append(str(p))
if not paths:
return {"status": "no files received"}
x = load_exam_as_batch(paths).to(DEVICE).float()
if HAS_WEIGHTS:
with torch.no_grad():
days_raw, logits = MODEL(x)
days = days_raw.squeeze(-1).detach().cpu().tolist()
proba = torch.sigmoid(logits.squeeze(-1)).detach().cpu().tolist()
else:
days = [150.0 for _ in range(x.shape[0])]
proba = [0.5 for _ in range(x.shape[0])]
days = [clamp_days(float(d)) for d in days]
preterm = ["Preterm" if p >= 0.5 else "Term" for p in proba]
days_mean, proba_mean = aggregate_predictions(days, proba)
return {
"frames": len(paths),
"per_frame_days": days,
"per_frame_preterm_proba": proba,
"per_frame_preterm_label": preterm,
"aggregate_days_mean": days_mean,
"aggregate_predicted_date": today_plus_days(days_mean),
"aggregate_preterm_proba": proba_mean,
"aggregate_preterm_label": "Preterm" if proba_mean >= 0.5 else "Term",
"weights_message": LOAD_MSG
}
with gr.Blocks(title="PAIR-inspired Delivery Timing Predictor") as demo:
gr.Markdown(
"""
PAIR-inspired Delivery Timing Predictor
This app is a technical scaffold inspired by the PAIR study.
It does not include the proprietary model or clinical dataset.
Not for medical use.
"""
)
with gr.Row():
with gr.Column():
in_files = gr.Files(
label="Upload ultrasound images or a ZIP (PNG/JPG/DICOM)",
file_count="multiple",
type="filepath"
)
run_btn = gr.Button("Run prediction", variant="primary")
with gr.Column():
status = gr.JSON(label="Outputs")
gr.Markdown(f"Model status: {LOAD_MSG}")
run_btn.click(fn=predict_on_files, inputs=[in_files], outputs=[status])
if __name__ == "__main__":
demo.launch()
|