| | """ |
| | Shape2Force (S2F) - GUI for force map prediction from bright field microscopy images. |
| | """ |
| | import csv |
| | import io |
| | import os |
| | import sys |
| | import traceback |
| |
|
| | import cv2 |
| | cv2.utils.logging.setLogLevel(cv2.utils.logging.LOG_LEVEL_ERROR) |
| |
|
| | import numpy as np |
| | import streamlit as st |
| | from PIL import Image |
| | import plotly.graph_objects as go |
| | from plotly.subplots import make_subplots |
| |
|
| | S2F_ROOT = os.path.dirname(os.path.abspath(__file__)) |
| | if S2F_ROOT not in sys.path: |
| | sys.path.insert(0, S2F_ROOT) |
| |
|
| | from utils.substrate_settings import list_substrates |
| |
|
| | try: |
| | from streamlit_drawable_canvas import st_canvas |
| | HAS_DRAWABLE_CANVAS = True |
| | except (ImportError, AttributeError): |
| | HAS_DRAWABLE_CANVAS = False |
| |
|
| | |
| | MODEL_TYPE_LABELS = {"single_cell": "Single cell", "spheroid": "Spheroid"} |
| | DRAW_TOOLS = ["polygon", "rect", "circle"] |
| | TOOL_LABELS = {"polygon": "Polygon", "rect": "Rectangle", "circle": "Circle"} |
| | CANVAS_SIZE = 320 |
| | SAMPLE_EXTENSIONS = (".tif", ".tiff", ".png", ".jpg", ".jpeg") |
| | COLORMAPS = { |
| | "Jet": cv2.COLORMAP_JET, |
| | "Viridis": cv2.COLORMAP_VIRIDIS, |
| | "Plasma": cv2.COLORMAP_PLASMA, |
| | "Inferno": cv2.COLORMAP_INFERNO, |
| | "Magma": cv2.COLORMAP_MAGMA, |
| | } |
| |
|
| |
|
| | def _cv_colormap_to_plotly_colorscale(colormap_name, n_samples=64): |
| | """Build a Plotly colorscale from OpenCV colormap so UI matches download/PDF exactly.""" |
| | cv2_cmap = COLORMAPS.get(colormap_name, cv2.COLORMAP_JET) |
| | gradient = np.linspace(0, 255, n_samples, dtype=np.uint8).reshape(1, -1) |
| | rgb = cv2.applyColorMap(gradient, cv2_cmap) |
| | rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) |
| | |
| | scale = [] |
| | for i in range(n_samples): |
| | r, g, b = rgb[0, i] |
| | scale.append([i / (n_samples - 1), f"rgb({r},{g},{b})"]) |
| | return scale |
| |
|
| |
|
| | CITATION = ( |
| | "Lautaro Baro, Kaveh Shahhosseini, Amparo Andrés-Bordería, Claudio Angione, and Maria Angeles Juanes. " |
| | "**\"Shape-to-force (S2F): Predicting Cell Traction Forces from LabelFree Imaging\"**, 2026." |
| | ) |
| |
|
| |
|
| | def _make_annotated_heatmap(heatmap_rgb, mask, fill_alpha=0.3, stroke_color=(255, 102, 0), stroke_width=2): |
| | """Composite heatmap with drawn region overlay. heatmap_rgb and mask must match in size.""" |
| | annotated = heatmap_rgb.copy() |
| | contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| | |
| | overlay = annotated.copy() |
| | cv2.fillPoly(overlay, contours, stroke_color) |
| | mask_3d = np.stack([mask] * 3, axis=-1).astype(bool) |
| | annotated[mask_3d] = ( |
| | (1 - fill_alpha) * annotated[mask_3d].astype(np.float32) |
| | + fill_alpha * overlay[mask_3d].astype(np.float32) |
| | ).astype(np.uint8) |
| | |
| | cv2.drawContours(annotated, contours, -1, stroke_color, stroke_width) |
| | return annotated |
| |
|
| |
|
| | def _parse_canvas_shapes_to_mask(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w): |
| | """ |
| | Parse drawn shapes from streamlit-drawable-canvas json_data and create a binary mask |
| | in heatmap coordinates. Returns (mask, num_shapes) or (None, 0) if no valid shapes. |
| | """ |
| | if not json_data or "objects" not in json_data or not json_data["objects"]: |
| | return None, 0 |
| | scale_x = heatmap_w / canvas_w |
| | scale_y = heatmap_h / canvas_h |
| | mask = np.zeros((heatmap_h, heatmap_w), dtype=np.uint8) |
| | count = 0 |
| | for obj in json_data["objects"]: |
| | obj_type = obj.get("type", "") |
| | pts = [] |
| | if obj_type == "rect": |
| | left = obj.get("left", 0) |
| | top = obj.get("top", 0) |
| | w = obj.get("width", 0) |
| | h = obj.get("height", 0) |
| | pts = np.array([ |
| | [left, top], [left + w, top], [left + w, top + h], [left, top + h] |
| | ], dtype=np.float32) |
| | elif obj_type == "circle" or obj_type == "ellipse": |
| | left = obj.get("left", 0) |
| | top = obj.get("top", 0) |
| | width = obj.get("width", 0) |
| | height = obj.get("height", 0) |
| | radius = obj.get("radius", 0) |
| | angle_deg = obj.get("angle", 0) |
| | if radius > 0: |
| | |
| | |
| | rx = ry = radius |
| | angle_rad = np.deg2rad(angle_deg) |
| | cx = left + radius * np.cos(angle_rad) |
| | cy = top + radius * np.sin(angle_rad) |
| | else: |
| | |
| | rx = width / 2 if width > 0 else 0 |
| | ry = height / 2 if height > 0 else 0 |
| | if rx <= 0 or ry <= 0: |
| | continue |
| | cx = left + rx |
| | cy = top + ry |
| | if rx <= 0 or ry <= 0: |
| | continue |
| | n = 32 |
| | angles = np.linspace(0, 2 * np.pi, n, endpoint=False) |
| | pts = np.column_stack([cx + rx * np.cos(angles), cy + ry * np.sin(angles)]).astype(np.float32) |
| | elif obj_type == "path": |
| | path = obj.get("path", []) |
| | for cmd in path: |
| | if isinstance(cmd, (list, tuple)) and len(cmd) >= 3: |
| | if cmd[0] in ("M", "L"): |
| | pts.append([float(cmd[1]), float(cmd[2])]) |
| | elif cmd[0] == "Q" and len(cmd) >= 5: |
| | pts.append([float(cmd[3]), float(cmd[4])]) |
| | elif cmd[0] == "C" and len(cmd) >= 7: |
| | pts.append([float(cmd[5]), float(cmd[6])]) |
| | if len(pts) < 3: |
| | continue |
| | pts = np.array(pts, dtype=np.float32) |
| | else: |
| | continue |
| | pts[:, 0] *= scale_x |
| | pts[:, 1] *= scale_y |
| | pts = np.clip(pts, 0, [heatmap_w - 1, heatmap_h - 1]).astype(np.int32) |
| | cv2.fillPoly(mask, [pts], 1) |
| | count += 1 |
| | return mask, count |
| |
|
| |
|
| | def _heatmap_to_rgb(scaled_heatmap, colormap_name="Jet"): |
| | """Convert scaled heatmap (float 0-1) to RGB array using the given colormap.""" |
| | heatmap_uint8 = (np.clip(scaled_heatmap, 0, 1) * 255).astype(np.uint8) |
| | cv2_colormap = COLORMAPS.get(colormap_name, cv2.COLORMAP_JET) |
| | heatmap_rgb = cv2.cvtColor(cv2.applyColorMap(heatmap_uint8, cv2_colormap), cv2.COLOR_BGR2RGB) |
| | return heatmap_rgb |
| |
|
| |
|
| | def _heatmap_to_png_bytes(scaled_heatmap, colormap_name="Jet"): |
| | """Convert scaled heatmap (float 0-1) to PNG bytes buffer.""" |
| | heatmap_rgb = _heatmap_to_rgb(scaled_heatmap, colormap_name) |
| | buf = io.BytesIO() |
| | Image.fromarray(heatmap_rgb).save(buf, format="PNG") |
| | buf.seek(0) |
| | return buf |
| |
|
| |
|
| | def _create_pdf_report(img, scaled_heatmap, pixel_sum, force, force_scale, base_name, colormap_name="Jet"): |
| | """Create a PDF report with input image, heatmap, and metrics.""" |
| | from datetime import datetime |
| | from reportlab.lib.pagesizes import A4 |
| | from reportlab.lib.units import inch |
| | from reportlab.pdfgen import canvas |
| | from reportlab.lib.utils import ImageReader |
| |
|
| | buf = io.BytesIO() |
| | c = canvas.Canvas(buf, pagesize=A4) |
| | c.setTitle("Shape2Force") |
| | c.setAuthor("Angione-Lab") |
| | w, h = A4 |
| | img_w, img_h = 2.5 * inch, 2.5 * inch |
| |
|
| | |
| | footer_y = 40 |
| | c.setFont("Helvetica", 8) |
| | c.setFillColorRGB(0.4, 0.4, 0.4) |
| | gen_date = datetime.now().strftime("%Y-%m-%d %H:%M") |
| | c.drawString(72, footer_y, f"Generated by Shape2Force (S2F) on {gen_date}") |
| | c.drawString(72, footer_y - 12, "Model: https://huggingface.co/Angione-Lab/Shape2Force") |
| | c.drawString(72, footer_y - 24, "Web app: https://huggingface.co/spaces/Angione-Lab/Shape2force") |
| | c.setFillColorRGB(0, 0, 0) |
| |
|
| | |
| | img_top = h - 70 |
| | img_pil = Image.fromarray(img) if img.ndim == 2 else Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) |
| | img_buf = io.BytesIO() |
| | img_pil.save(img_buf, format="PNG") |
| | img_buf.seek(0) |
| | c.drawImage(ImageReader(img_buf), 72, img_top - img_h, width=img_w, height=img_h, preserveAspectRatio=True) |
| | c.setFont("Helvetica", 9) |
| | c.drawString(72, img_top - img_h - 12, "Input: Bright-field") |
| |
|
| | heatmap_rgb = _heatmap_to_rgb(scaled_heatmap, colormap_name) |
| | hm_buf = io.BytesIO() |
| | Image.fromarray(heatmap_rgb).save(hm_buf, format="PNG") |
| | hm_buf.seek(0) |
| | c.drawImage(ImageReader(hm_buf), 72 + img_w + 20, img_top - img_h, width=img_w, height=img_h, preserveAspectRatio=True) |
| | c.drawString(72 + img_w + 20, img_top - img_h - 12, "Output: Force map") |
| |
|
| | |
| | c.setFont("Helvetica-Bold", 16) |
| | c.drawString(72, img_top + 25, "Shape2Force (S2F) - Prediction Report") |
| | c.setFont("Helvetica", 10) |
| | c.drawString(72, img_top + 8, f"Image: {base_name}") |
| |
|
| | |
| | y = img_top - img_h - 45 |
| | c.setFont("Helvetica-Bold", 10) |
| | c.drawString(72, y, "Metrics") |
| | c.setFont("Helvetica", 9) |
| | y -= 18 |
| | metrics = [ |
| | ("Sum of all pixels", f"{pixel_sum * force_scale:.2f}"), |
| | ("Cell force (scaled)", f"{force * force_scale:.2f}"), |
| | ("Heatmap max", f"{np.max(scaled_heatmap):.4f}"), |
| | ("Heatmap mean", f"{np.mean(scaled_heatmap):.4f}"), |
| | ] |
| | for label, val in metrics: |
| | c.drawString(72, y, f"{label}: {val}") |
| | y -= 16 |
| |
|
| | c.save() |
| | buf.seek(0) |
| | return buf.getvalue() |
| |
|
| |
|
| | def _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale): |
| | """Build original_vals dict for measure tool.""" |
| | return { |
| | "pixel_sum": pixel_sum * force_scale, |
| | "force": force * force_scale, |
| | "max": float(np.max(scaled_heatmap)), |
| | "mean": float(np.mean(scaled_heatmap)), |
| | } |
| |
|
| |
|
| | def _render_result_display(img, scaled_heatmap, pixel_sum, force, force_scale, key_img, download_key_suffix="", colormap_name="Jet"): |
| | """Render prediction result: plot, metrics, expander, and download/measure buttons.""" |
| | buf_hm = _heatmap_to_png_bytes(scaled_heatmap, colormap_name) |
| | base_name = os.path.splitext(key_img or "image")[0] |
| | main_csv_rows = [ |
| | ["image", "Sum of all pixels", "Cell force (scaled)", "Heatmap max", "Heatmap mean"], |
| | [base_name, f"{pixel_sum * force_scale:.2f}", f"{force * force_scale:.2f}", |
| | f"{np.max(scaled_heatmap):.4f}", f"{np.mean(scaled_heatmap):.4f}"], |
| | ] |
| | buf_main_csv = io.StringIO() |
| | csv.writer(buf_main_csv).writerows(main_csv_rows) |
| |
|
| | tit1, tit2 = st.columns(2) |
| | with tit1: |
| | st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Input: Bright-field image</p>', unsafe_allow_html=True) |
| | with tit2: |
| | st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True) |
| | fig_pl = make_subplots(rows=1, cols=2) |
| | fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1) |
| | plotly_colorscale = _cv_colormap_to_plotly_colorscale(colormap_name) |
| | fig_pl.add_trace(go.Heatmap(z=scaled_heatmap, colorscale=plotly_colorscale, zmin=0, zmax=1, showscale=True, |
| | colorbar=dict(len=0.4, thickness=12)), row=1, col=2) |
| | fig_pl.update_layout( |
| | height=400, |
| | margin=dict(l=10, r=10, t=10, b=10), |
| | xaxis=dict(scaleanchor="y", scaleratio=1), |
| | xaxis2=dict(scaleanchor="y2", scaleratio=1), |
| | ) |
| | fig_pl.update_xaxes(showticklabels=False) |
| | fig_pl.update_yaxes(showticklabels=False, autorange="reversed") |
| | st.plotly_chart(fig_pl, use_container_width=True, config={"displayModeBar": True, "responsive": True}) |
| |
|
| | col1, col2, col3, col4 = st.columns(4) |
| | with col1: |
| | st.metric("Sum of all pixels", f"{pixel_sum * force_scale:.2f}", help="Raw sum of all pixel values in the force map") |
| | with col2: |
| | st.metric("Cell force (scaled)", f"{force * force_scale:.2f}", help="Total traction force in physical units") |
| | with col3: |
| | st.metric("Heatmap max", f"{np.max(scaled_heatmap):.4f}", help="Peak force intensity in the map") |
| | with col4: |
| | st.metric("Heatmap mean", f"{np.mean(scaled_heatmap):.4f}", help="Average force intensity") |
| |
|
| | with st.expander("How to read the results"): |
| | st.markdown(""" |
| | **Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate. |
| | This is the raw image you provided—it shows cell shape but not forces. |
| | |
| | **Output (right):** Predicted traction force map. |
| | - **Color** indicates force magnitude: blue = low, red = high |
| | - **Brighter/warmer colors** = stronger forces exerted by the cell on the substrate |
| | - Values are normalized to [0, 1] for visualization |
| | |
| | **Metrics:** |
| | - **Sum of all pixels:** Total force is the sum of all pixels in the force map. Each pixel represents the magnitude of force at that location; summing them gives the overall traction. |
| | - **Cell force (scaled):** Total traction force in physical units (scaled by substrate stiffness) |
| | - **Heatmap max/mean:** Peak and average force intensity in the map |
| | """) |
| |
|
| | original_vals = _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale) |
| | pdf_bytes = _create_pdf_report(img, scaled_heatmap, pixel_sum, force, force_scale, base_name, colormap_name) |
| | btn_col1, btn_col2, btn_col3, btn_col4 = st.columns(4) |
| | with btn_col1: |
| | if HAS_DRAWABLE_CANVAS and st_dialog: |
| | if st.button("Measure tool", key="open_measure", icon=":material/straighten:"): |
| | st.session_state["open_measure_dialog"] = True |
| | st.rerun() |
| | elif HAS_DRAWABLE_CANVAS: |
| | with st.expander("Measure tool"): |
| | _render_region_canvas( |
| | scaled_heatmap, |
| | bf_img=img, |
| | original_vals=original_vals, |
| | key_suffix="expander", |
| | input_filename=key_img, |
| | colormap_name=colormap_name, |
| | ) |
| | else: |
| | st.caption("Install `streamlit-drawable-canvas-fix` for region measurement: `pip install streamlit-drawable-canvas-fix`") |
| | with btn_col2: |
| | st.download_button( |
| | "Download heatmap", |
| | width="stretch", |
| | data=buf_hm.getvalue(), |
| | file_name="s2f_heatmap.png", |
| | mime="image/png", |
| | key=f"download_heatmap{download_key_suffix}", |
| | icon=":material/download:", |
| | ) |
| | with btn_col3: |
| | st.download_button( |
| | "Download values", |
| | width="stretch", |
| | data=buf_main_csv.getvalue(), |
| | file_name=f"{base_name}_main_values.csv", |
| | mime="text/csv", |
| | key=f"download_main_values{download_key_suffix}", |
| | icon=":material/download:", |
| | ) |
| | with btn_col4: |
| | st.download_button( |
| | "Download report", |
| | width="stretch", |
| | data=pdf_bytes, |
| | file_name=f"{base_name}_report.pdf", |
| | mime="application/pdf", |
| | key=f"download_pdf{download_key_suffix}", |
| | icon=":material/picture_as_pdf:", |
| | ) |
| |
|
| |
|
| | def _compute_region_metrics(scaled_heatmap, mask, original_vals=None): |
| | """Compute region metrics from mask. Returns dict with area_px, force_sum, density, etc.""" |
| | area_px = int(np.sum(mask)) |
| | region_values = scaled_heatmap * mask |
| | region_nonzero = region_values[mask > 0] |
| | force_sum = float(np.sum(region_values)) |
| | density = force_sum / area_px if area_px > 0 else 0 |
| | region_max = float(np.max(region_nonzero)) if len(region_nonzero) > 0 else 0 |
| | region_mean = float(np.mean(region_nonzero)) if len(region_nonzero) > 0 else 0 |
| | region_force_scaled = ( |
| | force_sum * (original_vals["force"] / original_vals["pixel_sum"]) |
| | if original_vals and original_vals.get("pixel_sum", 0) > 0 |
| | else force_sum |
| | ) |
| | return { |
| | "area_px": area_px, |
| | "force_sum": force_sum, |
| | "density": density, |
| | "max": region_max, |
| | "mean": region_mean, |
| | "force_scaled": region_force_scaled, |
| | } |
| |
|
| |
|
| | def _render_region_metrics_and_downloads(metrics, heatmap_rgb, mask, input_filename, key_suffix, has_original_vals): |
| | """Render region metrics and download buttons.""" |
| | base_name = os.path.splitext(input_filename or "image")[0] |
| | st.markdown("**Region (drawn)**") |
| | if has_original_vals: |
| | r1, r2, r3, r4, r5, r6 = st.columns(6) |
| | with r1: |
| | st.metric("Area", f"{metrics['area_px']:,}") |
| | with r2: |
| | st.metric("F.sum", f"{metrics['force_sum']:.3f}") |
| | with r3: |
| | st.metric("Force", f"{metrics['force_scaled']:.1f}") |
| | with r4: |
| | st.metric("Max", f"{metrics['max']:.3f}") |
| | with r5: |
| | st.metric("Mean", f"{metrics['mean']:.3f}") |
| | with r6: |
| | st.metric("Density", f"{metrics['density']:.4f}") |
| | csv_rows = [ |
| | ["image", "Area", "F.sum", "Force", "Max", "Mean", "Density"], |
| | [base_name, metrics["area_px"], f"{metrics['force_sum']:.3f}", f"{metrics['force_scaled']:.1f}", |
| | f"{metrics['max']:.3f}", f"{metrics['mean']:.3f}", f"{metrics['density']:.4f}"], |
| | ] |
| | else: |
| | c1, c2, c3 = st.columns(3) |
| | with c1: |
| | st.metric("Area (px²)", f"{metrics['area_px']:,}") |
| | with c2: |
| | st.metric("Force sum", f"{metrics['force_sum']:.4f}") |
| | with c3: |
| | st.metric("Density", f"{metrics['density']:.6f}") |
| | csv_rows = [ |
| | ["image", "Area", "Force sum", "Density"], |
| | [base_name, metrics["area_px"], f"{metrics['force_sum']:.4f}", f"{metrics['density']:.6f}"], |
| | ] |
| | buf_csv = io.StringIO() |
| | csv.writer(buf_csv).writerows(csv_rows) |
| | buf_img = io.BytesIO() |
| | Image.fromarray(_make_annotated_heatmap(heatmap_rgb, mask)).save(buf_img, format="PNG") |
| | buf_img.seek(0) |
| | dl_col1, dl_col2 = st.columns(2) |
| | with dl_col1: |
| | st.download_button("Download values", data=buf_csv.getvalue(), |
| | file_name=f"{base_name}_region_values.csv", mime="text/csv", |
| | key=f"download_region_values_{key_suffix}", icon=":material/download:") |
| | with dl_col2: |
| | st.download_button("Download annotated heatmap", data=buf_img.getvalue(), |
| | file_name=f"{base_name}_annotated_heatmap.png", mime="image/png", |
| | key=f"download_annotated_{key_suffix}", icon=":material/image:") |
| |
|
| |
|
| | def _render_region_canvas(scaled_heatmap, bf_img=None, original_vals=None, key_suffix="", input_filename=None, colormap_name="Jet"): |
| | """Render drawable canvas and region metrics. Used in dialog or expander.""" |
| | h, w = scaled_heatmap.shape |
| | heatmap_rgb = _heatmap_to_rgb(scaled_heatmap, colormap_name) |
| | pil_bg = Image.fromarray(heatmap_rgb).resize((CANVAS_SIZE, CANVAS_SIZE), Image.Resampling.LANCZOS) |
| |
|
| | st.markdown(""" |
| | <style> |
| | [data-testid="stDialog"] [data-testid="stSelectbox"], [data-testid="stExpander"] [data-testid="stSelectbox"], |
| | [data-testid="stDialog"] [data-testid="stSelectbox"] > div, [data-testid="stExpander"] [data-testid="stSelectbox"] > div { |
| | width: 100% !important; max-width: 100% !important; |
| | } |
| | [data-testid="stDialog"] [data-testid="stMetric"] label, [data-testid="stDialog"] [data-testid="stMetric"] [data-testid="stMetricValue"], |
| | [data-testid="stExpander"] [data-testid="stMetric"] label, [data-testid="stExpander"] [data-testid="stMetric"] [data-testid="stMetricValue"] { |
| | font-size: 0.95rem !important; |
| | } |
| | [data-testid="stDialog"] img, [data-testid="stExpander"] img { border-radius: 0 !important; } |
| | </style> |
| | """, unsafe_allow_html=True) |
| |
|
| | if bf_img is not None: |
| | bf_resized = cv2.resize(bf_img, (CANVAS_SIZE, CANVAS_SIZE)) |
| | bf_rgb = cv2.cvtColor(bf_resized, cv2.COLOR_GRAY2RGB) if bf_img.ndim == 2 else cv2.cvtColor(bf_resized, cv2.COLOR_BGR2RGB) |
| | left_col, right_col = st.columns(2, gap=None) |
| | with left_col: |
| | draw_mode = st.selectbox("Tool", DRAW_TOOLS, format_func=lambda x: TOOL_LABELS[x], key=f"draw_mode_region_{key_suffix}") |
| | st.caption("Left-click add, right-click close. \nForce map (draw region)") |
| | canvas_result = st_canvas( |
| | fill_color="rgba(255, 165, 0, 0.3)", stroke_width=2, stroke_color="#ff6600", |
| | background_image=pil_bg, drawing_mode=draw_mode, update_streamlit=True, |
| | height=CANVAS_SIZE, width=CANVAS_SIZE, display_toolbar=True, |
| | key=f"region_measure_canvas_{key_suffix}", |
| | ) |
| | with right_col: |
| | if original_vals: |
| | st.markdown('<p style="font-weight: 400; color: #334155; font-size: 0.95rem; margin: 0 20px 4px 4px;">Full map</p>', unsafe_allow_html=True) |
| | st.markdown(f""" |
| | <div style="width: 100%; box-sizing: border-box; border: 1px solid #e2e8f0; border-radius: 10px; |
| | padding: 10px 12px; margin: 0 10px 20px 10px; background: linear-gradient(145deg, #f8fafc 0%, #f1f5f9 100%); |
| | box-shadow: 0 1px 3px rgba(0,0,0,0.06);"> |
| | <div style="display: flex; flex-wrap: wrap; gap: 5px; font-size: 0.9rem;"> |
| | <span><strong>Sum:</strong> {original_vals['pixel_sum']:.1f}</span> |
| | <span><strong>Force:</strong> {original_vals['force']:.1f}</span> |
| | <span><strong>Max:</strong> {original_vals['max']:.3f}</span> |
| | <span><strong>Mean:</strong> {original_vals['mean']:.3f}</span> |
| | </div> |
| | </div> |
| | """, unsafe_allow_html=True) |
| | st.caption("Bright-field") |
| | st.image(bf_rgb, width=CANVAS_SIZE) |
| | else: |
| | st.markdown("**Draw a region** on the heatmap.") |
| | draw_mode = st.selectbox("Drawing tool", DRAW_TOOLS, |
| | format_func=lambda x: "Polygon (free shape)" if x == "polygon" else TOOL_LABELS[x], |
| | key=f"draw_mode_region_{key_suffix}") |
| | st.caption("Polygon: left-click to add points, right-click to close.") |
| | canvas_result = st_canvas( |
| | fill_color="rgba(255, 165, 0, 0.3)", stroke_width=2, stroke_color="#ff6600", |
| | background_image=pil_bg, drawing_mode=draw_mode, update_streamlit=True, |
| | height=CANVAS_SIZE, width=CANVAS_SIZE, display_toolbar=True, |
| | key=f"region_measure_canvas_{key_suffix}", |
| | ) |
| |
|
| | if canvas_result.json_data: |
| | mask, n = _parse_canvas_shapes_to_mask(canvas_result.json_data, CANVAS_SIZE, CANVAS_SIZE, h, w) |
| | if mask is not None and n > 0: |
| | metrics = _compute_region_metrics(scaled_heatmap, mask, original_vals) |
| | _render_region_metrics_and_downloads(metrics, heatmap_rgb, mask, input_filename, key_suffix, original_vals is not None) |
| |
|
| |
|
| | st_dialog = getattr(st, "dialog", None) or getattr(st, "experimental_dialog", None) |
| | if HAS_DRAWABLE_CANVAS and st_dialog: |
| | @st_dialog("Measure tool", width="medium") |
| | def measure_region_dialog(): |
| | scaled_heatmap = st.session_state.get("measure_scaled_heatmap") |
| | if scaled_heatmap is None: |
| | st.warning("No prediction available to measure.") |
| | return |
| | bf_img = st.session_state.get("measure_bf_img") |
| | original_vals = st.session_state.get("measure_original_vals") |
| | input_filename = st.session_state.get("measure_input_filename", "image") |
| | colormap_name = st.session_state.get("measure_colormap", "Jet") |
| | _render_region_canvas(scaled_heatmap, bf_img=bf_img, original_vals=original_vals, key_suffix="dialog", input_filename=input_filename, colormap_name=colormap_name) |
| | else: |
| | def measure_region_dialog(): |
| | pass |
| |
|
| |
|
| | st.set_page_config(page_title="Shape2Force (S2F)", page_icon="🦠", layout="centered") |
| | st.markdown(""" |
| | <style> |
| | section[data-testid="stSidebar"] { width: 380px !important; } |
| | div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) > div { |
| | flex: 1 1 0 !important; min-width: 0 !important; |
| | } |
| | div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) button { |
| | width: 100% !important; min-width: 100px !important; white-space: nowrap !important; |
| | } |
| | div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) > div:nth-child(1) button { |
| | background-color: #0d9488 !important; color: white !important; border-color: #0d9488 !important; |
| | } |
| | div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) > div:nth-child(1) button:hover { |
| | background-color: #0f766e !important; border-color: #0f766e !important; color: white !important; |
| | } |
| | </style> |
| | """, unsafe_allow_html=True) |
| | st.title("🦠 Shape2Force (S2F)") |
| | st.caption("Predict traction force maps from bright-field microscopy images of cells or spheroids") |
| |
|
| | |
| | ckp_base = os.path.join(S2F_ROOT, "ckp") |
| | |
| | if not os.path.isdir(ckp_base): |
| | project_root = os.path.dirname(S2F_ROOT) |
| | if os.path.isdir(os.path.join(project_root, "ckp")): |
| | ckp_base = os.path.join(project_root, "ckp") |
| | ckp_single_cell = os.path.join(ckp_base, "single_cell") |
| | ckp_spheroid = os.path.join(ckp_base, "spheroid") |
| | sample_base = os.path.join(S2F_ROOT, "samples") |
| | sample_single_cell = os.path.join(sample_base, "single_cell") |
| | sample_spheroid = os.path.join(sample_base, "spheroid") |
| |
|
| |
|
| | def get_ckp_files_for_model(model_type): |
| | """Return list of .pth files in the checkpoint folder for the given model type.""" |
| | folder = ckp_single_cell if model_type == "single_cell" else ckp_spheroid |
| | if os.path.isdir(folder): |
| | return sorted(f for f in os.listdir(folder) if f.endswith(".pth")) |
| | return [] |
| |
|
| |
|
| | def get_sample_files_for_model(model_type): |
| | """Return list of sample images in the sample folder for the given model type.""" |
| | folder = sample_single_cell if model_type == "single_cell" else sample_spheroid |
| | if os.path.isdir(folder): |
| | return sorted(f for f in os.listdir(folder) if f.lower().endswith(SAMPLE_EXTENSIONS)) |
| | return [] |
| |
|
| | |
| | with st.sidebar: |
| | st.header("Model configuration") |
| | model_type = st.radio( |
| | "Model type", |
| | ["single_cell", "spheroid"], |
| | format_func=lambda x: MODEL_TYPE_LABELS[x], |
| | horizontal=False, |
| | help="Single cell: substrate-aware force prediction. Spheroid: spheroid force maps.", |
| | ) |
| | st.caption(f"Inference mode: **{MODEL_TYPE_LABELS[model_type]}**") |
| |
|
| | ckp_files = get_ckp_files_for_model(model_type) |
| | ckp_folder = ckp_single_cell if model_type == "single_cell" else ckp_spheroid |
| | ckp_subfolder_name = "single_cell" if model_type == "single_cell" else "spheroid" |
| |
|
| | if ckp_files: |
| | checkpoint = st.selectbox( |
| | "Checkpoint", |
| | ckp_files, |
| | help=f"Select a .pth file from ckp/{ckp_subfolder_name}/", |
| | ) |
| | else: |
| | st.warning(f"No .pth files in ckp/{ckp_subfolder_name}/. Add checkpoints to load.") |
| | checkpoint = None |
| |
|
| | substrate_config = None |
| | substrate_val = "fibroblasts_PDMS" |
| | use_manual = False |
| | if model_type == "single_cell": |
| | try: |
| | substrates = list_substrates() |
| | substrate_val = st.selectbox( |
| | "Substrate (from config)", |
| | substrates, |
| | help="Select a preset from config/substrate_settings.json", |
| | ) |
| | use_manual = st.checkbox("Enter substrate values manually", value=False) |
| | if use_manual: |
| | st.caption("Enter pixelsize (µm/px) and Young's modulus (Pa)") |
| | manual_pixelsize = st.number_input("Pixelsize (µm/px)", min_value=0.1, max_value=50.0, |
| | value=3.0769, step=0.1, format="%.4f") |
| | manual_young = st.number_input("Young's modulus (Pa)", min_value=100.0, max_value=100000.0, |
| | value=6000.0, step=100.0, format="%.0f") |
| | substrate_config = {"pixelsize": manual_pixelsize, "young": manual_young} |
| | except FileNotFoundError: |
| | st.error("config/substrate_settings.json not found") |
| |
|
| | st.divider() |
| | st.header("Display options") |
| | force_scale = st.slider( |
| | "Force scale", |
| | min_value=0.0, |
| | max_value=1.0, |
| | value=1.0, |
| | step=0.01, |
| | format="%.2f", |
| | help="Scale the displayed force values. 1 = full intensity, 0.5 = half the pixel values.", |
| | ) |
| | colormap_name = st.selectbox( |
| | "Heatmap colormap", |
| | list(COLORMAPS.keys()), |
| | help="Color scheme for the force map. Viridis is often preferred for accessibility.", |
| | ) |
| |
|
| | |
| | img_source = st.radio("Image source", ["Upload", "Example"], horizontal=True, label_visibility="collapsed") |
| | img = None |
| | uploaded = None |
| | selected_sample = None |
| |
|
| | if img_source == "Upload": |
| | uploaded = st.file_uploader( |
| | "Upload bright-field image", |
| | type=["tif", "tiff", "png", "jpg", "jpeg"], |
| | help="Bright-field microscopy image of a cell or spheroid on a substrate (grayscale or RGB). The model will predict traction forces from the cell shape.", |
| | ) |
| | if uploaded: |
| | bytes_data = uploaded.read() |
| | nparr = np.frombuffer(bytes_data, np.uint8) |
| | img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE) |
| | uploaded.seek(0) |
| | else: |
| | sample_files = get_sample_files_for_model(model_type) |
| | sample_folder = sample_single_cell if model_type == "single_cell" else sample_spheroid |
| | sample_subfolder_name = "single_cell" if model_type == "single_cell" else "spheroid" |
| | if sample_files: |
| | selected_sample = st.selectbox( |
| | f"Select example image (from `samples/{sample_subfolder_name}/`)", |
| | sample_files, |
| | format_func=lambda x: x, |
| | key=f"sample_{model_type}", |
| | ) |
| | if selected_sample: |
| | sample_path = os.path.join(sample_folder, selected_sample) |
| | img = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE) |
| | |
| | n_cols = min(5, len(sample_files)) |
| | cols = st.columns(n_cols) |
| | for i, fname in enumerate(sample_files[:8]): |
| | with cols[i % n_cols]: |
| | path = os.path.join(sample_folder, fname) |
| | sample_img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) |
| | if sample_img is not None: |
| | st.image(sample_img, caption=fname, width='content') |
| | else: |
| | st.info(f"No example images in samples/{sample_subfolder_name}/. Add images or use Upload.") |
| |
|
| | col_btn, col_model, col_path = st.columns([1, 1, 3]) |
| | with col_btn: |
| | run = st.button("Run prediction", type="primary") |
| | with col_model: |
| | st.markdown(f"<span style='display: inline-flex; align-items: center; height: 38px;'>{MODEL_TYPE_LABELS[model_type]}</span>", unsafe_allow_html=True) |
| | with col_path: |
| | ckp_path = f"ckp/{ckp_subfolder_name}/{checkpoint}" if checkpoint else f"ckp/{ckp_subfolder_name}/" |
| | st.markdown(f"<span style='display: inline-flex; align-items: center; height: 38px;'>Checkpoint: <code>{ckp_path}</code></span>", unsafe_allow_html=True) |
| | has_image = img is not None |
| |
|
| | |
| | if "prediction_result" not in st.session_state: |
| | st.session_state["prediction_result"] = None |
| |
|
| | |
| | just_ran = run and checkpoint and has_image |
| | cached = st.session_state["prediction_result"] |
| | key_img = (uploaded.name if uploaded else None) if img_source == "Upload" else selected_sample |
| | current_key = (model_type, checkpoint, key_img) |
| | has_cached = cached is not None and cached.get("cache_key") == current_key |
| |
|
| | if just_ran: |
| | st.session_state["prediction_result"] = None |
| | with st.spinner("Loading model and predicting..."): |
| | try: |
| | from predictor import S2FPredictor |
| | predictor = S2FPredictor( |
| | model_type=model_type, |
| | checkpoint_path=checkpoint, |
| | ckp_folder=ckp_folder, |
| | ) |
| | sub_val = substrate_val if model_type == "single_cell" and not use_manual else "fibroblasts_PDMS" |
| | heatmap, force, pixel_sum = predictor.predict( |
| | image_array=img, |
| | substrate=sub_val, |
| | substrate_config=substrate_config if model_type == "single_cell" else None, |
| | ) |
| |
|
| | st.success("Prediction complete!") |
| |
|
| | scaled_heatmap = heatmap * force_scale |
| |
|
| | cache_key = (model_type, checkpoint, key_img) |
| | st.session_state["prediction_result"] = { |
| | "img": img.copy(), |
| | "heatmap": heatmap.copy(), |
| | "force": force, |
| | "pixel_sum": pixel_sum, |
| | "cache_key": cache_key, |
| | } |
| | st.session_state["measure_scaled_heatmap"] = scaled_heatmap.copy() |
| | st.session_state["measure_bf_img"] = img.copy() |
| | st.session_state["measure_input_filename"] = key_img or "image" |
| | st.session_state["measure_original_vals"] = _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale) |
| | st.session_state["measure_colormap"] = colormap_name |
| |
|
| | _render_result_display(img, scaled_heatmap, pixel_sum, force, force_scale, key_img, colormap_name=colormap_name) |
| |
|
| | except Exception as e: |
| | st.error(f"Prediction failed: {e}") |
| | st.code(traceback.format_exc()) |
| |
|
| | elif has_cached: |
| | r = st.session_state["prediction_result"] |
| | img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"] |
| | scaled_heatmap = heatmap * force_scale |
| |
|
| | st.session_state["measure_scaled_heatmap"] = scaled_heatmap.copy() |
| | st.session_state["measure_bf_img"] = img.copy() |
| | st.session_state["measure_input_filename"] = key_img or "image" |
| | st.session_state["measure_original_vals"] = _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale) |
| | st.session_state["measure_colormap"] = colormap_name |
| |
|
| | if st.session_state.pop("open_measure_dialog", False): |
| | measure_region_dialog() |
| |
|
| | st.success("Prediction complete!") |
| | _render_result_display(img, scaled_heatmap, pixel_sum, force, force_scale, key_img, download_key_suffix="_cached", colormap_name=colormap_name) |
| |
|
| | elif run and not checkpoint: |
| | st.warning("Please add checkpoint files to the ckp/ folder and select one.") |
| | elif run and not has_image: |
| | st.warning("Please upload an image or select an example.") |
| |
|
| | st.sidebar.divider() |
| | st.sidebar.caption(f"Examples: `samples/{ckp_subfolder_name}/`") |
| | st.sidebar.caption("If you find this software useful, please cite:") |
| | st.sidebar.caption(CITATION) |
| |
|