kaveh commited on
Commit
e6abf01
·
1 Parent(s): 8d06964

added measuring

Browse files
Files changed (2) hide show
  1. app.py +429 -131
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,9 +1,12 @@
1
  """
2
  Shape2Force (S2F) - GUI for force map prediction from bright field microscopy images.
3
  """
 
 
4
  import os
5
  import sys
6
- import io
 
7
  import cv2
8
  cv2.utils.logging.setLogLevel(cv2.utils.logging.LOG_LEVEL_ERROR)
9
 
@@ -13,19 +16,418 @@ from PIL import Image
13
  import plotly.graph_objects as go
14
  from plotly.subplots import make_subplots
15
 
16
- # Ensure S2F is in path
17
  S2F_ROOT = os.path.dirname(os.path.abspath(__file__))
18
  if S2F_ROOT not in sys.path:
19
  sys.path.insert(0, S2F_ROOT)
20
 
21
  from utils.substrate_settings import list_substrates
22
 
23
- st.set_page_config(page_title="Shape2Force (S2F)", page_icon="🦠", layout="centered")
24
- st.markdown("""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  <style>
26
- section[data-testid="stSidebar"] { width: 380px !important; }
 
 
 
 
 
 
 
 
27
  </style>
28
  """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  st.title("🦠 Shape2Force (S2F)")
30
  st.caption("Predict traction force maps from bright-field microscopy images of cells or spheroids")
31
 
@@ -42,14 +444,12 @@ sample_base = os.path.join(S2F_ROOT, "samples")
42
  sample_single_cell = os.path.join(sample_base, "single_cell")
43
  sample_spheroid = os.path.join(sample_base, "spheroid")
44
 
45
- SAMPLE_EXTENSIONS = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
46
-
47
 
48
  def get_ckp_files_for_model(model_type):
49
  """Return list of .pth files in the checkpoint folder for the given model type."""
50
  folder = ckp_single_cell if model_type == "single_cell" else ckp_spheroid
51
  if os.path.isdir(folder):
52
- return sorted([f for f in os.listdir(folder) if f.endswith(".pth")])
53
  return []
54
 
55
 
@@ -57,8 +457,7 @@ def get_sample_files_for_model(model_type):
57
  """Return list of sample images in the sample folder for the given model type."""
58
  folder = sample_single_cell if model_type == "single_cell" else sample_spheroid
59
  if os.path.isdir(folder):
60
- return sorted([f for f in os.listdir(folder)
61
- if f.lower().endswith(SAMPLE_EXTENSIONS)])
62
  return []
63
 
64
  # Sidebar: model configuration
@@ -67,11 +466,11 @@ with st.sidebar:
67
  model_type = st.radio(
68
  "Model type",
69
  ["single_cell", "spheroid"],
70
- format_func=lambda x: "Single cell" if x == "single_cell" else "Spheroid",
71
  horizontal=False,
72
  help="Single cell: substrate-aware force prediction. Spheroid: spheroid force maps.",
73
  )
74
- st.caption(f"Inference mode: **{'Single cell' if model_type == 'single_cell' else 'Spheroid'}**")
75
 
76
  ckp_files = get_ckp_files_for_model(model_type)
77
  ckp_folder = ckp_single_cell if model_type == "single_cell" else ckp_spheroid
@@ -106,8 +505,6 @@ with st.sidebar:
106
  manual_young = st.number_input("Young's modulus (Pa)", min_value=100.0, max_value=100000.0,
107
  value=6000.0, step=100.0, format="%.0f")
108
  substrate_config = {"pixelsize": manual_pixelsize, "young": manual_young}
109
- else:
110
- substrate_config = None
111
  except FileNotFoundError:
112
  st.error("config/substrate_settings.json not found")
113
 
@@ -170,8 +567,7 @@ col_btn, col_model, col_path = st.columns([1, 1, 3])
170
  with col_btn:
171
  run = st.button("Run prediction", type="primary")
172
  with col_model:
173
- model_label = "Single cell" if model_type == "single_cell" else "Spheroid"
174
- st.markdown(f"<span style='display: inline-flex; align-items: center; height: 38px;'>{model_label}</span>", unsafe_allow_html=True)
175
  with col_path:
176
  ckp_path = f"ckp/{ckp_subfolder_name}/{checkpoint}" if checkpoint else f"ckp/{ckp_subfolder_name}/"
177
  st.markdown(f"<span style='display: inline-flex; align-items: center; height: 38px;'>Checkpoint: <code>{ckp_path}</code></span>", unsafe_allow_html=True)
@@ -208,69 +604,9 @@ if just_ran:
208
 
209
  st.success("Prediction complete!")
210
 
211
- # Apply force scale to displayed heatmap
212
  scaled_heatmap = heatmap * force_scale
213
 
214
- # Visualization - Plotly with zoom/pan, annotated (titles in Streamlit to avoid clipping)
215
- tit1, tit2 = st.columns(2)
216
- with tit1:
217
- st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Input: Bright-field image</p>', unsafe_allow_html=True)
218
- with tit2:
219
- st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
220
- fig_pl = make_subplots(rows=1, cols=2)
221
- fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
222
- fig_pl.add_trace(go.Heatmap(z=scaled_heatmap, colorscale="Jet", zmin=0, zmax=1, showscale=True,
223
- colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
224
- fig_pl.update_layout(
225
- height=400,
226
- margin=dict(l=10, r=10, t=10, b=10),
227
- xaxis=dict(scaleanchor="y", scaleratio=1),
228
- xaxis2=dict(scaleanchor="y2", scaleratio=1),
229
- )
230
- fig_pl.update_xaxes(showticklabels=False)
231
- fig_pl.update_yaxes(showticklabels=False, autorange="reversed")
232
- st.plotly_chart(fig_pl, use_container_width=True)
233
-
234
- # Metrics with help (below plot) - use scaled values
235
- col1, col2, col3, col4 = st.columns(4)
236
- with col1:
237
- st.metric("Sum of all pixels", f"{pixel_sum * force_scale:.2f}", help="Raw sum of all pixel values in the force map")
238
- with col2:
239
- st.metric("Cell force (scaled)", f"{force * force_scale:.2f}", help="Total traction force in physical units")
240
- with col3:
241
- st.metric("Heatmap max", f"{np.max(scaled_heatmap):.4f}", help="Peak force intensity in the map")
242
- with col4:
243
- st.metric("Heatmap mean", f"{np.mean(scaled_heatmap):.4f}", help="Average force intensity")
244
-
245
- # How to read (below numbers)
246
- with st.expander("ℹ️ How to read the results"):
247
- st.markdown("""
248
- **Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
249
- This is the raw image you provided—it shows cell shape but not forces.
250
-
251
- **Output (right):** Predicted traction force map.
252
- - **Color** indicates force magnitude: blue = low, red = high
253
- - **Brighter/warmer colors** = stronger forces exerted by the cell on the substrate
254
- - Values are normalized to [0, 1] for visualization
255
-
256
- **Metrics:**
257
- - **Sum of all pixels:** Total force is the sum of all pixels in the force map. Each pixel represents the magnitude of force at that location; summing them gives the overall traction.
258
- - **Cell force (scaled):** Total traction force in physical units (scaled by substrate stiffness)
259
- - **Heatmap max/mean:** Peak and average force intensity in the map
260
- """)
261
-
262
- # Download (scaled heatmap)
263
- heatmap_uint8 = (np.clip(scaled_heatmap, 0, 1) * 255).astype(np.uint8)
264
- heatmap_rgb = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
265
- heatmap_rgb = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB)
266
- pil_heatmap = Image.fromarray(heatmap_rgb)
267
- buf_hm = io.BytesIO()
268
- pil_heatmap.save(buf_hm, format="PNG")
269
- buf_hm.seek(0)
270
- st.download_button("Download Heatmap", data=buf_hm.getvalue(),
271
- file_name="s2f_heatmap.png", mime="image/png", key="download_heatmap")
272
-
273
- # Store in session state so results persist when user clicks Download
274
  cache_key = (model_type, checkpoint, key_img)
275
  st.session_state["prediction_result"] = {
276
  "img": img.copy(),
@@ -279,77 +615,39 @@ This is the raw image you provided—it shows cell shape but not forces.
279
  "pixel_sum": pixel_sum,
280
  "cache_key": cache_key,
281
  }
 
 
 
 
 
 
282
 
283
  except Exception as e:
284
  st.error(f"Prediction failed: {e}")
285
- import traceback
286
  st.code(traceback.format_exc())
287
 
288
  elif has_cached:
289
- # Show cached results (e.g. after clicking Download)
290
  r = st.session_state["prediction_result"]
291
  img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
292
  scaled_heatmap = heatmap * force_scale
293
- st.success("Prediction complete!")
294
- tit1, tit2 = st.columns(2)
295
- with tit1:
296
- st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Input: Bright-field image</p>', unsafe_allow_html=True)
297
- with tit2:
298
- st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
299
- fig_pl = make_subplots(rows=1, cols=2)
300
- fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
301
- fig_pl.add_trace(go.Heatmap(z=scaled_heatmap, colorscale="Jet", zmin=0, zmax=1, showscale=True,
302
- colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
303
- fig_pl.update_layout(height=400, margin=dict(l=10, r=10, t=10, b=10),
304
- xaxis=dict(scaleanchor="y", scaleratio=1),
305
- xaxis2=dict(scaleanchor="y2", scaleratio=1))
306
- fig_pl.update_xaxes(showticklabels=False)
307
- fig_pl.update_yaxes(showticklabels=False, autorange="reversed")
308
- st.plotly_chart(fig_pl, use_container_width=True)
309
- col1, col2, col3, col4 = st.columns(4)
310
- with col1:
311
- st.metric("Sum of all pixels", f"{pixel_sum * force_scale:.2f}", help="Raw sum of all pixel values in the force map")
312
- with col2:
313
- st.metric("Cell force (scaled)", f"{force * force_scale:.2f}", help="Total traction force in physical units")
314
- with col3:
315
- st.metric("Heatmap max", f"{np.max(scaled_heatmap):.4f}", help="Peak force intensity in the map")
316
- with col4:
317
- st.metric("Heatmap mean", f"{np.mean(scaled_heatmap):.4f}", help="Average force intensity")
318
- with st.expander("ℹ️ How to read the results"):
319
- st.markdown("""
320
- **Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
321
- This is the raw image you provided—it shows cell shape but not forces.
322
 
323
- **Output (right):** Predicted traction force map.
324
- - **Color** indicates force magnitude: blue = low, red = high
325
- - **Brighter/warmer colors** = stronger forces exerted by the cell on the substrate
326
- - Values are normalized to [0, 1] for visualization
327
 
328
- **Metrics:**
329
- - **Sum of all pixels:** Total force is the sum of all pixels in the force map. Each pixel represents the magnitude of force at that location; summing them gives the overall traction.
330
- - **Cell force (scaled):** Total traction force in physical units (scaled by substrate stiffness)
331
- - **Heatmap max/mean:** Peak and average force intensity in the map
332
- """)
333
- heatmap_uint8 = (np.clip(scaled_heatmap, 0, 1) * 255).astype(np.uint8)
334
- heatmap_rgb = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
335
- heatmap_rgb = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB)
336
- pil_heatmap = Image.fromarray(heatmap_rgb)
337
- buf_hm = io.BytesIO()
338
- pil_heatmap.save(buf_hm, format="PNG")
339
- buf_hm.seek(0)
340
- st.download_button("Download Heatmap", data=buf_hm.getvalue(),
341
- file_name="s2f_heatmap.png", mime="image/png", key="download_cached")
342
 
343
  elif run and not checkpoint:
344
  st.warning("Please add checkpoint files to the ckp/ folder and select one.")
345
  elif run and not has_image:
346
  st.warning("Please upload an image or select an example.")
347
 
348
- # Footer
349
  st.sidebar.divider()
350
  st.sidebar.caption(f"Examples: `samples/{ckp_subfolder_name}/`")
351
  st.sidebar.caption("If you find this software useful, please cite:")
352
- st.sidebar.caption(
353
- "Lautaro Baro, Kaveh Shahhosseini, Amparo Andrés-Bordería, Claudio Angione, and Maria Angeles Juanes. "
354
- "**\"Shape-to-force (S2F): Predicting Cell Traction Forces from LabelFree Imaging\"**, 2026."
355
- )
 
1
  """
2
  Shape2Force (S2F) - GUI for force map prediction from bright field microscopy images.
3
  """
4
+ import csv
5
+ import io
6
  import os
7
  import sys
8
+ import traceback
9
+
10
  import cv2
11
  cv2.utils.logging.setLogLevel(cv2.utils.logging.LOG_LEVEL_ERROR)
12
 
 
16
  import plotly.graph_objects as go
17
  from plotly.subplots import make_subplots
18
 
 
19
  S2F_ROOT = os.path.dirname(os.path.abspath(__file__))
20
  if S2F_ROOT not in sys.path:
21
  sys.path.insert(0, S2F_ROOT)
22
 
23
  from utils.substrate_settings import list_substrates
24
 
25
+ try:
26
+ from streamlit_drawable_canvas import st_canvas
27
+ HAS_DRAWABLE_CANVAS = True
28
+ except (ImportError, AttributeError):
29
+ HAS_DRAWABLE_CANVAS = False
30
+
31
+ # Constants
32
+ MODEL_TYPE_LABELS = {"single_cell": "Single cell", "spheroid": "Spheroid"}
33
+ DRAW_TOOLS = ["polygon", "rect", "circle"]
34
+ TOOL_LABELS = {"polygon": "Polygon", "rect": "Rectangle", "circle": "Circle"}
35
+ CANVAS_SIZE = 320
36
+ SAMPLE_EXTENSIONS = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
37
+ CITATION = (
38
+ "Lautaro Baro, Kaveh Shahhosseini, Amparo Andrés-Bordería, Claudio Angione, and Maria Angeles Juanes. "
39
+ "**\"Shape-to-force (S2F): Predicting Cell Traction Forces from LabelFree Imaging\"**, 2026."
40
+ )
41
+
42
+
43
+ def _make_annotated_heatmap(heatmap_rgb, mask, fill_alpha=0.3, stroke_color=(255, 102, 0), stroke_width=2):
44
+ """Composite heatmap with drawn region overlay. heatmap_rgb and mask must match in size."""
45
+ annotated = heatmap_rgb.copy()
46
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
47
+ # Semi-transparent orange fill
48
+ overlay = annotated.copy()
49
+ cv2.fillPoly(overlay, contours, stroke_color)
50
+ mask_3d = np.stack([mask] * 3, axis=-1).astype(bool)
51
+ annotated[mask_3d] = (
52
+ (1 - fill_alpha) * annotated[mask_3d].astype(np.float32)
53
+ + fill_alpha * overlay[mask_3d].astype(np.float32)
54
+ ).astype(np.uint8)
55
+ # Orange contour
56
+ cv2.drawContours(annotated, contours, -1, stroke_color, stroke_width)
57
+ return annotated
58
+
59
+
60
+ def _parse_canvas_shapes_to_mask(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w):
61
+ """
62
+ Parse drawn shapes from streamlit-drawable-canvas json_data and create a binary mask
63
+ in heatmap coordinates. Returns (mask, num_shapes) or (None, 0) if no valid shapes.
64
+ """
65
+ if not json_data or "objects" not in json_data or not json_data["objects"]:
66
+ return None, 0
67
+ scale_x = heatmap_w / canvas_w
68
+ scale_y = heatmap_h / canvas_h
69
+ mask = np.zeros((heatmap_h, heatmap_w), dtype=np.uint8)
70
+ count = 0
71
+ for obj in json_data["objects"]:
72
+ obj_type = obj.get("type", "")
73
+ pts = []
74
+ if obj_type == "rect":
75
+ left = obj.get("left", 0)
76
+ top = obj.get("top", 0)
77
+ w = obj.get("width", 0)
78
+ h = obj.get("height", 0)
79
+ pts = np.array([
80
+ [left, top], [left + w, top], [left + w, top + h], [left, top + h]
81
+ ], dtype=np.float32)
82
+ elif obj_type == "circle" or obj_type == "ellipse":
83
+ left = obj.get("left", 0)
84
+ top = obj.get("top", 0)
85
+ width = obj.get("width", 0)
86
+ height = obj.get("height", 0)
87
+ radius = obj.get("radius", 0)
88
+ angle_deg = obj.get("angle", 0)
89
+ if radius > 0:
90
+ # Circle: (left, top) is mouse start point, not center.
91
+ # Center = start + radius * (cos(angle), sin(angle))
92
+ rx = ry = radius
93
+ angle_rad = np.deg2rad(angle_deg)
94
+ cx = left + radius * np.cos(angle_rad)
95
+ cy = top + radius * np.sin(angle_rad)
96
+ else:
97
+ # Ellipse: left, top = top-left of bounding box
98
+ rx = width / 2 if width > 0 else 0
99
+ ry = height / 2 if height > 0 else 0
100
+ if rx <= 0 or ry <= 0:
101
+ continue
102
+ cx = left + rx
103
+ cy = top + ry
104
+ if rx <= 0 or ry <= 0:
105
+ continue
106
+ n = 32
107
+ angles = np.linspace(0, 2 * np.pi, n, endpoint=False)
108
+ pts = np.column_stack([cx + rx * np.cos(angles), cy + ry * np.sin(angles)]).astype(np.float32)
109
+ elif obj_type == "path":
110
+ path = obj.get("path", [])
111
+ for cmd in path:
112
+ if isinstance(cmd, (list, tuple)) and len(cmd) >= 3:
113
+ if cmd[0] in ("M", "L"):
114
+ pts.append([float(cmd[1]), float(cmd[2])])
115
+ elif cmd[0] == "Q" and len(cmd) >= 5:
116
+ pts.append([float(cmd[3]), float(cmd[4])])
117
+ elif cmd[0] == "C" and len(cmd) >= 7:
118
+ pts.append([float(cmd[5]), float(cmd[6])])
119
+ if len(pts) < 3:
120
+ continue
121
+ pts = np.array(pts, dtype=np.float32)
122
+ else:
123
+ continue
124
+ pts[:, 0] *= scale_x
125
+ pts[:, 1] *= scale_y
126
+ pts = np.clip(pts, 0, [heatmap_w - 1, heatmap_h - 1]).astype(np.int32)
127
+ cv2.fillPoly(mask, [pts], 1)
128
+ count += 1
129
+ return mask, count
130
+
131
+
132
+ def _heatmap_to_png_bytes(scaled_heatmap):
133
+ """Convert scaled heatmap (float 0-1) to PNG bytes buffer."""
134
+ heatmap_uint8 = (np.clip(scaled_heatmap, 0, 1) * 255).astype(np.uint8)
135
+ heatmap_rgb = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
136
+ heatmap_rgb = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB)
137
+ buf = io.BytesIO()
138
+ Image.fromarray(heatmap_rgb).save(buf, format="PNG")
139
+ buf.seek(0)
140
+ return buf
141
+
142
+
143
+ def _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale):
144
+ """Build original_vals dict for measure tool."""
145
+ return {
146
+ "pixel_sum": pixel_sum * force_scale,
147
+ "force": force * force_scale,
148
+ "max": float(np.max(scaled_heatmap)),
149
+ "mean": float(np.mean(scaled_heatmap)),
150
+ }
151
+
152
+
153
+ def _render_result_display(img, scaled_heatmap, pixel_sum, force, force_scale, key_img, download_key_suffix=""):
154
+ """Render prediction result: plot, metrics, expander, and download/measure buttons."""
155
+ buf_hm = _heatmap_to_png_bytes(scaled_heatmap)
156
+ base_name = os.path.splitext(key_img or "image")[0]
157
+ main_csv_rows = [
158
+ ["image", "Sum of all pixels", "Cell force (scaled)", "Heatmap max", "Heatmap mean"],
159
+ [base_name, f"{pixel_sum * force_scale:.2f}", f"{force * force_scale:.2f}",
160
+ f"{np.max(scaled_heatmap):.4f}", f"{np.mean(scaled_heatmap):.4f}"],
161
+ ]
162
+ buf_main_csv = io.StringIO()
163
+ csv.writer(buf_main_csv).writerows(main_csv_rows)
164
+
165
+ tit1, tit2 = st.columns(2)
166
+ with tit1:
167
+ st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Input: Bright-field image</p>', unsafe_allow_html=True)
168
+ with tit2:
169
+ st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
170
+ fig_pl = make_subplots(rows=1, cols=2)
171
+ fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
172
+ fig_pl.add_trace(go.Heatmap(z=scaled_heatmap, colorscale="Jet", zmin=0, zmax=1, showscale=True,
173
+ colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
174
+ fig_pl.update_layout(
175
+ height=400,
176
+ margin=dict(l=10, r=10, t=10, b=10),
177
+ xaxis=dict(scaleanchor="y", scaleratio=1),
178
+ xaxis2=dict(scaleanchor="y2", scaleratio=1),
179
+ )
180
+ fig_pl.update_xaxes(showticklabels=False)
181
+ fig_pl.update_yaxes(showticklabels=False, autorange="reversed")
182
+ st.plotly_chart(fig_pl, use_container_width=True, config={"displayModeBar": True, "responsive": True})
183
+
184
+ col1, col2, col3, col4 = st.columns(4)
185
+ with col1:
186
+ st.metric("Sum of all pixels", f"{pixel_sum * force_scale:.2f}", help="Raw sum of all pixel values in the force map")
187
+ with col2:
188
+ st.metric("Cell force (scaled)", f"{force * force_scale:.2f}", help="Total traction force in physical units")
189
+ with col3:
190
+ st.metric("Heatmap max", f"{np.max(scaled_heatmap):.4f}", help="Peak force intensity in the map")
191
+ with col4:
192
+ st.metric("Heatmap mean", f"{np.mean(scaled_heatmap):.4f}", help="Average force intensity")
193
+
194
+ with st.expander("How to read the results"):
195
+ st.markdown("""
196
+ **Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
197
+ This is the raw image you provided—it shows cell shape but not forces.
198
+
199
+ **Output (right):** Predicted traction force map.
200
+ - **Color** indicates force magnitude: blue = low, red = high
201
+ - **Brighter/warmer colors** = stronger forces exerted by the cell on the substrate
202
+ - Values are normalized to [0, 1] for visualization
203
+
204
+ **Metrics:**
205
+ - **Sum of all pixels:** Total force is the sum of all pixels in the force map. Each pixel represents the magnitude of force at that location; summing them gives the overall traction.
206
+ - **Cell force (scaled):** Total traction force in physical units (scaled by substrate stiffness)
207
+ - **Heatmap max/mean:** Peak and average force intensity in the map
208
+ """)
209
+
210
+ original_vals = _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale)
211
+ btn_col1, btn_col2, btn_col3 = st.columns(3)
212
+ with btn_col1:
213
+ if HAS_DRAWABLE_CANVAS and st_dialog:
214
+ if st.button("Measure tool", key="open_measure", icon=":material/straighten:"):
215
+ st.session_state["open_measure_dialog"] = True
216
+ st.rerun()
217
+ elif HAS_DRAWABLE_CANVAS:
218
+ with st.expander("Measure tool"):
219
+ _render_region_canvas(
220
+ scaled_heatmap,
221
+ bf_img=img,
222
+ original_vals=original_vals,
223
+ key_suffix="expander",
224
+ input_filename=key_img,
225
+ )
226
+ else:
227
+ st.caption("Install `streamlit-drawable-canvas-fix` for region measurement: `pip install streamlit-drawable-canvas-fix`")
228
+ with btn_col2:
229
+ st.download_button(
230
+ "Download heatmap",
231
+ width="stretch",
232
+ data=buf_hm.getvalue(),
233
+ file_name="s2f_heatmap.png",
234
+ mime="image/png",
235
+ key=f"download_heatmap{download_key_suffix}",
236
+ icon=":material/download:",
237
+ )
238
+ with btn_col3:
239
+ st.download_button(
240
+ "Download values",
241
+ width="stretch",
242
+ data=buf_main_csv.getvalue(),
243
+ file_name=f"{base_name}_main_values.csv",
244
+ mime="text/csv",
245
+ key=f"download_main_values{download_key_suffix}",
246
+ icon=":material/download:",
247
+ )
248
+
249
+
250
+ def _compute_region_metrics(scaled_heatmap, mask, original_vals=None):
251
+ """Compute region metrics from mask. Returns dict with area_px, force_sum, density, etc."""
252
+ area_px = int(np.sum(mask))
253
+ region_values = scaled_heatmap * mask
254
+ region_nonzero = region_values[mask > 0]
255
+ force_sum = float(np.sum(region_values))
256
+ density = force_sum / area_px if area_px > 0 else 0
257
+ region_max = float(np.max(region_nonzero)) if len(region_nonzero) > 0 else 0
258
+ region_mean = float(np.mean(region_nonzero)) if len(region_nonzero) > 0 else 0
259
+ region_force_scaled = (
260
+ force_sum * (original_vals["force"] / original_vals["pixel_sum"])
261
+ if original_vals and original_vals.get("pixel_sum", 0) > 0
262
+ else force_sum
263
+ )
264
+ return {
265
+ "area_px": area_px,
266
+ "force_sum": force_sum,
267
+ "density": density,
268
+ "max": region_max,
269
+ "mean": region_mean,
270
+ "force_scaled": region_force_scaled,
271
+ }
272
+
273
+
274
+ def _render_region_metrics_and_downloads(metrics, heatmap_rgb, mask, input_filename, key_suffix, has_original_vals):
275
+ """Render region metrics and download buttons."""
276
+ base_name = os.path.splitext(input_filename or "image")[0]
277
+ st.markdown("**Region (drawn)**")
278
+ if has_original_vals:
279
+ r1, r2, r3, r4, r5, r6 = st.columns(6)
280
+ with r1:
281
+ st.metric("Area", f"{metrics['area_px']:,}")
282
+ with r2:
283
+ st.metric("F.sum", f"{metrics['force_sum']:.3f}")
284
+ with r3:
285
+ st.metric("Force", f"{metrics['force_scaled']:.1f}")
286
+ with r4:
287
+ st.metric("Max", f"{metrics['max']:.3f}")
288
+ with r5:
289
+ st.metric("Mean", f"{metrics['mean']:.3f}")
290
+ with r6:
291
+ st.metric("Density", f"{metrics['density']:.4f}")
292
+ csv_rows = [
293
+ ["image", "Area", "F.sum", "Force", "Max", "Mean", "Density"],
294
+ [base_name, metrics["area_px"], f"{metrics['force_sum']:.3f}", f"{metrics['force_scaled']:.1f}",
295
+ f"{metrics['max']:.3f}", f"{metrics['mean']:.3f}", f"{metrics['density']:.4f}"],
296
+ ]
297
+ else:
298
+ c1, c2, c3 = st.columns(3)
299
+ with c1:
300
+ st.metric("Area (px²)", f"{metrics['area_px']:,}")
301
+ with c2:
302
+ st.metric("Force sum", f"{metrics['force_sum']:.4f}")
303
+ with c3:
304
+ st.metric("Density", f"{metrics['density']:.6f}")
305
+ csv_rows = [
306
+ ["image", "Area", "Force sum", "Density"],
307
+ [base_name, metrics["area_px"], f"{metrics['force_sum']:.4f}", f"{metrics['density']:.6f}"],
308
+ ]
309
+ buf_csv = io.StringIO()
310
+ csv.writer(buf_csv).writerows(csv_rows)
311
+ buf_img = io.BytesIO()
312
+ Image.fromarray(_make_annotated_heatmap(heatmap_rgb, mask)).save(buf_img, format="PNG")
313
+ buf_img.seek(0)
314
+ dl_col1, dl_col2 = st.columns(2)
315
+ with dl_col1:
316
+ st.download_button("Download values", data=buf_csv.getvalue(),
317
+ file_name=f"{base_name}_region_values.csv", mime="text/csv",
318
+ key=f"download_region_values_{key_suffix}", icon=":material/download:")
319
+ with dl_col2:
320
+ st.download_button("Download annotated heatmap", data=buf_img.getvalue(),
321
+ file_name=f"{base_name}_annotated_heatmap.png", mime="image/png",
322
+ key=f"download_annotated_{key_suffix}", icon=":material/image:")
323
+
324
+
325
+ def _render_region_canvas(scaled_heatmap, bf_img=None, original_vals=None, key_suffix="", input_filename=None):
326
+ """Render drawable canvas and region metrics. Used in dialog or expander."""
327
+ h, w = scaled_heatmap.shape
328
+ heatmap_display = (np.clip(scaled_heatmap, 0, 1) * 255).astype(np.uint8)
329
+ heatmap_rgb = cv2.cvtColor(cv2.applyColorMap(heatmap_display, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB)
330
+ pil_bg = Image.fromarray(heatmap_rgb).resize((CANVAS_SIZE, CANVAS_SIZE), Image.Resampling.LANCZOS)
331
+
332
+ st.markdown("""
333
  <style>
334
+ [data-testid="stDialog"] [data-testid="stSelectbox"], [data-testid="stExpander"] [data-testid="stSelectbox"],
335
+ [data-testid="stDialog"] [data-testid="stSelectbox"] > div, [data-testid="stExpander"] [data-testid="stSelectbox"] > div {
336
+ width: 100% !important; max-width: 100% !important;
337
+ }
338
+ [data-testid="stDialog"] [data-testid="stMetric"] label, [data-testid="stDialog"] [data-testid="stMetric"] [data-testid="stMetricValue"],
339
+ [data-testid="stExpander"] [data-testid="stMetric"] label, [data-testid="stExpander"] [data-testid="stMetric"] [data-testid="stMetricValue"] {
340
+ font-size: 0.95rem !important;
341
+ }
342
+ [data-testid="stDialog"] img, [data-testid="stExpander"] img { border-radius: 0 !important; }
343
  </style>
344
  """, unsafe_allow_html=True)
345
+
346
+ if bf_img is not None:
347
+ bf_resized = cv2.resize(bf_img, (CANVAS_SIZE, CANVAS_SIZE))
348
+ bf_rgb = cv2.cvtColor(bf_resized, cv2.COLOR_GRAY2RGB) if bf_img.ndim == 2 else cv2.cvtColor(bf_resized, cv2.COLOR_BGR2RGB)
349
+ left_col, right_col = st.columns(2, gap=None)
350
+ with left_col:
351
+ draw_mode = st.selectbox("Tool", DRAW_TOOLS, format_func=lambda x: TOOL_LABELS[x], key=f"draw_mode_region_{key_suffix}")
352
+ st.caption("Left-click add, right-click close. \nForce map (draw region)")
353
+ canvas_result = st_canvas(
354
+ fill_color="rgba(255, 165, 0, 0.3)", stroke_width=2, stroke_color="#ff6600",
355
+ background_image=pil_bg, drawing_mode=draw_mode, update_streamlit=True,
356
+ height=CANVAS_SIZE, width=CANVAS_SIZE, display_toolbar=True,
357
+ key=f"region_measure_canvas_{key_suffix}",
358
+ )
359
+ with right_col:
360
+ if original_vals:
361
+ st.markdown('<p style="font-weight: 400; color: #334155; font-size: 0.95rem; margin: 0 20px 4px 4px;">Full map</p>', unsafe_allow_html=True)
362
+ st.markdown(f"""
363
+ <div style="width: 100%; box-sizing: border-box; border: 1px solid #e2e8f0; border-radius: 10px;
364
+ padding: 10px 12px; margin: 0 10px 20px 10px; background: linear-gradient(145deg, #f8fafc 0%, #f1f5f9 100%);
365
+ box-shadow: 0 1px 3px rgba(0,0,0,0.06);">
366
+ <div style="display: flex; flex-wrap: wrap; gap: 5px; font-size: 0.9rem;">
367
+ <span><strong>Sum:</strong> {original_vals['pixel_sum']:.1f}</span>
368
+ <span><strong>Force:</strong> {original_vals['force']:.1f}</span>
369
+ <span><strong>Max:</strong> {original_vals['max']:.3f}</span>
370
+ <span><strong>Mean:</strong> {original_vals['mean']:.3f}</span>
371
+ </div>
372
+ </div>
373
+ """, unsafe_allow_html=True)
374
+ st.caption("Bright-field")
375
+ st.image(bf_rgb, width=CANVAS_SIZE)
376
+ else:
377
+ st.markdown("**Draw a region** on the heatmap.")
378
+ draw_mode = st.selectbox("Drawing tool", DRAW_TOOLS,
379
+ format_func=lambda x: "Polygon (free shape)" if x == "polygon" else TOOL_LABELS[x],
380
+ key=f"draw_mode_region_{key_suffix}")
381
+ st.caption("Polygon: left-click to add points, right-click to close.")
382
+ canvas_result = st_canvas(
383
+ fill_color="rgba(255, 165, 0, 0.3)", stroke_width=2, stroke_color="#ff6600",
384
+ background_image=pil_bg, drawing_mode=draw_mode, update_streamlit=True,
385
+ height=CANVAS_SIZE, width=CANVAS_SIZE, display_toolbar=True,
386
+ key=f"region_measure_canvas_{key_suffix}",
387
+ )
388
+
389
+ if canvas_result.json_data:
390
+ mask, n = _parse_canvas_shapes_to_mask(canvas_result.json_data, CANVAS_SIZE, CANVAS_SIZE, h, w)
391
+ if mask is not None and n > 0:
392
+ metrics = _compute_region_metrics(scaled_heatmap, mask, original_vals)
393
+ _render_region_metrics_and_downloads(metrics, heatmap_rgb, mask, input_filename, key_suffix, original_vals is not None)
394
+
395
+
396
+ st_dialog = getattr(st, "dialog", None) or getattr(st, "experimental_dialog", None)
397
+ if HAS_DRAWABLE_CANVAS and st_dialog:
398
+ @st_dialog("Measure tool", width="medium")
399
+ def measure_region_dialog():
400
+ scaled_heatmap = st.session_state.get("measure_scaled_heatmap")
401
+ if scaled_heatmap is None:
402
+ st.warning("No prediction available to measure.")
403
+ return
404
+ bf_img = st.session_state.get("measure_bf_img")
405
+ original_vals = st.session_state.get("measure_original_vals")
406
+ input_filename = st.session_state.get("measure_input_filename", "image")
407
+ _render_region_canvas(scaled_heatmap, bf_img=bf_img, original_vals=original_vals, key_suffix="dialog", input_filename=input_filename)
408
+ else:
409
+ def measure_region_dialog():
410
+ pass # no-op when canvas or dialog not available
411
+
412
+
413
+ st.set_page_config(page_title="Shape2Force (S2F)", page_icon="🦠", layout="centered")
414
+ st.markdown("""
415
+ <style>
416
+ section[data-testid="stSidebar"] { width: 380px !important; }
417
+ div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) > div {
418
+ flex: 1 1 0 !important; min-width: 0 !important;
419
+ }
420
+ div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) button {
421
+ width: 100% !important; min-width: 100px !important; white-space: nowrap !important;
422
+ }
423
+ div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) > div:nth-child(1) button {
424
+ background-color: #0d9488 !important; color: white !important; border-color: #0d9488 !important;
425
+ }
426
+ div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) > div:nth-child(1) button:hover {
427
+ background-color: #0f766e !important; border-color: #0f766e !important; color: white !important;
428
+ }
429
+ </style>
430
+ """, unsafe_allow_html=True)
431
  st.title("🦠 Shape2Force (S2F)")
432
  st.caption("Predict traction force maps from bright-field microscopy images of cells or spheroids")
433
 
 
444
  sample_single_cell = os.path.join(sample_base, "single_cell")
445
  sample_spheroid = os.path.join(sample_base, "spheroid")
446
 
 
 
447
 
448
  def get_ckp_files_for_model(model_type):
449
  """Return list of .pth files in the checkpoint folder for the given model type."""
450
  folder = ckp_single_cell if model_type == "single_cell" else ckp_spheroid
451
  if os.path.isdir(folder):
452
+ return sorted(f for f in os.listdir(folder) if f.endswith(".pth"))
453
  return []
454
 
455
 
 
457
  """Return list of sample images in the sample folder for the given model type."""
458
  folder = sample_single_cell if model_type == "single_cell" else sample_spheroid
459
  if os.path.isdir(folder):
460
+ return sorted(f for f in os.listdir(folder) if f.lower().endswith(SAMPLE_EXTENSIONS))
 
461
  return []
462
 
463
  # Sidebar: model configuration
 
466
  model_type = st.radio(
467
  "Model type",
468
  ["single_cell", "spheroid"],
469
+ format_func=lambda x: MODEL_TYPE_LABELS[x],
470
  horizontal=False,
471
  help="Single cell: substrate-aware force prediction. Spheroid: spheroid force maps.",
472
  )
473
+ st.caption(f"Inference mode: **{MODEL_TYPE_LABELS[model_type]}**")
474
 
475
  ckp_files = get_ckp_files_for_model(model_type)
476
  ckp_folder = ckp_single_cell if model_type == "single_cell" else ckp_spheroid
 
505
  manual_young = st.number_input("Young's modulus (Pa)", min_value=100.0, max_value=100000.0,
506
  value=6000.0, step=100.0, format="%.0f")
507
  substrate_config = {"pixelsize": manual_pixelsize, "young": manual_young}
 
 
508
  except FileNotFoundError:
509
  st.error("config/substrate_settings.json not found")
510
 
 
567
  with col_btn:
568
  run = st.button("Run prediction", type="primary")
569
  with col_model:
570
+ st.markdown(f"<span style='display: inline-flex; align-items: center; height: 38px;'>{MODEL_TYPE_LABELS[model_type]}</span>", unsafe_allow_html=True)
 
571
  with col_path:
572
  ckp_path = f"ckp/{ckp_subfolder_name}/{checkpoint}" if checkpoint else f"ckp/{ckp_subfolder_name}/"
573
  st.markdown(f"<span style='display: inline-flex; align-items: center; height: 38px;'>Checkpoint: <code>{ckp_path}</code></span>", unsafe_allow_html=True)
 
604
 
605
  st.success("Prediction complete!")
606
 
 
607
  scaled_heatmap = heatmap * force_scale
608
 
609
+ # Store result and measure data before rendering (Measure click survives rerun)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610
  cache_key = (model_type, checkpoint, key_img)
611
  st.session_state["prediction_result"] = {
612
  "img": img.copy(),
 
615
  "pixel_sum": pixel_sum,
616
  "cache_key": cache_key,
617
  }
618
+ st.session_state["measure_scaled_heatmap"] = scaled_heatmap.copy()
619
+ st.session_state["measure_bf_img"] = img.copy()
620
+ st.session_state["measure_input_filename"] = key_img or "image"
621
+ st.session_state["measure_original_vals"] = _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale)
622
+
623
+ _render_result_display(img, scaled_heatmap, pixel_sum, force, force_scale, key_img)
624
 
625
  except Exception as e:
626
  st.error(f"Prediction failed: {e}")
 
627
  st.code(traceback.format_exc())
628
 
629
  elif has_cached:
 
630
  r = st.session_state["prediction_result"]
631
  img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
632
  scaled_heatmap = heatmap * force_scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
 
634
+ st.session_state["measure_scaled_heatmap"] = scaled_heatmap.copy()
635
+ st.session_state["measure_bf_img"] = img.copy()
636
+ st.session_state["measure_input_filename"] = key_img or "image"
637
+ st.session_state["measure_original_vals"] = _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale)
638
 
639
+ if st.session_state.pop("open_measure_dialog", False):
640
+ measure_region_dialog()
641
+
642
+ st.success("Prediction complete!")
643
+ _render_result_display(img, scaled_heatmap, pixel_sum, force, force_scale, key_img, download_key_suffix="_cached")
 
 
 
 
 
 
 
 
 
644
 
645
  elif run and not checkpoint:
646
  st.warning("Please add checkpoint files to the ckp/ folder and select one.")
647
  elif run and not has_image:
648
  st.warning("Please upload an image or select an example.")
649
 
 
650
  st.sidebar.divider()
651
  st.sidebar.caption(f"Examples: `samples/{ckp_subfolder_name}/`")
652
  st.sidebar.caption("If you find this software useful, please cite:")
653
+ st.sidebar.caption(CITATION)
 
 
 
requirements.txt CHANGED
@@ -4,6 +4,7 @@ torchvision>=0.15.0
4
  numpy>=1.20.0
5
  opencv-python>=4.5.0
6
  streamlit>=1.28.0
 
7
  matplotlib>=3.5.0
8
  Pillow>=9.0.0
9
  plotly>=5.14.0
 
4
  numpy>=1.20.0
5
  opencv-python>=4.5.0
6
  streamlit>=1.28.0
7
+ streamlit-drawable-canvas-fix>=0.9.8
8
  matplotlib>=3.5.0
9
  Pillow>=9.0.0
10
  plotly>=5.14.0