lmedz commited on
Commit
a11efe6
·
verified ·
1 Parent(s): fce7238

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -11
app.py CHANGED
@@ -9,6 +9,7 @@ from huggingface_hub import hf_hub_download
9
  # ---------------------------
10
  # Device configuration
11
  # ---------------------------
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
  # ---------------------------
@@ -16,7 +17,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  # ---------------------------
17
 
18
  def build_model():
19
- """Recreate the network architecture used during training."""
20
  backbone = timm.create_model(
21
  "convnext_small", pretrained=False, num_classes=0, global_pool="avg"
22
  )
@@ -57,19 +58,29 @@ transform = transforms.Compose(
57
  # ---------------------------
58
  # Inference function
59
  # ---------------------------
60
- THRESHOLD = 0.5 # adjust if you want to tweak the decision boundary
61
 
62
- def predict(img: Image.Image, magnification: int, ra_conc: float):
63
- """Return probabilities for High / Low CPM classes.
64
 
65
- Note: `magnification` and `ra_conc` are recorded but *not* used for
66
- inference. They are included so that users can tag inputs with these
67
- metadata values.
 
 
 
 
 
 
 
 
 
 
 
68
  """
69
  img_tensor = transform(img).unsqueeze(0).to(device)
70
  with torch.no_grad():
71
  logit = model(img_tensor)
72
- prob_high = torch.sigmoid(logit).item()
 
 
73
 
74
  prob_low = 1.0 - prob_high
75
 
@@ -89,13 +100,21 @@ demo = gr.Interface(
89
  gr.Image(type="pil", label="Microscopy Image"),
90
  gr.Dropdown(choices=[4, 10, 20], value=10, label="Magnification (×)"),
91
  gr.Dropdown(choices=[0.1, 0.5, 1.0], value=0.1, label="RA Concentration (µM)"),
 
 
 
 
 
 
 
 
92
  ],
93
  outputs=gr.Label(num_top_classes=2, label="Predicted CPM Class & Probability"),
94
  title="iPS Cell Quality Classifier",
95
  description=(
96
- "Upload a microscopy image, choose magnification and RA concentration "
97
- "(for record‑keeping), and instantly get the predicted CPM quality "
98
- "class with its probability."
99
  ),
100
  )
101
 
 
9
  # ---------------------------
10
  # Device configuration
11
  # ---------------------------
12
+
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
  # ---------------------------
 
17
  # ---------------------------
18
 
19
  def build_model():
20
+ """Re-create the network architecture used during training."""
21
  backbone = timm.create_model(
22
  "convnext_small", pretrained=False, num_classes=0, global_pool="avg"
23
  )
 
58
  # ---------------------------
59
  # Inference function
60
  # ---------------------------
 
61
 
 
 
62
 
63
+ def predict(
64
+ img: Image.Image,
65
+ magnification: int,
66
+ ra_conc: float,
67
+ temperature: float = 1.0,
68
+ ):
69
+ """Return probabilities for High/Low CPM classes with optional temperature scaling.
70
+
71
+ Args:
72
+ img: Microscopy image.
73
+ magnification: Tag for objective magnification (×4/10/20).
74
+ ra_conc: Tag for RA concentration (µM).
75
+ temperature: Temperature parameter for confidence calibration. T>1 lowers
76
+ confidence, T<1 increases confidence.
77
  """
78
  img_tensor = transform(img).unsqueeze(0).to(device)
79
  with torch.no_grad():
80
  logit = model(img_tensor)
81
+ # Temperature scaling for calibration
82
+ logit_scaled = logit / temperature
83
+ prob_high = torch.sigmoid(logit_scaled).item()
84
 
85
  prob_low = 1.0 - prob_high
86
 
 
100
  gr.Image(type="pil", label="Microscopy Image"),
101
  gr.Dropdown(choices=[4, 10, 20], value=10, label="Magnification (×)"),
102
  gr.Dropdown(choices=[0.1, 0.5, 1.0], value=0.1, label="RA Concentration (µM)"),
103
+ gr.Slider(
104
+ minimum=0.5,
105
+ maximum=5.0,
106
+ step=0.1,
107
+ value=1.0,
108
+ label="Temperature (confidence calibration)",
109
+ info="Increase temperature to reduce overconfidence",
110
+ ),
111
  ],
112
  outputs=gr.Label(num_top_classes=2, label="Predicted CPM Class & Probability"),
113
  title="iPS Cell Quality Classifier",
114
  description=(
115
+ "Upload a microscopy image, choose magnification & RA concentration "
116
+ "(metadata only), then optionally adjust the *Temperature* slider to "
117
+ "calibrate confidence if predictions look over‑ or under‑confident."
118
  ),
119
  )
120