kaveh commited on
Commit
1abf691
·
1 Parent(s): 3e0bed9

added force scale

Browse files
Files changed (2) hide show
  1. S2FApp/app.py +30 -14
  2. models/s2f_model.py +0 -2
S2FApp/app.py CHANGED
@@ -111,6 +111,18 @@ with st.sidebar:
111
  except FileNotFoundError:
112
  st.error("config/substrate_settings.json not found")
113
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  # Main area: image input
115
  img_source = st.radio("Image source", ["Upload", "Example"], horizontal=True, label_visibility="collapsed")
116
  img = None
@@ -196,6 +208,9 @@ if just_ran:
196
 
197
  st.success("Prediction complete!")
198
 
 
 
 
199
  # Visualization - Plotly with zoom/pan, annotated (titles in Streamlit to avoid clipping)
200
  tit1, tit2 = st.columns(2)
201
  with tit1:
@@ -204,7 +219,7 @@ if just_ran:
204
  st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
205
  fig_pl = make_subplots(rows=1, cols=2)
206
  fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
207
- fig_pl.add_trace(go.Heatmap(z=heatmap, colorscale="Jet", zmin=0, zmax=1, showscale=True,
208
  colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
209
  fig_pl.update_layout(
210
  height=400,
@@ -216,16 +231,16 @@ if just_ran:
216
  fig_pl.update_yaxes(showticklabels=False, autorange="reversed")
217
  st.plotly_chart(fig_pl, use_container_width=True)
218
 
219
- # Metrics with help (below plot)
220
  col1, col2, col3, col4 = st.columns(4)
221
  with col1:
222
- st.metric("Sum of all pixels", f"{pixel_sum:.2f}", help="Raw sum of all pixel values in the force map")
223
  with col2:
224
- st.metric("Cell force (scaled)", f"{force:.2f}", help="Total traction force in physical units")
225
  with col3:
226
- st.metric("Heatmap max", f"{np.max(heatmap):.4f}", help="Peak force intensity in the map")
227
  with col4:
228
- st.metric("Heatmap mean", f"{np.mean(heatmap):.4f}", help="Average force intensity")
229
 
230
  # How to read (below numbers)
231
  with st.expander("ℹ️ How to read the results"):
@@ -244,8 +259,8 @@ This is the raw image you provided—it shows cell shape but not forces.
244
  - **Heatmap max/mean:** Peak and average force intensity in the map
245
  """)
246
 
247
- # Download
248
- heatmap_uint8 = (np.clip(heatmap, 0, 1) * 255).astype(np.uint8)
249
  heatmap_rgb = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
250
  heatmap_rgb = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB)
251
  pil_heatmap = Image.fromarray(heatmap_rgb)
@@ -274,6 +289,7 @@ elif has_cached:
274
  # Show cached results (e.g. after clicking Download)
275
  r = st.session_state["prediction_result"]
276
  img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
 
277
  st.success("Prediction complete!")
278
  tit1, tit2 = st.columns(2)
279
  with tit1:
@@ -282,7 +298,7 @@ elif has_cached:
282
  st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
283
  fig_pl = make_subplots(rows=1, cols=2)
284
  fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
285
- fig_pl.add_trace(go.Heatmap(z=heatmap, colorscale="Jet", zmin=0, zmax=1, showscale=True,
286
  colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
287
  fig_pl.update_layout(height=400, margin=dict(l=10, r=10, t=10, b=10),
288
  xaxis=dict(scaleanchor="y", scaleratio=1),
@@ -292,13 +308,13 @@ elif has_cached:
292
  st.plotly_chart(fig_pl, use_container_width=True)
293
  col1, col2, col3, col4 = st.columns(4)
294
  with col1:
295
- st.metric("Sum of all pixels", f"{pixel_sum:.2f}", help="Raw sum of all pixel values in the force map")
296
  with col2:
297
- st.metric("Cell force (scaled)", f"{force:.2f}", help="Total traction force in physical units")
298
  with col3:
299
- st.metric("Heatmap max", f"{np.max(heatmap):.4f}", help="Peak force intensity in the map")
300
  with col4:
301
- st.metric("Heatmap mean", f"{np.mean(heatmap):.4f}", help="Average force intensity")
302
  with st.expander("ℹ️ How to read the results"):
303
  st.markdown("""
304
  **Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
@@ -314,7 +330,7 @@ This is the raw image you provided—it shows cell shape but not forces.
314
  - **Cell force (scaled):** Total traction force in physical units (scaled by substrate stiffness)
315
  - **Heatmap max/mean:** Peak and average force intensity in the map
316
  """)
317
- heatmap_uint8 = (np.clip(heatmap, 0, 1) * 255).astype(np.uint8)
318
  heatmap_rgb = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
319
  heatmap_rgb = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB)
320
  pil_heatmap = Image.fromarray(heatmap_rgb)
 
111
  except FileNotFoundError:
112
  st.error("config/substrate_settings.json not found")
113
 
114
+ st.divider()
115
+ st.header("Display options")
116
+ force_scale = st.slider(
117
+ "Force scale",
118
+ min_value=0.0,
119
+ max_value=1.0,
120
+ value=1.0,
121
+ step=0.01,
122
+ format="%.2f",
123
+ help="Scale the displayed force values. 1 = full intensity, 0.5 = half the pixel values.",
124
+ )
125
+
126
  # Main area: image input
127
  img_source = st.radio("Image source", ["Upload", "Example"], horizontal=True, label_visibility="collapsed")
128
  img = None
 
208
 
209
  st.success("Prediction complete!")
210
 
211
+ # Apply force scale to displayed heatmap
212
+ scaled_heatmap = heatmap * force_scale
213
+
214
  # Visualization - Plotly with zoom/pan, annotated (titles in Streamlit to avoid clipping)
215
  tit1, tit2 = st.columns(2)
216
  with tit1:
 
219
  st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
220
  fig_pl = make_subplots(rows=1, cols=2)
221
  fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
222
+ fig_pl.add_trace(go.Heatmap(z=scaled_heatmap, colorscale="Jet", zmin=0, zmax=1, showscale=True,
223
  colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
224
  fig_pl.update_layout(
225
  height=400,
 
231
  fig_pl.update_yaxes(showticklabels=False, autorange="reversed")
232
  st.plotly_chart(fig_pl, use_container_width=True)
233
 
234
+ # Metrics with help (below plot) - use scaled values
235
  col1, col2, col3, col4 = st.columns(4)
236
  with col1:
237
+ st.metric("Sum of all pixels", f"{pixel_sum * force_scale:.2f}", help="Raw sum of all pixel values in the force map")
238
  with col2:
239
+ st.metric("Cell force (scaled)", f"{force * force_scale:.2f}", help="Total traction force in physical units")
240
  with col3:
241
+ st.metric("Heatmap max", f"{np.max(scaled_heatmap):.4f}", help="Peak force intensity in the map")
242
  with col4:
243
+ st.metric("Heatmap mean", f"{np.mean(scaled_heatmap):.4f}", help="Average force intensity")
244
 
245
  # How to read (below numbers)
246
  with st.expander("ℹ️ How to read the results"):
 
259
  - **Heatmap max/mean:** Peak and average force intensity in the map
260
  """)
261
 
262
+ # Download (scaled heatmap)
263
+ heatmap_uint8 = (np.clip(scaled_heatmap, 0, 1) * 255).astype(np.uint8)
264
  heatmap_rgb = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
265
  heatmap_rgb = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB)
266
  pil_heatmap = Image.fromarray(heatmap_rgb)
 
289
  # Show cached results (e.g. after clicking Download)
290
  r = st.session_state["prediction_result"]
291
  img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
292
+ scaled_heatmap = heatmap * force_scale
293
  st.success("Prediction complete!")
294
  tit1, tit2 = st.columns(2)
295
  with tit1:
 
298
  st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
299
  fig_pl = make_subplots(rows=1, cols=2)
300
  fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
301
+ fig_pl.add_trace(go.Heatmap(z=scaled_heatmap, colorscale="Jet", zmin=0, zmax=1, showscale=True,
302
  colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
303
  fig_pl.update_layout(height=400, margin=dict(l=10, r=10, t=10, b=10),
304
  xaxis=dict(scaleanchor="y", scaleratio=1),
 
308
  st.plotly_chart(fig_pl, use_container_width=True)
309
  col1, col2, col3, col4 = st.columns(4)
310
  with col1:
311
+ st.metric("Sum of all pixels", f"{pixel_sum * force_scale:.2f}", help="Raw sum of all pixel values in the force map")
312
  with col2:
313
+ st.metric("Cell force (scaled)", f"{force * force_scale:.2f}", help="Total traction force in physical units")
314
  with col3:
315
+ st.metric("Heatmap max", f"{np.max(scaled_heatmap):.4f}", help="Peak force intensity in the map")
316
  with col4:
317
+ st.metric("Heatmap mean", f"{np.mean(scaled_heatmap):.4f}", help="Average force intensity")
318
  with st.expander("ℹ️ How to read the results"):
319
  st.markdown("""
320
  **Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
 
330
  - **Cell force (scaled):** Total traction force in physical units (scaled by substrate stiffness)
331
  - **Heatmap max/mean:** Peak and average force intensity in the map
332
  """)
333
+ heatmap_uint8 = (np.clip(scaled_heatmap, 0, 1) * 255).astype(np.uint8)
334
  heatmap_rgb = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
335
  heatmap_rgb = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB)
336
  pil_heatmap = Image.fromarray(heatmap_rgb)
models/s2f_model.py CHANGED
@@ -174,7 +174,6 @@ class AttentionGate(nn.Module):
174
  psi = F.interpolate(psi, size=x.shape[2:], mode='bilinear', align_corners=False)
175
  return x * psi
176
 
177
-
178
  class SpheroidAttentionGate(nn.Module):
179
  """Attention Gate from ForceNet2WithAttention (s2f_spheroid). Checkpoint-compatible for ckp_spheroid_FN.pth."""
180
  def __init__(self, F_g, F_l, F_int):
@@ -238,7 +237,6 @@ class PatchGANDiscriminator(nn.Module):
238
  x = x * self.attention(x)
239
  return self.output_conv(x)
240
 
241
-
242
  class S2FGenerator(nn.Module):
243
  """
244
  S2F (Shape2Force) model: U-Net generator for force map prediction.
 
174
  psi = F.interpolate(psi, size=x.shape[2:], mode='bilinear', align_corners=False)
175
  return x * psi
176
 
 
177
  class SpheroidAttentionGate(nn.Module):
178
  """Attention Gate from ForceNet2WithAttention (s2f_spheroid). Checkpoint-compatible for ckp_spheroid_FN.pth."""
179
  def __init__(self, F_g, F_l, F_int):
 
237
  x = x * self.attention(x)
238
  return self.output_conv(x)
239
 
 
240
  class S2FGenerator(nn.Module):
241
  """
242
  S2F (Shape2Force) model: U-Net generator for force map prediction.