fantos commited on
Commit
b709ba3
·
verified ·
1 Parent(s): 57815fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -22
app.py CHANGED
@@ -2,11 +2,10 @@ import os
2
  import tempfile
3
  import zipfile
4
  from pathlib import Path
5
- from typing import List, Tuple
6
 
7
  import gradio as gr
8
  import torch
9
- import torch.nn.functional as F
10
 
11
  from model import PairNet, load_weights_if_any
12
  from utils import load_exam_as_batch, aggregate_predictions, clamp_days, today_plus_days
@@ -27,11 +26,20 @@ def init_model() -> Tuple[torch.nn.Module, str, bool]:
27
 
28
  MODEL, LOAD_MSG, HAS_WEIGHTS = init_model()
29
 
30
- def predict_on_files(files: List[gr.File]) -> dict:
31
- # Collect file paths (also support zip of images)
 
 
 
 
 
 
 
 
 
32
  paths: List[str] = []
33
  for f in files or []:
34
- p = Path(f.name)
35
  if p.suffix.lower() == ".zip":
36
  with zipfile.ZipFile(p, "r") as z:
37
  with tempfile.TemporaryDirectory() as td:
@@ -44,17 +52,14 @@ def predict_on_files(files: List[gr.File]) -> dict:
44
  if not paths:
45
  return {"status": "no files received"}
46
 
47
- x = load_exam_as_batch(paths).to(DEVICE)
48
- x = x.float()
49
 
50
  if HAS_WEIGHTS:
51
  with torch.no_grad():
52
  days_raw, logits = MODEL(x)
53
- # Aggregate across frames by mean
54
  days = days_raw.squeeze(-1).detach().cpu().tolist()
55
  proba = torch.sigmoid(logits.squeeze(-1)).detach().cpu().tolist()
56
  else:
57
- # Baseline demo: constant mid-gestation-ish guess and neutral probability
58
  days = [150.0 for _ in range(x.shape[0])]
59
  proba = [0.5 for _ in range(x.shape[0])]
60
 
@@ -63,7 +68,7 @@ def predict_on_files(files: List[gr.File]) -> dict:
63
 
64
  days_mean, proba_mean = aggregate_predictions(days, proba)
65
 
66
- result = {
67
  "frames": len(paths),
68
  "per_frame_days": days,
69
  "per_frame_preterm_proba": proba,
@@ -74,31 +79,32 @@ def predict_on_files(files: List[gr.File]) -> dict:
74
  "aggregate_preterm_label": "Preterm" if proba_mean >= 0.5 else "Term",
75
  "weights_message": LOAD_MSG
76
  }
77
- return result
78
 
79
  with gr.Blocks(title="PAIR-inspired Delivery Timing Predictor") as demo:
80
  gr.Markdown(
81
- "PAIR-inspired Delivery Timing Predictor
 
82
 
83
- "
84
- "This app is a technical scaffold inspired by the PAIR study. "
85
- "It does not include the proprietary model or clinical dataset. "
86
- "Not for medical use."
87
  )
88
 
89
  with gr.Row():
90
  with gr.Column():
91
- in_files = gr.Files(label="Upload ultrasound images or a ZIP (PNG/JPG/DICOM)", file_count="multiple", type="filepath")
 
 
 
 
92
  run_btn = gr.Button("Run prediction", variant="primary")
93
  with gr.Column():
94
  status = gr.JSON(label="Outputs")
95
 
96
- note = gr.Markdown(f"Model status: {LOAD_MSG}")
97
 
98
- def _run(files):
99
- return predict_on_files(files)
100
-
101
- run_btn.click(_run, inputs=[in_files], outputs=[status])
102
 
103
  if __name__ == "__main__":
104
  demo.launch()
 
2
  import tempfile
3
  import zipfile
4
  from pathlib import Path
5
+ from typing import List, Tuple, Any
6
 
7
  import gradio as gr
8
  import torch
 
9
 
10
  from model import PairNet, load_weights_if_any
11
  from utils import load_exam_as_batch, aggregate_predictions, clamp_days, today_plus_days
 
26
 
27
  MODEL, LOAD_MSG, HAS_WEIGHTS = init_model()
28
 
29
+ def _to_path(maybe_file: Any) -> Path:
30
+ if isinstance(maybe_file, str):
31
+ return Path(maybe_file)
32
+ if isinstance(maybe_file, dict) and "name" in maybe_file:
33
+ return Path(maybe_file["name"])
34
+ name = getattr(maybe_file, "name", None)
35
+ if name:
36
+ return Path(name)
37
+ return Path(str(maybe_file))
38
+
39
+ def predict_on_files(files) -> dict:
40
  paths: List[str] = []
41
  for f in files or []:
42
+ p = _to_path(f)
43
  if p.suffix.lower() == ".zip":
44
  with zipfile.ZipFile(p, "r") as z:
45
  with tempfile.TemporaryDirectory() as td:
 
52
  if not paths:
53
  return {"status": "no files received"}
54
 
55
+ x = load_exam_as_batch(paths).to(DEVICE).float()
 
56
 
57
  if HAS_WEIGHTS:
58
  with torch.no_grad():
59
  days_raw, logits = MODEL(x)
 
60
  days = days_raw.squeeze(-1).detach().cpu().tolist()
61
  proba = torch.sigmoid(logits.squeeze(-1)).detach().cpu().tolist()
62
  else:
 
63
  days = [150.0 for _ in range(x.shape[0])]
64
  proba = [0.5 for _ in range(x.shape[0])]
65
 
 
68
 
69
  days_mean, proba_mean = aggregate_predictions(days, proba)
70
 
71
+ return {
72
  "frames": len(paths),
73
  "per_frame_days": days,
74
  "per_frame_preterm_proba": proba,
 
79
  "aggregate_preterm_label": "Preterm" if proba_mean >= 0.5 else "Term",
80
  "weights_message": LOAD_MSG
81
  }
 
82
 
83
  with gr.Blocks(title="PAIR-inspired Delivery Timing Predictor") as demo:
84
  gr.Markdown(
85
+ """
86
+ PAIR-inspired Delivery Timing Predictor
87
 
88
+ This app is a technical scaffold inspired by the PAIR study.
89
+ It does not include the proprietary model or clinical dataset.
90
+ Not for medical use.
91
+ """
92
  )
93
 
94
  with gr.Row():
95
  with gr.Column():
96
+ in_files = gr.Files(
97
+ label="Upload ultrasound images or a ZIP (PNG/JPG/DICOM)",
98
+ file_count="multiple",
99
+ type="filepath"
100
+ )
101
  run_btn = gr.Button("Run prediction", variant="primary")
102
  with gr.Column():
103
  status = gr.JSON(label="Outputs")
104
 
105
+ gr.Markdown(f"Model status: {LOAD_MSG}")
106
 
107
+ run_btn.click(fn=predict_on_files, inputs=[in_files], outputs=[status])
 
 
 
108
 
109
  if __name__ == "__main__":
110
  demo.launch()