Spaces:
Running
Running
| """ | |
| Shape2Force (S2F) - GUI for force map prediction from bright field microscopy images. | |
| """ | |
| import csv | |
| import io | |
| import os | |
| import sys | |
| import traceback | |
| import cv2 | |
| cv2.utils.logging.setLogLevel(cv2.utils.logging.LOG_LEVEL_ERROR) | |
| import numpy as np | |
| import streamlit as st | |
| from PIL import Image | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| S2F_ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| if S2F_ROOT not in sys.path: | |
| sys.path.insert(0, S2F_ROOT) | |
| from utils.substrate_settings import list_substrates | |
| try: | |
| from streamlit_drawable_canvas import st_canvas | |
| HAS_DRAWABLE_CANVAS = True | |
| except (ImportError, AttributeError): | |
| HAS_DRAWABLE_CANVAS = False | |
| # Constants | |
| MODEL_TYPE_LABELS = {"single_cell": "Single cell", "spheroid": "Spheroid"} | |
| DRAW_TOOLS = ["polygon", "rect", "circle"] | |
| TOOL_LABELS = {"polygon": "Polygon", "rect": "Rectangle", "circle": "Circle"} | |
| CANVAS_SIZE = 320 | |
| SAMPLE_EXTENSIONS = (".tif", ".tiff", ".png", ".jpg", ".jpeg") | |
| COLORMAPS = { | |
| "Jet": cv2.COLORMAP_JET, | |
| "Viridis": cv2.COLORMAP_VIRIDIS, | |
| "Plasma": cv2.COLORMAP_PLASMA, | |
| "Inferno": cv2.COLORMAP_INFERNO, | |
| "Magma": cv2.COLORMAP_MAGMA, | |
| } | |
| def _cv_colormap_to_plotly_colorscale(colormap_name, n_samples=64): | |
| """Build a Plotly colorscale from OpenCV colormap so UI matches download/PDF exactly.""" | |
| cv2_cmap = COLORMAPS.get(colormap_name, cv2.COLORMAP_JET) | |
| gradient = np.linspace(0, 255, n_samples, dtype=np.uint8).reshape(1, -1) | |
| rgb = cv2.applyColorMap(gradient, cv2_cmap) | |
| rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) | |
| # Plotly colorscale: [[position 0..1, 'rgb(r,g,b)'], ...] | |
| scale = [] | |
| for i in range(n_samples): | |
| r, g, b = rgb[0, i] | |
| scale.append([i / (n_samples - 1), f"rgb({r},{g},{b})"]) | |
| return scale | |
| CITATION = ( | |
| "Lautaro Baro, Kaveh Shahhosseini, Amparo Andrés-Bordería, Claudio Angione, and Maria Angeles Juanes. " | |
| "**\"Shape-to-force (S2F): Predicting Cell Traction Forces from LabelFree Imaging\"**, 2026." | |
| ) | |
| def _make_annotated_heatmap(heatmap_rgb, mask, fill_alpha=0.3, stroke_color=(255, 102, 0), stroke_width=2): | |
| """Composite heatmap with drawn region overlay. heatmap_rgb and mask must match in size.""" | |
| annotated = heatmap_rgb.copy() | |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| # Semi-transparent orange fill | |
| overlay = annotated.copy() | |
| cv2.fillPoly(overlay, contours, stroke_color) | |
| mask_3d = np.stack([mask] * 3, axis=-1).astype(bool) | |
| annotated[mask_3d] = ( | |
| (1 - fill_alpha) * annotated[mask_3d].astype(np.float32) | |
| + fill_alpha * overlay[mask_3d].astype(np.float32) | |
| ).astype(np.uint8) | |
| # Orange contour | |
| cv2.drawContours(annotated, contours, -1, stroke_color, stroke_width) | |
| return annotated | |
| def _parse_canvas_shapes_to_mask(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w): | |
| """ | |
| Parse drawn shapes from streamlit-drawable-canvas json_data and create a binary mask | |
| in heatmap coordinates. Returns (mask, num_shapes) or (None, 0) if no valid shapes. | |
| """ | |
| if not json_data or "objects" not in json_data or not json_data["objects"]: | |
| return None, 0 | |
| scale_x = heatmap_w / canvas_w | |
| scale_y = heatmap_h / canvas_h | |
| mask = np.zeros((heatmap_h, heatmap_w), dtype=np.uint8) | |
| count = 0 | |
| for obj in json_data["objects"]: | |
| obj_type = obj.get("type", "") | |
| pts = [] | |
| if obj_type == "rect": | |
| left = obj.get("left", 0) | |
| top = obj.get("top", 0) | |
| w = obj.get("width", 0) | |
| h = obj.get("height", 0) | |
| pts = np.array([ | |
| [left, top], [left + w, top], [left + w, top + h], [left, top + h] | |
| ], dtype=np.float32) | |
| elif obj_type == "circle" or obj_type == "ellipse": | |
| left = obj.get("left", 0) | |
| top = obj.get("top", 0) | |
| width = obj.get("width", 0) | |
| height = obj.get("height", 0) | |
| radius = obj.get("radius", 0) | |
| angle_deg = obj.get("angle", 0) | |
| if radius > 0: | |
| # Circle: (left, top) is mouse start point, not center. | |
| # Center = start + radius * (cos(angle), sin(angle)) | |
| rx = ry = radius | |
| angle_rad = np.deg2rad(angle_deg) | |
| cx = left + radius * np.cos(angle_rad) | |
| cy = top + radius * np.sin(angle_rad) | |
| else: | |
| # Ellipse: left, top = top-left of bounding box | |
| rx = width / 2 if width > 0 else 0 | |
| ry = height / 2 if height > 0 else 0 | |
| if rx <= 0 or ry <= 0: | |
| continue | |
| cx = left + rx | |
| cy = top + ry | |
| if rx <= 0 or ry <= 0: | |
| continue | |
| n = 32 | |
| angles = np.linspace(0, 2 * np.pi, n, endpoint=False) | |
| pts = np.column_stack([cx + rx * np.cos(angles), cy + ry * np.sin(angles)]).astype(np.float32) | |
| elif obj_type == "path": | |
| path = obj.get("path", []) | |
| for cmd in path: | |
| if isinstance(cmd, (list, tuple)) and len(cmd) >= 3: | |
| if cmd[0] in ("M", "L"): | |
| pts.append([float(cmd[1]), float(cmd[2])]) | |
| elif cmd[0] == "Q" and len(cmd) >= 5: | |
| pts.append([float(cmd[3]), float(cmd[4])]) | |
| elif cmd[0] == "C" and len(cmd) >= 7: | |
| pts.append([float(cmd[5]), float(cmd[6])]) | |
| if len(pts) < 3: | |
| continue | |
| pts = np.array(pts, dtype=np.float32) | |
| else: | |
| continue | |
| pts[:, 0] *= scale_x | |
| pts[:, 1] *= scale_y | |
| pts = np.clip(pts, 0, [heatmap_w - 1, heatmap_h - 1]).astype(np.int32) | |
| cv2.fillPoly(mask, [pts], 1) | |
| count += 1 | |
| return mask, count | |
| def _heatmap_to_rgb(scaled_heatmap, colormap_name="Jet"): | |
| """Convert scaled heatmap (float 0-1) to RGB array using the given colormap.""" | |
| heatmap_uint8 = (np.clip(scaled_heatmap, 0, 1) * 255).astype(np.uint8) | |
| cv2_colormap = COLORMAPS.get(colormap_name, cv2.COLORMAP_JET) | |
| heatmap_rgb = cv2.cvtColor(cv2.applyColorMap(heatmap_uint8, cv2_colormap), cv2.COLOR_BGR2RGB) | |
| return heatmap_rgb | |
| def _heatmap_to_png_bytes(scaled_heatmap, colormap_name="Jet"): | |
| """Convert scaled heatmap (float 0-1) to PNG bytes buffer.""" | |
| heatmap_rgb = _heatmap_to_rgb(scaled_heatmap, colormap_name) | |
| buf = io.BytesIO() | |
| Image.fromarray(heatmap_rgb).save(buf, format="PNG") | |
| buf.seek(0) | |
| return buf | |
| def _create_pdf_report(img, scaled_heatmap, pixel_sum, force, force_scale, base_name, colormap_name="Jet"): | |
| """Create a PDF report with input image, heatmap, and metrics.""" | |
| from datetime import datetime | |
| from reportlab.lib.pagesizes import A4 | |
| from reportlab.lib.units import inch | |
| from reportlab.pdfgen import canvas | |
| from reportlab.lib.utils import ImageReader | |
| buf = io.BytesIO() | |
| c = canvas.Canvas(buf, pagesize=A4) | |
| c.setTitle("Shape2Force") | |
| c.setAuthor("Angione-Lab") | |
| w, h = A4 | |
| img_w, img_h = 2.5 * inch, 2.5 * inch | |
| # Footer area (reserve space at bottom) | |
| footer_y = 40 | |
| c.setFont("Helvetica", 8) | |
| c.setFillColorRGB(0.4, 0.4, 0.4) | |
| gen_date = datetime.now().strftime("%Y-%m-%d %H:%M") | |
| c.drawString(72, footer_y, f"Generated by Shape2Force (S2F) on {gen_date}") | |
| c.drawString(72, footer_y - 12, "Model: https://huggingface.co/Angione-Lab/Shape2Force") | |
| c.drawString(72, footer_y - 24, "Web app: https://huggingface.co/spaces/Angione-Lab/Shape2force") | |
| c.setFillColorRGB(0, 0, 0) | |
| # Images first (drawn lower so title can go on top) | |
| img_top = h - 70 | |
| img_pil = Image.fromarray(img) if img.ndim == 2 else Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | |
| img_buf = io.BytesIO() | |
| img_pil.save(img_buf, format="PNG") | |
| img_buf.seek(0) | |
| c.drawImage(ImageReader(img_buf), 72, img_top - img_h, width=img_w, height=img_h, preserveAspectRatio=True) | |
| c.setFont("Helvetica", 9) | |
| c.drawString(72, img_top - img_h - 12, "Input: Bright-field") | |
| heatmap_rgb = _heatmap_to_rgb(scaled_heatmap, colormap_name) | |
| hm_buf = io.BytesIO() | |
| Image.fromarray(heatmap_rgb).save(hm_buf, format="PNG") | |
| hm_buf.seek(0) | |
| c.drawImage(ImageReader(hm_buf), 72 + img_w + 20, img_top - img_h, width=img_w, height=img_h, preserveAspectRatio=True) | |
| c.drawString(72 + img_w + 20, img_top - img_h - 12, "Output: Force map") | |
| # Title above images | |
| c.setFont("Helvetica-Bold", 16) | |
| c.drawString(72, img_top + 25, "Shape2Force (S2F) - Prediction Report") | |
| c.setFont("Helvetica", 10) | |
| c.drawString(72, img_top + 8, f"Image: {base_name}") | |
| # Metrics table below images | |
| y = img_top - img_h - 45 | |
| c.setFont("Helvetica-Bold", 10) | |
| c.drawString(72, y, "Metrics") | |
| c.setFont("Helvetica", 9) | |
| y -= 18 | |
| metrics = [ | |
| ("Sum of all pixels", f"{pixel_sum * force_scale:.2f}"), | |
| ("Cell force (scaled)", f"{force * force_scale:.2f}"), | |
| ("Heatmap max", f"{np.max(scaled_heatmap):.4f}"), | |
| ("Heatmap mean", f"{np.mean(scaled_heatmap):.4f}"), | |
| ] | |
| for label, val in metrics: | |
| c.drawString(72, y, f"{label}: {val}") | |
| y -= 16 | |
| c.save() | |
| buf.seek(0) | |
| return buf.getvalue() | |
| def _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale): | |
| """Build original_vals dict for measure tool.""" | |
| return { | |
| "pixel_sum": pixel_sum * force_scale, | |
| "force": force * force_scale, | |
| "max": float(np.max(scaled_heatmap)), | |
| "mean": float(np.mean(scaled_heatmap)), | |
| } | |
| def _render_result_display(img, scaled_heatmap, pixel_sum, force, force_scale, key_img, download_key_suffix="", colormap_name="Jet"): | |
| """Render prediction result: plot, metrics, expander, and download/measure buttons.""" | |
| buf_hm = _heatmap_to_png_bytes(scaled_heatmap, colormap_name) | |
| base_name = os.path.splitext(key_img or "image")[0] | |
| main_csv_rows = [ | |
| ["image", "Sum of all pixels", "Cell force (scaled)", "Heatmap max", "Heatmap mean"], | |
| [base_name, f"{pixel_sum * force_scale:.2f}", f"{force * force_scale:.2f}", | |
| f"{np.max(scaled_heatmap):.4f}", f"{np.mean(scaled_heatmap):.4f}"], | |
| ] | |
| buf_main_csv = io.StringIO() | |
| csv.writer(buf_main_csv).writerows(main_csv_rows) | |
| tit1, tit2 = st.columns(2) | |
| with tit1: | |
| st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Input: Bright-field image</p>', unsafe_allow_html=True) | |
| with tit2: | |
| st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True) | |
| fig_pl = make_subplots(rows=1, cols=2) | |
| fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1) | |
| plotly_colorscale = _cv_colormap_to_plotly_colorscale(colormap_name) | |
| fig_pl.add_trace(go.Heatmap(z=scaled_heatmap, colorscale=plotly_colorscale, zmin=0, zmax=1, showscale=True, | |
| colorbar=dict(len=0.4, thickness=12)), row=1, col=2) | |
| fig_pl.update_layout( | |
| height=400, | |
| margin=dict(l=10, r=10, t=10, b=10), | |
| xaxis=dict(scaleanchor="y", scaleratio=1), | |
| xaxis2=dict(scaleanchor="y2", scaleratio=1), | |
| ) | |
| fig_pl.update_xaxes(showticklabels=False) | |
| fig_pl.update_yaxes(showticklabels=False, autorange="reversed") | |
| st.plotly_chart(fig_pl, use_container_width=True, config={"displayModeBar": True, "responsive": True}) | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| st.metric("Sum of all pixels", f"{pixel_sum * force_scale:.2f}", help="Raw sum of all pixel values in the force map") | |
| with col2: | |
| st.metric("Cell force (scaled)", f"{force * force_scale:.2f}", help="Total traction force in physical units") | |
| with col3: | |
| st.metric("Heatmap max", f"{np.max(scaled_heatmap):.4f}", help="Peak force intensity in the map") | |
| with col4: | |
| st.metric("Heatmap mean", f"{np.mean(scaled_heatmap):.4f}", help="Average force intensity") | |
| with st.expander("How to read the results"): | |
| st.markdown(""" | |
| **Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate. | |
| This is the raw image you provided—it shows cell shape but not forces. | |
| **Output (right):** Predicted traction force map. | |
| - **Color** indicates force magnitude: blue = low, red = high | |
| - **Brighter/warmer colors** = stronger forces exerted by the cell on the substrate | |
| - Values are normalized to [0, 1] for visualization | |
| **Metrics:** | |
| - **Sum of all pixels:** Total force is the sum of all pixels in the force map. Each pixel represents the magnitude of force at that location; summing them gives the overall traction. | |
| - **Cell force (scaled):** Total traction force in physical units (scaled by substrate stiffness) | |
| - **Heatmap max/mean:** Peak and average force intensity in the map | |
| """) | |
| original_vals = _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale) | |
| pdf_bytes = _create_pdf_report(img, scaled_heatmap, pixel_sum, force, force_scale, base_name, colormap_name) | |
| btn_col1, btn_col2, btn_col3, btn_col4 = st.columns(4) | |
| with btn_col1: | |
| if HAS_DRAWABLE_CANVAS and st_dialog: | |
| if st.button("Measure tool", key="open_measure", icon=":material/straighten:"): | |
| st.session_state["open_measure_dialog"] = True | |
| st.rerun() | |
| elif HAS_DRAWABLE_CANVAS: | |
| with st.expander("Measure tool"): | |
| _render_region_canvas( | |
| scaled_heatmap, | |
| bf_img=img, | |
| original_vals=original_vals, | |
| key_suffix="expander", | |
| input_filename=key_img, | |
| colormap_name=colormap_name, | |
| ) | |
| else: | |
| st.caption("Install `streamlit-drawable-canvas-fix` for region measurement: `pip install streamlit-drawable-canvas-fix`") | |
| with btn_col2: | |
| st.download_button( | |
| "Download heatmap", | |
| width="stretch", | |
| data=buf_hm.getvalue(), | |
| file_name="s2f_heatmap.png", | |
| mime="image/png", | |
| key=f"download_heatmap{download_key_suffix}", | |
| icon=":material/download:", | |
| ) | |
| with btn_col3: | |
| st.download_button( | |
| "Download values", | |
| width="stretch", | |
| data=buf_main_csv.getvalue(), | |
| file_name=f"{base_name}_main_values.csv", | |
| mime="text/csv", | |
| key=f"download_main_values{download_key_suffix}", | |
| icon=":material/download:", | |
| ) | |
| with btn_col4: | |
| st.download_button( | |
| "Download report", | |
| width="stretch", | |
| data=pdf_bytes, | |
| file_name=f"{base_name}_report.pdf", | |
| mime="application/pdf", | |
| key=f"download_pdf{download_key_suffix}", | |
| icon=":material/picture_as_pdf:", | |
| ) | |
| def _compute_region_metrics(scaled_heatmap, mask, original_vals=None): | |
| """Compute region metrics from mask. Returns dict with area_px, force_sum, density, etc.""" | |
| area_px = int(np.sum(mask)) | |
| region_values = scaled_heatmap * mask | |
| region_nonzero = region_values[mask > 0] | |
| force_sum = float(np.sum(region_values)) | |
| density = force_sum / area_px if area_px > 0 else 0 | |
| region_max = float(np.max(region_nonzero)) if len(region_nonzero) > 0 else 0 | |
| region_mean = float(np.mean(region_nonzero)) if len(region_nonzero) > 0 else 0 | |
| region_force_scaled = ( | |
| force_sum * (original_vals["force"] / original_vals["pixel_sum"]) | |
| if original_vals and original_vals.get("pixel_sum", 0) > 0 | |
| else force_sum | |
| ) | |
| return { | |
| "area_px": area_px, | |
| "force_sum": force_sum, | |
| "density": density, | |
| "max": region_max, | |
| "mean": region_mean, | |
| "force_scaled": region_force_scaled, | |
| } | |
| def _render_region_metrics_and_downloads(metrics, heatmap_rgb, mask, input_filename, key_suffix, has_original_vals): | |
| """Render region metrics and download buttons.""" | |
| base_name = os.path.splitext(input_filename or "image")[0] | |
| st.markdown("**Region (drawn)**") | |
| if has_original_vals: | |
| r1, r2, r3, r4, r5, r6 = st.columns(6) | |
| with r1: | |
| st.metric("Area", f"{metrics['area_px']:,}") | |
| with r2: | |
| st.metric("F.sum", f"{metrics['force_sum']:.3f}") | |
| with r3: | |
| st.metric("Force", f"{metrics['force_scaled']:.1f}") | |
| with r4: | |
| st.metric("Max", f"{metrics['max']:.3f}") | |
| with r5: | |
| st.metric("Mean", f"{metrics['mean']:.3f}") | |
| with r6: | |
| st.metric("Density", f"{metrics['density']:.4f}") | |
| csv_rows = [ | |
| ["image", "Area", "F.sum", "Force", "Max", "Mean", "Density"], | |
| [base_name, metrics["area_px"], f"{metrics['force_sum']:.3f}", f"{metrics['force_scaled']:.1f}", | |
| f"{metrics['max']:.3f}", f"{metrics['mean']:.3f}", f"{metrics['density']:.4f}"], | |
| ] | |
| else: | |
| c1, c2, c3 = st.columns(3) | |
| with c1: | |
| st.metric("Area (px²)", f"{metrics['area_px']:,}") | |
| with c2: | |
| st.metric("Force sum", f"{metrics['force_sum']:.4f}") | |
| with c3: | |
| st.metric("Density", f"{metrics['density']:.6f}") | |
| csv_rows = [ | |
| ["image", "Area", "Force sum", "Density"], | |
| [base_name, metrics["area_px"], f"{metrics['force_sum']:.4f}", f"{metrics['density']:.6f}"], | |
| ] | |
| buf_csv = io.StringIO() | |
| csv.writer(buf_csv).writerows(csv_rows) | |
| buf_img = io.BytesIO() | |
| Image.fromarray(_make_annotated_heatmap(heatmap_rgb, mask)).save(buf_img, format="PNG") | |
| buf_img.seek(0) | |
| dl_col1, dl_col2 = st.columns(2) | |
| with dl_col1: | |
| st.download_button("Download values", data=buf_csv.getvalue(), | |
| file_name=f"{base_name}_region_values.csv", mime="text/csv", | |
| key=f"download_region_values_{key_suffix}", icon=":material/download:") | |
| with dl_col2: | |
| st.download_button("Download annotated heatmap", data=buf_img.getvalue(), | |
| file_name=f"{base_name}_annotated_heatmap.png", mime="image/png", | |
| key=f"download_annotated_{key_suffix}", icon=":material/image:") | |
| def _render_region_canvas(scaled_heatmap, bf_img=None, original_vals=None, key_suffix="", input_filename=None, colormap_name="Jet"): | |
| """Render drawable canvas and region metrics. Used in dialog or expander.""" | |
| h, w = scaled_heatmap.shape | |
| heatmap_rgb = _heatmap_to_rgb(scaled_heatmap, colormap_name) | |
| pil_bg = Image.fromarray(heatmap_rgb).resize((CANVAS_SIZE, CANVAS_SIZE), Image.Resampling.LANCZOS) | |
| st.markdown(""" | |
| <style> | |
| [data-testid="stDialog"] [data-testid="stSelectbox"], [data-testid="stExpander"] [data-testid="stSelectbox"], | |
| [data-testid="stDialog"] [data-testid="stSelectbox"] > div, [data-testid="stExpander"] [data-testid="stSelectbox"] > div { | |
| width: 100% !important; max-width: 100% !important; | |
| } | |
| [data-testid="stDialog"] [data-testid="stMetric"] label, [data-testid="stDialog"] [data-testid="stMetric"] [data-testid="stMetricValue"], | |
| [data-testid="stExpander"] [data-testid="stMetric"] label, [data-testid="stExpander"] [data-testid="stMetric"] [data-testid="stMetricValue"] { | |
| font-size: 0.95rem !important; | |
| } | |
| [data-testid="stDialog"] img, [data-testid="stExpander"] img { border-radius: 0 !important; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| if bf_img is not None: | |
| bf_resized = cv2.resize(bf_img, (CANVAS_SIZE, CANVAS_SIZE)) | |
| bf_rgb = cv2.cvtColor(bf_resized, cv2.COLOR_GRAY2RGB) if bf_img.ndim == 2 else cv2.cvtColor(bf_resized, cv2.COLOR_BGR2RGB) | |
| left_col, right_col = st.columns(2, gap=None) | |
| with left_col: | |
| draw_mode = st.selectbox("Tool", DRAW_TOOLS, format_func=lambda x: TOOL_LABELS[x], key=f"draw_mode_region_{key_suffix}") | |
| st.caption("Left-click add, right-click close. \nForce map (draw region)") | |
| canvas_result = st_canvas( | |
| fill_color="rgba(255, 165, 0, 0.3)", stroke_width=2, stroke_color="#ff6600", | |
| background_image=pil_bg, drawing_mode=draw_mode, update_streamlit=True, | |
| height=CANVAS_SIZE, width=CANVAS_SIZE, display_toolbar=True, | |
| key=f"region_measure_canvas_{key_suffix}", | |
| ) | |
| with right_col: | |
| if original_vals: | |
| st.markdown('<p style="font-weight: 400; color: #334155; font-size: 0.95rem; margin: 0 20px 4px 4px;">Full map</p>', unsafe_allow_html=True) | |
| st.markdown(f""" | |
| <div style="width: 100%; box-sizing: border-box; border: 1px solid #e2e8f0; border-radius: 10px; | |
| padding: 10px 12px; margin: 0 10px 20px 10px; background: linear-gradient(145deg, #f8fafc 0%, #f1f5f9 100%); | |
| box-shadow: 0 1px 3px rgba(0,0,0,0.06);"> | |
| <div style="display: flex; flex-wrap: wrap; gap: 5px; font-size: 0.9rem;"> | |
| <span><strong>Sum:</strong> {original_vals['pixel_sum']:.1f}</span> | |
| <span><strong>Force:</strong> {original_vals['force']:.1f}</span> | |
| <span><strong>Max:</strong> {original_vals['max']:.3f}</span> | |
| <span><strong>Mean:</strong> {original_vals['mean']:.3f}</span> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.caption("Bright-field") | |
| st.image(bf_rgb, width=CANVAS_SIZE) | |
| else: | |
| st.markdown("**Draw a region** on the heatmap.") | |
| draw_mode = st.selectbox("Drawing tool", DRAW_TOOLS, | |
| format_func=lambda x: "Polygon (free shape)" if x == "polygon" else TOOL_LABELS[x], | |
| key=f"draw_mode_region_{key_suffix}") | |
| st.caption("Polygon: left-click to add points, right-click to close.") | |
| canvas_result = st_canvas( | |
| fill_color="rgba(255, 165, 0, 0.3)", stroke_width=2, stroke_color="#ff6600", | |
| background_image=pil_bg, drawing_mode=draw_mode, update_streamlit=True, | |
| height=CANVAS_SIZE, width=CANVAS_SIZE, display_toolbar=True, | |
| key=f"region_measure_canvas_{key_suffix}", | |
| ) | |
| if canvas_result.json_data: | |
| mask, n = _parse_canvas_shapes_to_mask(canvas_result.json_data, CANVAS_SIZE, CANVAS_SIZE, h, w) | |
| if mask is not None and n > 0: | |
| metrics = _compute_region_metrics(scaled_heatmap, mask, original_vals) | |
| _render_region_metrics_and_downloads(metrics, heatmap_rgb, mask, input_filename, key_suffix, original_vals is not None) | |
| st_dialog = getattr(st, "dialog", None) or getattr(st, "experimental_dialog", None) | |
| if HAS_DRAWABLE_CANVAS and st_dialog: | |
| def measure_region_dialog(): | |
| scaled_heatmap = st.session_state.get("measure_scaled_heatmap") | |
| if scaled_heatmap is None: | |
| st.warning("No prediction available to measure.") | |
| return | |
| bf_img = st.session_state.get("measure_bf_img") | |
| original_vals = st.session_state.get("measure_original_vals") | |
| input_filename = st.session_state.get("measure_input_filename", "image") | |
| colormap_name = st.session_state.get("measure_colormap", "Jet") | |
| _render_region_canvas(scaled_heatmap, bf_img=bf_img, original_vals=original_vals, key_suffix="dialog", input_filename=input_filename, colormap_name=colormap_name) | |
| else: | |
| def measure_region_dialog(): | |
| pass # no-op when canvas or dialog not available | |
| st.set_page_config(page_title="Shape2Force (S2F)", page_icon="🦠", layout="centered") | |
| st.markdown(""" | |
| <style> | |
| section[data-testid="stSidebar"] { width: 380px !important; } | |
| div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) > div { | |
| flex: 1 1 0 !important; min-width: 0 !important; | |
| } | |
| div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) button { | |
| width: 100% !important; min-width: 100px !important; white-space: nowrap !important; | |
| } | |
| div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) > div:nth-child(1) button { | |
| background-color: #0d9488 !important; color: white !important; border-color: #0d9488 !important; | |
| } | |
| div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) > div:nth-child(1) button:hover { | |
| background-color: #0f766e !important; border-color: #0f766e !important; color: white !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.title("🦠 Shape2Force (S2F)") | |
| st.caption("Predict traction force maps from bright-field microscopy images of cells or spheroids") | |
| # Folders: checkpoints in subfolders by model type (single_cell / spheroid) | |
| ckp_base = os.path.join(S2F_ROOT, "ckp") | |
| # Fallback: use project root ckp when running from S2F repo (ckp at S2F/ckp/) | |
| if not os.path.isdir(ckp_base): | |
| project_root = os.path.dirname(S2F_ROOT) | |
| if os.path.isdir(os.path.join(project_root, "ckp")): | |
| ckp_base = os.path.join(project_root, "ckp") | |
| ckp_single_cell = os.path.join(ckp_base, "single_cell") | |
| ckp_spheroid = os.path.join(ckp_base, "spheroid") | |
| sample_base = os.path.join(S2F_ROOT, "samples") | |
| sample_single_cell = os.path.join(sample_base, "single_cell") | |
| sample_spheroid = os.path.join(sample_base, "spheroid") | |
| def get_ckp_files_for_model(model_type): | |
| """Return list of .pth files in the checkpoint folder for the given model type.""" | |
| folder = ckp_single_cell if model_type == "single_cell" else ckp_spheroid | |
| if os.path.isdir(folder): | |
| return sorted(f for f in os.listdir(folder) if f.endswith(".pth")) | |
| return [] | |
| def get_sample_files_for_model(model_type): | |
| """Return list of sample images in the sample folder for the given model type.""" | |
| folder = sample_single_cell if model_type == "single_cell" else sample_spheroid | |
| if os.path.isdir(folder): | |
| return sorted(f for f in os.listdir(folder) if f.lower().endswith(SAMPLE_EXTENSIONS)) | |
| return [] | |
| # Sidebar: model configuration | |
| with st.sidebar: | |
| st.header("Model configuration") | |
| model_type = st.radio( | |
| "Model type", | |
| ["single_cell", "spheroid"], | |
| format_func=lambda x: MODEL_TYPE_LABELS[x], | |
| horizontal=False, | |
| help="Single cell: substrate-aware force prediction. Spheroid: spheroid force maps.", | |
| ) | |
| st.caption(f"Inference mode: **{MODEL_TYPE_LABELS[model_type]}**") | |
| ckp_files = get_ckp_files_for_model(model_type) | |
| ckp_folder = ckp_single_cell if model_type == "single_cell" else ckp_spheroid | |
| ckp_subfolder_name = "single_cell" if model_type == "single_cell" else "spheroid" | |
| if ckp_files: | |
| checkpoint = st.selectbox( | |
| "Checkpoint", | |
| ckp_files, | |
| help=f"Select a .pth file from ckp/{ckp_subfolder_name}/", | |
| ) | |
| else: | |
| st.warning(f"No .pth files in ckp/{ckp_subfolder_name}/. Add checkpoints to load.") | |
| checkpoint = None | |
| substrate_config = None | |
| substrate_val = "fibroblasts_PDMS" | |
| use_manual = False | |
| if model_type == "single_cell": | |
| try: | |
| substrates = list_substrates() | |
| substrate_val = st.selectbox( | |
| "Substrate (from config)", | |
| substrates, | |
| help="Select a preset from config/substrate_settings.json", | |
| ) | |
| use_manual = st.checkbox("Enter substrate values manually", value=False) | |
| if use_manual: | |
| st.caption("Enter pixelsize (µm/px) and Young's modulus (Pa)") | |
| manual_pixelsize = st.number_input("Pixelsize (µm/px)", min_value=0.1, max_value=50.0, | |
| value=3.0769, step=0.1, format="%.4f") | |
| manual_young = st.number_input("Young's modulus (Pa)", min_value=100.0, max_value=100000.0, | |
| value=6000.0, step=100.0, format="%.0f") | |
| substrate_config = {"pixelsize": manual_pixelsize, "young": manual_young} | |
| except FileNotFoundError: | |
| st.error("config/substrate_settings.json not found") | |
| st.divider() | |
| st.header("Display options") | |
| force_scale = st.slider( | |
| "Force scale", | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=1.0, | |
| step=0.01, | |
| format="%.2f", | |
| help="Scale the displayed force values. 1 = full intensity, 0.5 = half the pixel values.", | |
| ) | |
| colormap_name = st.selectbox( | |
| "Heatmap colormap", | |
| list(COLORMAPS.keys()), | |
| help="Color scheme for the force map. Viridis is often preferred for accessibility.", | |
| ) | |
| # Main area: image input | |
| img_source = st.radio("Image source", ["Upload", "Example"], horizontal=True, label_visibility="collapsed") | |
| img = None | |
| uploaded = None | |
| selected_sample = None | |
| if img_source == "Upload": | |
| uploaded = st.file_uploader( | |
| "Upload bright-field image", | |
| type=["tif", "tiff", "png", "jpg", "jpeg"], | |
| help="Bright-field microscopy image of a cell or spheroid on a substrate (grayscale or RGB). The model will predict traction forces from the cell shape.", | |
| ) | |
| if uploaded: | |
| bytes_data = uploaded.read() | |
| nparr = np.frombuffer(bytes_data, np.uint8) | |
| img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE) | |
| uploaded.seek(0) # reset for potential re-read | |
| else: | |
| sample_files = get_sample_files_for_model(model_type) | |
| sample_folder = sample_single_cell if model_type == "single_cell" else sample_spheroid | |
| sample_subfolder_name = "single_cell" if model_type == "single_cell" else "spheroid" | |
| if sample_files: | |
| selected_sample = st.selectbox( | |
| f"Select example image (from `samples/{sample_subfolder_name}/`)", | |
| sample_files, | |
| format_func=lambda x: x, | |
| key=f"sample_{model_type}", | |
| ) | |
| if selected_sample: | |
| sample_path = os.path.join(sample_folder, selected_sample) | |
| img = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE) | |
| # Show example thumbnails (filtered by model type) | |
| n_cols = min(5, len(sample_files)) | |
| cols = st.columns(n_cols) | |
| for i, fname in enumerate(sample_files[:8]): # show up to 8 | |
| with cols[i % n_cols]: | |
| path = os.path.join(sample_folder, fname) | |
| sample_img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) | |
| if sample_img is not None: | |
| st.image(sample_img, caption=fname, width='content') | |
| else: | |
| st.info(f"No example images in samples/{sample_subfolder_name}/. Add images or use Upload.") | |
| col_btn, col_model, col_path = st.columns([1, 1, 3]) | |
| with col_btn: | |
| run = st.button("Run prediction", type="primary") | |
| with col_model: | |
| st.markdown(f"<span style='display: inline-flex; align-items: center; height: 38px;'>{MODEL_TYPE_LABELS[model_type]}</span>", unsafe_allow_html=True) | |
| with col_path: | |
| ckp_path = f"ckp/{ckp_subfolder_name}/{checkpoint}" if checkpoint else f"ckp/{ckp_subfolder_name}/" | |
| st.markdown(f"<span style='display: inline-flex; align-items: center; height: 38px;'>Checkpoint: <code>{ckp_path}</code></span>", unsafe_allow_html=True) | |
| has_image = img is not None | |
| # Persist results in session state so they survive re-runs (e.g. when clicking Download) | |
| if "prediction_result" not in st.session_state: | |
| st.session_state["prediction_result"] = None | |
| # Show results if we just ran prediction OR we have cached results from a previous run | |
| just_ran = run and checkpoint and has_image | |
| cached = st.session_state["prediction_result"] | |
| key_img = (uploaded.name if uploaded else None) if img_source == "Upload" else selected_sample | |
| current_key = (model_type, checkpoint, key_img) | |
| has_cached = cached is not None and cached.get("cache_key") == current_key | |
| if just_ran: | |
| st.session_state["prediction_result"] = None # Clear before new run | |
| with st.spinner("Loading model and predicting..."): | |
| try: | |
| from predictor import S2FPredictor | |
| predictor = S2FPredictor( | |
| model_type=model_type, | |
| checkpoint_path=checkpoint, | |
| ckp_folder=ckp_folder, | |
| ) | |
| sub_val = substrate_val if model_type == "single_cell" and not use_manual else "fibroblasts_PDMS" | |
| heatmap, force, pixel_sum = predictor.predict( | |
| image_array=img, | |
| substrate=sub_val, | |
| substrate_config=substrate_config if model_type == "single_cell" else None, | |
| ) | |
| st.success("Prediction complete!") | |
| scaled_heatmap = heatmap * force_scale | |
| cache_key = (model_type, checkpoint, key_img) | |
| st.session_state["prediction_result"] = { | |
| "img": img.copy(), | |
| "heatmap": heatmap.copy(), | |
| "force": force, | |
| "pixel_sum": pixel_sum, | |
| "cache_key": cache_key, | |
| } | |
| st.session_state["measure_scaled_heatmap"] = scaled_heatmap.copy() | |
| st.session_state["measure_bf_img"] = img.copy() | |
| st.session_state["measure_input_filename"] = key_img or "image" | |
| st.session_state["measure_original_vals"] = _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale) | |
| st.session_state["measure_colormap"] = colormap_name | |
| _render_result_display(img, scaled_heatmap, pixel_sum, force, force_scale, key_img, colormap_name=colormap_name) | |
| except Exception as e: | |
| st.error(f"Prediction failed: {e}") | |
| st.code(traceback.format_exc()) | |
| elif has_cached: | |
| r = st.session_state["prediction_result"] | |
| img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"] | |
| scaled_heatmap = heatmap * force_scale | |
| st.session_state["measure_scaled_heatmap"] = scaled_heatmap.copy() | |
| st.session_state["measure_bf_img"] = img.copy() | |
| st.session_state["measure_input_filename"] = key_img or "image" | |
| st.session_state["measure_original_vals"] = _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale) | |
| st.session_state["measure_colormap"] = colormap_name | |
| if st.session_state.pop("open_measure_dialog", False): | |
| measure_region_dialog() | |
| st.success("Prediction complete!") | |
| _render_result_display(img, scaled_heatmap, pixel_sum, force, force_scale, key_img, download_key_suffix="_cached", colormap_name=colormap_name) | |
| elif run and not checkpoint: | |
| st.warning("Please add checkpoint files to the ckp/ folder and select one.") | |
| elif run and not has_image: | |
| st.warning("Please upload an image or select an example.") | |
| st.sidebar.divider() | |
| st.sidebar.caption(f"Examples: `samples/{ckp_subfolder_name}/`") | |
| st.sidebar.caption("If you find this software useful, please cite:") | |
| st.sidebar.caption(CITATION) | |