File size: 3,611 Bytes
2a034f9
 
 
 
b709ba3
2a034f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b709ba3
 
 
 
 
 
 
 
 
 
 
2a034f9
 
b709ba3
2a034f9
 
 
 
 
 
 
 
 
 
 
 
b709ba3
2a034f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b709ba3
2a034f9
 
 
 
 
 
 
 
 
 
 
 
 
b709ba3
 
2a034f9
b709ba3
 
 
 
2a034f9
 
 
 
b709ba3
 
 
 
 
2a034f9
 
 
 
b709ba3
2a034f9
b709ba3
2a034f9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import tempfile
import zipfile
from pathlib import Path
from typing import List, Tuple, Any

import gradio as gr
import torch

from model import PairNet, load_weights_if_any
from utils import load_exam_as_batch, aggregate_predictions, clamp_days, today_plus_days

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def init_model() -> Tuple[torch.nn.Module, str, bool]:
    model = PairNet(pretrained=True).to(DEVICE)
    model.eval()
    weights_hint = os.getenv("HF_WEIGHTS", "").strip()
    ok = False
    msg = "Running in baseline mode (no weights). Output is a UI demo only."
    if weights_hint:
        ok, msg = load_weights_if_any(model, weights_hint)
    elif Path("weights/pair_v4.pt").exists():
        ok, msg = load_weights_if_any(model, "weights/pair_v4.pt")
    return model, msg, ok

MODEL, LOAD_MSG, HAS_WEIGHTS = init_model()

def _to_path(maybe_file: Any) -> Path:
    if isinstance(maybe_file, str):
        return Path(maybe_file)
    if isinstance(maybe_file, dict) and "name" in maybe_file:
        return Path(maybe_file["name"])
    name = getattr(maybe_file, "name", None)
    if name:
        return Path(name)
    return Path(str(maybe_file))

def predict_on_files(files) -> dict:
    paths: List[str] = []
    for f in files or []:
        p = _to_path(f)
        if p.suffix.lower() == ".zip":
            with zipfile.ZipFile(p, "r") as z:
                with tempfile.TemporaryDirectory() as td:
                    z.extractall(td)
                    for ext in (".png", ".jpg", ".jpeg", ".dcm"):
                        paths.extend([str(q) for q in Path(td).rglob(f"*{ext}")])
        else:
            paths.append(str(p))

    if not paths:
        return {"status": "no files received"}

    x = load_exam_as_batch(paths).to(DEVICE).float()

    if HAS_WEIGHTS:
        with torch.no_grad():
            days_raw, logits = MODEL(x)
            days = days_raw.squeeze(-1).detach().cpu().tolist()
            proba = torch.sigmoid(logits.squeeze(-1)).detach().cpu().tolist()
    else:
        days = [150.0 for _ in range(x.shape[0])]
        proba = [0.5 for _ in range(x.shape[0])]

    days = [clamp_days(float(d)) for d in days]
    preterm = ["Preterm" if p >= 0.5 else "Term" for p in proba]

    days_mean, proba_mean = aggregate_predictions(days, proba)

    return {
        "frames": len(paths),
        "per_frame_days": days,
        "per_frame_preterm_proba": proba,
        "per_frame_preterm_label": preterm,
        "aggregate_days_mean": days_mean,
        "aggregate_predicted_date": today_plus_days(days_mean),
        "aggregate_preterm_proba": proba_mean,
        "aggregate_preterm_label": "Preterm" if proba_mean >= 0.5 else "Term",
        "weights_message": LOAD_MSG
    }

with gr.Blocks(title="PAIR-inspired Delivery Timing Predictor") as demo:
    gr.Markdown(
        """
PAIR-inspired Delivery Timing Predictor

This app is a technical scaffold inspired by the PAIR study.
It does not include the proprietary model or clinical dataset.
Not for medical use.
        """
    )

    with gr.Row():
        with gr.Column():
            in_files = gr.Files(
                label="Upload ultrasound images or a ZIP (PNG/JPG/DICOM)",
                file_count="multiple",
                type="filepath"
            )
            run_btn = gr.Button("Run prediction", variant="primary")
        with gr.Column():
            status = gr.JSON(label="Outputs")

    gr.Markdown(f"Model status: {LOAD_MSG}")

    run_btn.click(fn=predict_on_files, inputs=[in_files], outputs=[status])

if __name__ == "__main__":
    demo.launch()