lmedz commited on
Commit
62532c0
·
verified ·
1 Parent(s): a11efe6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -24
app.py CHANGED
@@ -45,7 +45,7 @@ model.to(device)
45
  model.eval()
46
 
47
  # ---------------------------
48
- # Preprocessing pipeline
49
  # ---------------------------
50
  transform = transforms.Compose(
51
  [
@@ -56,35 +56,29 @@ transform = transforms.Compose(
56
  )
57
 
58
  # ---------------------------
59
- # Inference function
60
  # ---------------------------
61
 
 
62
 
63
- def predict(
64
- img: Image.Image,
65
- magnification: int,
66
- ra_conc: float,
67
- temperature: float = 1.0,
68
- ):
69
- """Return probabilities for High/Low CPM classes with optional temperature scaling.
70
 
71
  Args:
72
  img: Microscopy image.
73
  magnification: Tag for objective magnification (×4/10/20).
74
  ra_conc: Tag for RA concentration (µM).
75
- temperature: Temperature parameter for confidence calibration. T>1 lowers
76
- confidence, T<1 increases confidence.
77
  """
78
  img_tensor = transform(img).unsqueeze(0).to(device)
79
  with torch.no_grad():
80
  logit = model(img_tensor)
81
- # Temperature scaling for calibration
82
- logit_scaled = logit / temperature
83
  prob_high = torch.sigmoid(logit_scaled).item()
84
 
85
  prob_low = 1.0 - prob_high
86
 
87
- # gr.Label expects a mapping {class_name: probability}
88
  return {
89
  "High CPM Score": prob_high,
90
  "Low CPM Score": prob_low,
@@ -100,21 +94,13 @@ demo = gr.Interface(
100
  gr.Image(type="pil", label="Microscopy Image"),
101
  gr.Dropdown(choices=[4, 10, 20], value=10, label="Magnification (×)"),
102
  gr.Dropdown(choices=[0.1, 0.5, 1.0], value=0.1, label="RA Concentration (µM)"),
103
- gr.Slider(
104
- minimum=0.5,
105
- maximum=5.0,
106
- step=0.1,
107
- value=1.0,
108
- label="Temperature (confidence calibration)",
109
- info="Increase temperature to reduce overconfidence",
110
- ),
111
  ],
112
  outputs=gr.Label(num_top_classes=2, label="Predicted CPM Class & Probability"),
113
  title="iPS Cell Quality Classifier",
114
  description=(
115
  "Upload a microscopy image, choose magnification & RA concentration "
116
- "(metadata only), then optionally adjust the *Temperature* slider to "
117
- "calibrate confidence if predictions look over‑ or under‑confident."
118
  ),
119
  )
120
 
 
45
  model.eval()
46
 
47
  # ---------------------------
48
+ # Pre-processing pipeline
49
  # ---------------------------
50
  transform = transforms.Compose(
51
  [
 
56
  )
57
 
58
  # ---------------------------
59
+ # Inference function & temperature setting
60
  # ---------------------------
61
 
62
+ TEMPERATURE = 3.5 # fixed temperature (between 3 and 4) for confidence calibration
63
 
64
+
65
+ def predict(img: Image.Image, magnification: int, ra_conc: float):
66
+ """Return probabilities for High/Low CPM classes.
 
 
 
 
67
 
68
  Args:
69
  img: Microscopy image.
70
  magnification: Tag for objective magnification (×4/10/20).
71
  ra_conc: Tag for RA concentration (µM).
 
 
72
  """
73
  img_tensor = transform(img).unsqueeze(0).to(device)
74
  with torch.no_grad():
75
  logit = model(img_tensor)
76
+ # Apply fixed temperature scaling to mitigate over‑confidence
77
+ logit_scaled = logit / TEMPERATURE
78
  prob_high = torch.sigmoid(logit_scaled).item()
79
 
80
  prob_low = 1.0 - prob_high
81
 
 
82
  return {
83
  "High CPM Score": prob_high,
84
  "Low CPM Score": prob_low,
 
94
  gr.Image(type="pil", label="Microscopy Image"),
95
  gr.Dropdown(choices=[4, 10, 20], value=10, label="Magnification (×)"),
96
  gr.Dropdown(choices=[0.1, 0.5, 1.0], value=0.1, label="RA Concentration (µM)"),
 
 
 
 
 
 
 
 
97
  ],
98
  outputs=gr.Label(num_top_classes=2, label="Predicted CPM Class & Probability"),
99
  title="iPS Cell Quality Classifier",
100
  description=(
101
  "Upload a microscopy image, choose magnification & RA concentration "
102
+ "(metadata only). Probabilities have been temperature‑scaled for more "
103
+ "realistic confidence estimates."
104
  ),
105
  )
106