Update app.py
Browse files
app.py
CHANGED
|
@@ -2,11 +2,10 @@ import os
|
|
| 2 |
import tempfile
|
| 3 |
import zipfile
|
| 4 |
from pathlib import Path
|
| 5 |
-
from typing import List, Tuple
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
import torch
|
| 9 |
-
import torch.nn.functional as F
|
| 10 |
|
| 11 |
from model import PairNet, load_weights_if_any
|
| 12 |
from utils import load_exam_as_batch, aggregate_predictions, clamp_days, today_plus_days
|
|
@@ -27,11 +26,20 @@ def init_model() -> Tuple[torch.nn.Module, str, bool]:
|
|
| 27 |
|
| 28 |
MODEL, LOAD_MSG, HAS_WEIGHTS = init_model()
|
| 29 |
|
| 30 |
-
def
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
paths: List[str] = []
|
| 33 |
for f in files or []:
|
| 34 |
-
p =
|
| 35 |
if p.suffix.lower() == ".zip":
|
| 36 |
with zipfile.ZipFile(p, "r") as z:
|
| 37 |
with tempfile.TemporaryDirectory() as td:
|
|
@@ -44,17 +52,14 @@ def predict_on_files(files: List[gr.File]) -> dict:
|
|
| 44 |
if not paths:
|
| 45 |
return {"status": "no files received"}
|
| 46 |
|
| 47 |
-
x = load_exam_as_batch(paths).to(DEVICE)
|
| 48 |
-
x = x.float()
|
| 49 |
|
| 50 |
if HAS_WEIGHTS:
|
| 51 |
with torch.no_grad():
|
| 52 |
days_raw, logits = MODEL(x)
|
| 53 |
-
# Aggregate across frames by mean
|
| 54 |
days = days_raw.squeeze(-1).detach().cpu().tolist()
|
| 55 |
proba = torch.sigmoid(logits.squeeze(-1)).detach().cpu().tolist()
|
| 56 |
else:
|
| 57 |
-
# Baseline demo: constant mid-gestation-ish guess and neutral probability
|
| 58 |
days = [150.0 for _ in range(x.shape[0])]
|
| 59 |
proba = [0.5 for _ in range(x.shape[0])]
|
| 60 |
|
|
@@ -63,7 +68,7 @@ def predict_on_files(files: List[gr.File]) -> dict:
|
|
| 63 |
|
| 64 |
days_mean, proba_mean = aggregate_predictions(days, proba)
|
| 65 |
|
| 66 |
-
|
| 67 |
"frames": len(paths),
|
| 68 |
"per_frame_days": days,
|
| 69 |
"per_frame_preterm_proba": proba,
|
|
@@ -74,31 +79,32 @@ def predict_on_files(files: List[gr.File]) -> dict:
|
|
| 74 |
"aggregate_preterm_label": "Preterm" if proba_mean >= 0.5 else "Term",
|
| 75 |
"weights_message": LOAD_MSG
|
| 76 |
}
|
| 77 |
-
return result
|
| 78 |
|
| 79 |
with gr.Blocks(title="PAIR-inspired Delivery Timing Predictor") as demo:
|
| 80 |
gr.Markdown(
|
| 81 |
-
"
|
|
|
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
"
|
| 87 |
)
|
| 88 |
|
| 89 |
with gr.Row():
|
| 90 |
with gr.Column():
|
| 91 |
-
in_files = gr.Files(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
run_btn = gr.Button("Run prediction", variant="primary")
|
| 93 |
with gr.Column():
|
| 94 |
status = gr.JSON(label="Outputs")
|
| 95 |
|
| 96 |
-
|
| 97 |
|
| 98 |
-
|
| 99 |
-
return predict_on_files(files)
|
| 100 |
-
|
| 101 |
-
run_btn.click(_run, inputs=[in_files], outputs=[status])
|
| 102 |
|
| 103 |
if __name__ == "__main__":
|
| 104 |
demo.launch()
|
|
|
|
| 2 |
import tempfile
|
| 3 |
import zipfile
|
| 4 |
from pathlib import Path
|
| 5 |
+
from typing import List, Tuple, Any
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
import torch
|
|
|
|
| 9 |
|
| 10 |
from model import PairNet, load_weights_if_any
|
| 11 |
from utils import load_exam_as_batch, aggregate_predictions, clamp_days, today_plus_days
|
|
|
|
| 26 |
|
| 27 |
MODEL, LOAD_MSG, HAS_WEIGHTS = init_model()
|
| 28 |
|
| 29 |
+
def _to_path(maybe_file: Any) -> Path:
|
| 30 |
+
if isinstance(maybe_file, str):
|
| 31 |
+
return Path(maybe_file)
|
| 32 |
+
if isinstance(maybe_file, dict) and "name" in maybe_file:
|
| 33 |
+
return Path(maybe_file["name"])
|
| 34 |
+
name = getattr(maybe_file, "name", None)
|
| 35 |
+
if name:
|
| 36 |
+
return Path(name)
|
| 37 |
+
return Path(str(maybe_file))
|
| 38 |
+
|
| 39 |
+
def predict_on_files(files) -> dict:
|
| 40 |
paths: List[str] = []
|
| 41 |
for f in files or []:
|
| 42 |
+
p = _to_path(f)
|
| 43 |
if p.suffix.lower() == ".zip":
|
| 44 |
with zipfile.ZipFile(p, "r") as z:
|
| 45 |
with tempfile.TemporaryDirectory() as td:
|
|
|
|
| 52 |
if not paths:
|
| 53 |
return {"status": "no files received"}
|
| 54 |
|
| 55 |
+
x = load_exam_as_batch(paths).to(DEVICE).float()
|
|
|
|
| 56 |
|
| 57 |
if HAS_WEIGHTS:
|
| 58 |
with torch.no_grad():
|
| 59 |
days_raw, logits = MODEL(x)
|
|
|
|
| 60 |
days = days_raw.squeeze(-1).detach().cpu().tolist()
|
| 61 |
proba = torch.sigmoid(logits.squeeze(-1)).detach().cpu().tolist()
|
| 62 |
else:
|
|
|
|
| 63 |
days = [150.0 for _ in range(x.shape[0])]
|
| 64 |
proba = [0.5 for _ in range(x.shape[0])]
|
| 65 |
|
|
|
|
| 68 |
|
| 69 |
days_mean, proba_mean = aggregate_predictions(days, proba)
|
| 70 |
|
| 71 |
+
return {
|
| 72 |
"frames": len(paths),
|
| 73 |
"per_frame_days": days,
|
| 74 |
"per_frame_preterm_proba": proba,
|
|
|
|
| 79 |
"aggregate_preterm_label": "Preterm" if proba_mean >= 0.5 else "Term",
|
| 80 |
"weights_message": LOAD_MSG
|
| 81 |
}
|
|
|
|
| 82 |
|
| 83 |
with gr.Blocks(title="PAIR-inspired Delivery Timing Predictor") as demo:
|
| 84 |
gr.Markdown(
|
| 85 |
+
"""
|
| 86 |
+
PAIR-inspired Delivery Timing Predictor
|
| 87 |
|
| 88 |
+
This app is a technical scaffold inspired by the PAIR study.
|
| 89 |
+
It does not include the proprietary model or clinical dataset.
|
| 90 |
+
Not for medical use.
|
| 91 |
+
"""
|
| 92 |
)
|
| 93 |
|
| 94 |
with gr.Row():
|
| 95 |
with gr.Column():
|
| 96 |
+
in_files = gr.Files(
|
| 97 |
+
label="Upload ultrasound images or a ZIP (PNG/JPG/DICOM)",
|
| 98 |
+
file_count="multiple",
|
| 99 |
+
type="filepath"
|
| 100 |
+
)
|
| 101 |
run_btn = gr.Button("Run prediction", variant="primary")
|
| 102 |
with gr.Column():
|
| 103 |
status = gr.JSON(label="Outputs")
|
| 104 |
|
| 105 |
+
gr.Markdown(f"Model status: {LOAD_MSG}")
|
| 106 |
|
| 107 |
+
run_btn.click(fn=predict_on_files, inputs=[in_files], outputs=[status])
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
if __name__ == "__main__":
|
| 110 |
demo.launch()
|