added force scale
Browse files- S2FApp/app.py +30 -14
- models/s2f_model.py +0 -2
S2FApp/app.py
CHANGED
|
@@ -111,6 +111,18 @@ with st.sidebar:
|
|
| 111 |
except FileNotFoundError:
|
| 112 |
st.error("config/substrate_settings.json not found")
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
# Main area: image input
|
| 115 |
img_source = st.radio("Image source", ["Upload", "Example"], horizontal=True, label_visibility="collapsed")
|
| 116 |
img = None
|
|
@@ -196,6 +208,9 @@ if just_ran:
|
|
| 196 |
|
| 197 |
st.success("Prediction complete!")
|
| 198 |
|
|
|
|
|
|
|
|
|
|
| 199 |
# Visualization - Plotly with zoom/pan, annotated (titles in Streamlit to avoid clipping)
|
| 200 |
tit1, tit2 = st.columns(2)
|
| 201 |
with tit1:
|
|
@@ -204,7 +219,7 @@ if just_ran:
|
|
| 204 |
st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
|
| 205 |
fig_pl = make_subplots(rows=1, cols=2)
|
| 206 |
fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
|
| 207 |
-
fig_pl.add_trace(go.Heatmap(z=
|
| 208 |
colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
|
| 209 |
fig_pl.update_layout(
|
| 210 |
height=400,
|
|
@@ -216,16 +231,16 @@ if just_ran:
|
|
| 216 |
fig_pl.update_yaxes(showticklabels=False, autorange="reversed")
|
| 217 |
st.plotly_chart(fig_pl, use_container_width=True)
|
| 218 |
|
| 219 |
-
# Metrics with help (below plot)
|
| 220 |
col1, col2, col3, col4 = st.columns(4)
|
| 221 |
with col1:
|
| 222 |
-
st.metric("Sum of all pixels", f"{pixel_sum:.2f}", help="Raw sum of all pixel values in the force map")
|
| 223 |
with col2:
|
| 224 |
-
st.metric("Cell force (scaled)", f"{force:.2f}", help="Total traction force in physical units")
|
| 225 |
with col3:
|
| 226 |
-
st.metric("Heatmap max", f"{np.max(
|
| 227 |
with col4:
|
| 228 |
-
st.metric("Heatmap mean", f"{np.mean(
|
| 229 |
|
| 230 |
# How to read (below numbers)
|
| 231 |
with st.expander("ℹ️ How to read the results"):
|
|
@@ -244,8 +259,8 @@ This is the raw image you provided—it shows cell shape but not forces.
|
|
| 244 |
- **Heatmap max/mean:** Peak and average force intensity in the map
|
| 245 |
""")
|
| 246 |
|
| 247 |
-
# Download
|
| 248 |
-
heatmap_uint8 = (np.clip(
|
| 249 |
heatmap_rgb = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
|
| 250 |
heatmap_rgb = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB)
|
| 251 |
pil_heatmap = Image.fromarray(heatmap_rgb)
|
|
@@ -274,6 +289,7 @@ elif has_cached:
|
|
| 274 |
# Show cached results (e.g. after clicking Download)
|
| 275 |
r = st.session_state["prediction_result"]
|
| 276 |
img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
|
|
|
|
| 277 |
st.success("Prediction complete!")
|
| 278 |
tit1, tit2 = st.columns(2)
|
| 279 |
with tit1:
|
|
@@ -282,7 +298,7 @@ elif has_cached:
|
|
| 282 |
st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
|
| 283 |
fig_pl = make_subplots(rows=1, cols=2)
|
| 284 |
fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
|
| 285 |
-
fig_pl.add_trace(go.Heatmap(z=
|
| 286 |
colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
|
| 287 |
fig_pl.update_layout(height=400, margin=dict(l=10, r=10, t=10, b=10),
|
| 288 |
xaxis=dict(scaleanchor="y", scaleratio=1),
|
|
@@ -292,13 +308,13 @@ elif has_cached:
|
|
| 292 |
st.plotly_chart(fig_pl, use_container_width=True)
|
| 293 |
col1, col2, col3, col4 = st.columns(4)
|
| 294 |
with col1:
|
| 295 |
-
st.metric("Sum of all pixels", f"{pixel_sum:.2f}", help="Raw sum of all pixel values in the force map")
|
| 296 |
with col2:
|
| 297 |
-
st.metric("Cell force (scaled)", f"{force:.2f}", help="Total traction force in physical units")
|
| 298 |
with col3:
|
| 299 |
-
st.metric("Heatmap max", f"{np.max(
|
| 300 |
with col4:
|
| 301 |
-
st.metric("Heatmap mean", f"{np.mean(
|
| 302 |
with st.expander("ℹ️ How to read the results"):
|
| 303 |
st.markdown("""
|
| 304 |
**Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
|
|
@@ -314,7 +330,7 @@ This is the raw image you provided—it shows cell shape but not forces.
|
|
| 314 |
- **Cell force (scaled):** Total traction force in physical units (scaled by substrate stiffness)
|
| 315 |
- **Heatmap max/mean:** Peak and average force intensity in the map
|
| 316 |
""")
|
| 317 |
-
heatmap_uint8 = (np.clip(
|
| 318 |
heatmap_rgb = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
|
| 319 |
heatmap_rgb = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB)
|
| 320 |
pil_heatmap = Image.fromarray(heatmap_rgb)
|
|
|
|
| 111 |
except FileNotFoundError:
|
| 112 |
st.error("config/substrate_settings.json not found")
|
| 113 |
|
| 114 |
+
st.divider()
|
| 115 |
+
st.header("Display options")
|
| 116 |
+
force_scale = st.slider(
|
| 117 |
+
"Force scale",
|
| 118 |
+
min_value=0.0,
|
| 119 |
+
max_value=1.0,
|
| 120 |
+
value=1.0,
|
| 121 |
+
step=0.01,
|
| 122 |
+
format="%.2f",
|
| 123 |
+
help="Scale the displayed force values. 1 = full intensity, 0.5 = half the pixel values.",
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
# Main area: image input
|
| 127 |
img_source = st.radio("Image source", ["Upload", "Example"], horizontal=True, label_visibility="collapsed")
|
| 128 |
img = None
|
|
|
|
| 208 |
|
| 209 |
st.success("Prediction complete!")
|
| 210 |
|
| 211 |
+
# Apply force scale to displayed heatmap
|
| 212 |
+
scaled_heatmap = heatmap * force_scale
|
| 213 |
+
|
| 214 |
# Visualization - Plotly with zoom/pan, annotated (titles in Streamlit to avoid clipping)
|
| 215 |
tit1, tit2 = st.columns(2)
|
| 216 |
with tit1:
|
|
|
|
| 219 |
st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
|
| 220 |
fig_pl = make_subplots(rows=1, cols=2)
|
| 221 |
fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
|
| 222 |
+
fig_pl.add_trace(go.Heatmap(z=scaled_heatmap, colorscale="Jet", zmin=0, zmax=1, showscale=True,
|
| 223 |
colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
|
| 224 |
fig_pl.update_layout(
|
| 225 |
height=400,
|
|
|
|
| 231 |
fig_pl.update_yaxes(showticklabels=False, autorange="reversed")
|
| 232 |
st.plotly_chart(fig_pl, use_container_width=True)
|
| 233 |
|
| 234 |
+
# Metrics with help (below plot) - use scaled values
|
| 235 |
col1, col2, col3, col4 = st.columns(4)
|
| 236 |
with col1:
|
| 237 |
+
st.metric("Sum of all pixels", f"{pixel_sum * force_scale:.2f}", help="Raw sum of all pixel values in the force map")
|
| 238 |
with col2:
|
| 239 |
+
st.metric("Cell force (scaled)", f"{force * force_scale:.2f}", help="Total traction force in physical units")
|
| 240 |
with col3:
|
| 241 |
+
st.metric("Heatmap max", f"{np.max(scaled_heatmap):.4f}", help="Peak force intensity in the map")
|
| 242 |
with col4:
|
| 243 |
+
st.metric("Heatmap mean", f"{np.mean(scaled_heatmap):.4f}", help="Average force intensity")
|
| 244 |
|
| 245 |
# How to read (below numbers)
|
| 246 |
with st.expander("ℹ️ How to read the results"):
|
|
|
|
| 259 |
- **Heatmap max/mean:** Peak and average force intensity in the map
|
| 260 |
""")
|
| 261 |
|
| 262 |
+
# Download (scaled heatmap)
|
| 263 |
+
heatmap_uint8 = (np.clip(scaled_heatmap, 0, 1) * 255).astype(np.uint8)
|
| 264 |
heatmap_rgb = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
|
| 265 |
heatmap_rgb = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB)
|
| 266 |
pil_heatmap = Image.fromarray(heatmap_rgb)
|
|
|
|
| 289 |
# Show cached results (e.g. after clicking Download)
|
| 290 |
r = st.session_state["prediction_result"]
|
| 291 |
img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
|
| 292 |
+
scaled_heatmap = heatmap * force_scale
|
| 293 |
st.success("Prediction complete!")
|
| 294 |
tit1, tit2 = st.columns(2)
|
| 295 |
with tit1:
|
|
|
|
| 298 |
st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
|
| 299 |
fig_pl = make_subplots(rows=1, cols=2)
|
| 300 |
fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
|
| 301 |
+
fig_pl.add_trace(go.Heatmap(z=scaled_heatmap, colorscale="Jet", zmin=0, zmax=1, showscale=True,
|
| 302 |
colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
|
| 303 |
fig_pl.update_layout(height=400, margin=dict(l=10, r=10, t=10, b=10),
|
| 304 |
xaxis=dict(scaleanchor="y", scaleratio=1),
|
|
|
|
| 308 |
st.plotly_chart(fig_pl, use_container_width=True)
|
| 309 |
col1, col2, col3, col4 = st.columns(4)
|
| 310 |
with col1:
|
| 311 |
+
st.metric("Sum of all pixels", f"{pixel_sum * force_scale:.2f}", help="Raw sum of all pixel values in the force map")
|
| 312 |
with col2:
|
| 313 |
+
st.metric("Cell force (scaled)", f"{force * force_scale:.2f}", help="Total traction force in physical units")
|
| 314 |
with col3:
|
| 315 |
+
st.metric("Heatmap max", f"{np.max(scaled_heatmap):.4f}", help="Peak force intensity in the map")
|
| 316 |
with col4:
|
| 317 |
+
st.metric("Heatmap mean", f"{np.mean(scaled_heatmap):.4f}", help="Average force intensity")
|
| 318 |
with st.expander("ℹ️ How to read the results"):
|
| 319 |
st.markdown("""
|
| 320 |
**Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
|
|
|
|
| 330 |
- **Cell force (scaled):** Total traction force in physical units (scaled by substrate stiffness)
|
| 331 |
- **Heatmap max/mean:** Peak and average force intensity in the map
|
| 332 |
""")
|
| 333 |
+
heatmap_uint8 = (np.clip(scaled_heatmap, 0, 1) * 255).astype(np.uint8)
|
| 334 |
heatmap_rgb = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
|
| 335 |
heatmap_rgb = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB)
|
| 336 |
pil_heatmap = Image.fromarray(heatmap_rgb)
|
models/s2f_model.py
CHANGED
|
@@ -174,7 +174,6 @@ class AttentionGate(nn.Module):
|
|
| 174 |
psi = F.interpolate(psi, size=x.shape[2:], mode='bilinear', align_corners=False)
|
| 175 |
return x * psi
|
| 176 |
|
| 177 |
-
|
| 178 |
class SpheroidAttentionGate(nn.Module):
|
| 179 |
"""Attention Gate from ForceNet2WithAttention (s2f_spheroid). Checkpoint-compatible for ckp_spheroid_FN.pth."""
|
| 180 |
def __init__(self, F_g, F_l, F_int):
|
|
@@ -238,7 +237,6 @@ class PatchGANDiscriminator(nn.Module):
|
|
| 238 |
x = x * self.attention(x)
|
| 239 |
return self.output_conv(x)
|
| 240 |
|
| 241 |
-
|
| 242 |
class S2FGenerator(nn.Module):
|
| 243 |
"""
|
| 244 |
S2F (Shape2Force) model: U-Net generator for force map prediction.
|
|
|
|
| 174 |
psi = F.interpolate(psi, size=x.shape[2:], mode='bilinear', align_corners=False)
|
| 175 |
return x * psi
|
| 176 |
|
|
|
|
| 177 |
class SpheroidAttentionGate(nn.Module):
|
| 178 |
"""Attention Gate from ForceNet2WithAttention (s2f_spheroid). Checkpoint-compatible for ckp_spheroid_FN.pth."""
|
| 179 |
def __init__(self, F_g, F_l, F_int):
|
|
|
|
| 237 |
x = x * self.attention(x)
|
| 238 |
return self.output_conv(x)
|
| 239 |
|
|
|
|
| 240 |
class S2FGenerator(nn.Module):
|
| 241 |
"""
|
| 242 |
S2F (Shape2Force) model: U-Net generator for force map prediction.
|