MyoSeg / src /streamlit_app.py
skarugu's picture
Update src/streamlit_app.py
08c400d verified
# src/streamlit_app.py
"""
MyoSeg β€” Myotube & Nuclei Analyser
========================================
Drop-in replacement for streamlit_app.py on Hugging Face Spaces.
Features:
✦ Animated count-up metrics (9 counters)
✦ Instance overlay β€” nucleus IDs (1,2,3…) + myotube IDs (M1,M2…)
✦ Separate nuclei outline + myotube outline tabs
✦ Watershed nuclei splitting for accurate counts
✦ Myotube surface area (total, mean, max ¡m²) + per-tube bar chart
✦ Active learning β€” upload corrected masks β†’ saved to corrections/
✦ Low-confidence auto-flagging β†’ image queued for retraining
✦ Retraining queue status panel
✦ All original sidebar controls preserved
v9 changes:
✦ FIXED: SVG viewer myotube ID count now matches live metrics count.
Root cause: viewer showed all connected components (myo_lab), but
metrics only counted those with β‰₯1 MyHC+ nucleus. Now the viewer
badge shows the biological myotube_count from compute_bio_metrics,
and non-bio myotube regions are shown as faint outlines (not labelled).
✦ Outlines split into two separate tabs: "Nuclei outlines" and
"Myotube outlines" per collaborator request.
✦ Privacy mode: sidebar toggle for "Private mode β€” do not use my data
for training". When enabled, images are NOT queued for retraining
(no low_confidence queue, no corrections submission).
✦ Training contribution mode: explicit user-initiated action to submit
current image + tuned parameters as a training contribution. Only
runs when user clicks "Submit for training" after finding good params.
✦ Parameter learning: when user submits, the current sidebar parameter
set (thresholds, postprocessing knobs) is saved alongside the image
so self_train.py can learn optimal parameters per image type.
✦ All v8 fixes preserved (no closing, shape filter, erode+dilate, etc).
"""
import io
import os
import json
import time
import zipfile
import hashlib
from datetime import datetime
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image
import streamlit as st
import streamlit.components.v1
import torch
import torch.nn as nn
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from huggingface_hub import hf_hub_download
import scipy.ndimage as ndi
from skimage.morphology import remove_small_objects, disk, closing, opening, binary_dilation, binary_erosion
from skimage import measure
from skimage.segmentation import watershed, find_boundaries
from skimage.feature import peak_local_max
# ─────────────────────────────────────────────────────────────────────────────
# CONFIG ← edit these two lines to match your HF model repo
# ─────────────────────────────────────────────────────────────────────────────
MODEL_REPO_ID = "skarugu/myotube-unet"
MODEL_FILENAME = "model_final.pt"
CONF_FLAG_THR = 0.60 # images below this confidence are queued for retraining
QUEUE_DIR = Path("retrain_queue")
CORRECTIONS_DIR = Path("corrections")
# ─────────────────────────────────────────────────────────────────────────────
# Helpers (identical to originals so nothing breaks)
# ─────────────────────────────────────────────────────────────────────────────
def sha256_file(path: str) -> str:
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
h.update(chunk)
return h.hexdigest()
def png_bytes(arr_u8: np.ndarray) -> bytes:
buf = io.BytesIO()
Image.fromarray(arr_u8).save(buf, format="PNG")
return buf.getvalue()
def resize_u8_to_float01(ch_u8: np.ndarray, W: int, H: int,
resample=Image.BILINEAR) -> np.ndarray:
im = Image.fromarray(ch_u8, mode="L").resize((W, H), resample=resample)
return np.array(im, dtype=np.float32) / 255.0
def get_channel(rgb_u8: np.ndarray, source: str) -> np.ndarray:
if source == "Red": return rgb_u8[..., 0]
if source == "Green": return rgb_u8[..., 1]
if source == "Blue": return rgb_u8[..., 2]
return (0.299*rgb_u8[...,0] + 0.587*rgb_u8[...,1] + 0.114*rgb_u8[...,2]).astype(np.uint8)
def hex_to_rgb(h: str):
h = h.lstrip("#")
return tuple(int(h[i:i+2], 16) for i in (0, 2, 4))
# ─────────────────────────────────────────────────────────────────────────────
# Postprocessing
# ─────────────────────────────────────────────────────────────────────────────
def postprocess_masks(nuc_mask, myo_mask,
min_nuc_area=20, min_myo_area=500,
nuc_close_radius=2,
myo_open_radius=2,
myo_erode_radius=2,
min_myo_aspect_ratio=0.0):
"""
Clean up raw predicted masks.
v8 β€” unified postprocessing (matches training script):
Nuclei: optional closing to fill gaps, then remove small objects.
Myotubes: opening (noise removal) β†’ erode+dilate (bridge breaking) β†’
size filter β†’ optional aspect-ratio shape filter.
NO closing for myotubes β€” closing merges adjacent myotubes into single
connected components, causing severe undercounting in dense cultures.
Validation showed r=0.245 with closing vs manual counts.
myo_open_radius β€” disk radius for morphological opening. Removes small
noise/debris without merging separate objects.
myo_erode_radius β€” disk radius for erode+dilate bridge-breaking. Separates
touching myotubes that share thin pixel bridges.
Start at 2 px; increase for dense cultures. Set 0 to disable.
min_myo_aspect_ratio β€” minimum major/minor axis ratio. Myotubes are elongated
(aspect > 2); round debris blobs (aspect ~1) are rejected.
Set 0 to disable. Recommended: 1.5–2.0 for sparse cultures.
"""
# Nuclei β€” closing fills small gaps, then size filter
nuc_bin = nuc_mask.astype(bool)
if int(nuc_close_radius) > 0:
nuc_bin = closing(nuc_bin, disk(int(nuc_close_radius)))
nuc_clean = remove_small_objects(nuc_bin, min_size=int(min_nuc_area)).astype(np.uint8)
# Myotubes β€” opening + erode/dilate + size filter + shape filter
myo_bin = myo_mask.astype(bool)
if int(myo_open_radius) > 0:
myo_bin = opening(myo_bin, disk(int(myo_open_radius)))
if int(myo_erode_radius) > 0:
se = disk(int(myo_erode_radius))
myo_bin = binary_erosion(myo_bin, se)
myo_bin = binary_dilation(myo_bin, se) # re-dilate to restore size
myo_bin = remove_small_objects(myo_bin, min_size=int(min_myo_area))
if float(min_myo_aspect_ratio) > 0:
myo_bin = _filter_by_aspect_ratio(myo_bin, float(min_myo_aspect_ratio))
myo_clean = myo_bin.astype(np.uint8)
return nuc_clean, myo_clean
def _filter_by_aspect_ratio(mask_bin: np.ndarray, min_aspect: float) -> np.ndarray:
"""Keep only regions with major/minor axis ratio >= min_aspect."""
lab, _ = ndi.label(mask_bin.astype(np.uint8))
keep = np.zeros_like(mask_bin, dtype=bool)
for prop in measure.regionprops(lab):
if prop.minor_axis_length > 0:
aspect = prop.major_axis_length / prop.minor_axis_length
if aspect >= min_aspect:
keep[lab == prop.label] = True
else:
# Degenerate (line-like) β€” keep it (very elongated)
keep[lab == prop.label] = True
return keep
def label_cc(mask: np.ndarray) -> np.ndarray:
lab, _ = ndi.label(mask.astype(np.uint8))
return lab
def split_large_myotubes(myo_lab: np.ndarray,
nuc_lab: np.ndarray,
max_area_px: int = 0,
min_seeds: int = 2) -> np.ndarray:
"""
Fix 2 + 3: Split oversized myotube regions using nucleus-seeded watershed.
Addresses the core myotube merging problem: when adjacent or branching
myotubes form a single connected region, this function splits them using
nuclei centroids as seeds β€” the same principle as nucleus watershed splitting,
applied at the myotube level.
Algorithm
---------
For each myotube region larger than max_area_px:
1. Find all nucleus centroids inside the region
2. If β‰₯ min_seeds nuclei found, run distance-transform watershed on
the myotube mask using nucleus centroids as seeds
3. Replace the merged region with the resulting split sub-regions
4. Remove any resulting fragment smaller than min_myo_area
Parameters
----------
myo_lab : 2D int array β€” labelled myotube instances (from label_cc)
nuc_lab : 2D int array β€” labelled nuclei (from label_nuclei_watershed)
max_area_px : regions larger than this (in pixels) are candidates for splitting.
Set to 0 to disable.
min_seeds : minimum nucleus seeds required to attempt a split (default 2)
Returns
-------
New labelled myotube array with split regions re-numbered sequentially.
"""
if max_area_px <= 0:
return myo_lab
out = myo_lab.copy()
next_id = int(myo_lab.max()) + 1
H, W = myo_lab.shape
for prop in measure.regionprops(myo_lab):
if prop.area <= max_area_px:
continue
# Build binary mask for this single myotube region
region_mask = (myo_lab == prop.label)
# Find nucleus centroids inside this region
seeds_img = np.zeros((H, W), dtype=np.int32)
seed_count = 0
for nuc_prop in measure.regionprops(nuc_lab):
r, c = int(nuc_prop.centroid[0]), int(nuc_prop.centroid[1])
if 0 <= r < H and 0 <= c < W and region_mask[r, c]:
seeds_img[r, c] = nuc_prop.label
seed_count += 1
if seed_count < min_seeds:
# Not enough nuclei to split β€” leave as is
continue
# Distance-transform watershed using nucleus seeds
dist = ndi.distance_transform_edt(region_mask)
result = watershed(-dist, seeds_img, mask=region_mask)
# Clear the original region and write split sub-regions
out[region_mask] = 0
for sub_id in np.unique(result):
if sub_id == 0:
continue
sub_mask = (result == sub_id)
if sub_mask.sum() < 10: # discard tiny slivers
continue
out[sub_mask] = next_id
next_id += 1
# Re-number sequentially 1..N
final = np.zeros_like(out)
for new_id, old_id in enumerate(np.unique(out)[1:], start=1):
final[out == old_id] = new_id
return final
def label_nuclei_watershed(nuc_bin: np.ndarray,
min_distance: int = 3,
min_nuc_area: int = 6) -> np.ndarray:
"""Split touching nuclei via distance-transform watershed."""
nuc_bin = remove_small_objects(nuc_bin.astype(bool), min_size=min_nuc_area)
if nuc_bin.sum() == 0:
return np.zeros_like(nuc_bin, dtype=np.int32)
dist = ndi.distance_transform_edt(nuc_bin)
coords = peak_local_max(dist, labels=nuc_bin,
min_distance=min_distance, exclude_border=False)
markers = np.zeros_like(nuc_bin, dtype=np.int32)
for i, (r, c) in enumerate(coords, start=1):
markers[r, c] = i
if markers.max() == 0:
return ndi.label(nuc_bin.astype(np.uint8))[0].astype(np.int32)
return watershed(-dist, markers, mask=nuc_bin).astype(np.int32)
# ─────────────────────────────────────────────────────────────────────────────
# Surface area (new)
# ─────────────────────────────────────────────────────────────────────────────
def compute_surface_area(myo_mask: np.ndarray, px_um: float = 1.0) -> dict:
lab = label_cc(myo_mask)
px_area = px_um ** 2
per = [round(prop.area * px_area, 2) for prop in measure.regionprops(lab)]
return {
"total_area_um2" : round(sum(per), 2),
"mean_area_um2" : round(float(np.mean(per)) if per else 0.0, 2),
"max_area_um2" : round(float(np.max(per)) if per else 0.0, 2),
"per_myotube_areas" : per,
}
# ─────────────────────────────────────────────────────────────────────────────
# Cytoplasm-hole nucleus classifier (MyoFuse method, Lair et al. 2025)
# ─────────────────────────────────────────────────────────────────────────────
def classify_nucleus_in_myotube(nuc_coords: np.ndarray,
myc_channel: np.ndarray,
myo_mask_full: np.ndarray,
ring_width: int = 6,
hole_ratio_thr: float = 0.85) -> bool:
"""
Determine whether a nucleus is GENUINELY inside a myotube
using the cytoplasm-hole method (MyoFuse, Lair et al. 2025).
A fused nucleus inside a myotube physically displaces the cytoplasm,
creating a local dip (dark "hole") in the MyHC signal beneath it.
An unfused nucleus sitting on top of a myotube in Z does NOT create
this dip β€” its underlying MyHC signal stays bright.
Algorithm
---------
1. Check the nucleus pixel footprint overlaps the myotube mask at all.
If not β€” definitely not fused.
2. Measure mean MyHC intensity under the nucleus pixels (I_nuc).
3. Build a ring around the nucleus (dilated - eroded footprint) clipped
to the myotube mask β€” this is the local cytoplasm reference (I_ring).
4. Compute hole_ratio = I_nuc / I_ring.
If hole_ratio < hole_ratio_thr β†’ nucleus has created a cytoplasmic
hole β†’ genuinely fused.
If hole_ratio β‰₯ hole_ratio_thr β†’ nucleus sits on top in Z β†’ not fused.
Parameters
----------
nuc_coords : (N,2) array of (row, col) pixel coords for this nucleus
myc_channel : 2D float32 array of MyHC channel at FULL image resolution
myo_mask_full : 2D binary mask of myotubes at FULL image resolution
ring_width : dilation radius (px) for the cytoplasm ring
hole_ratio_thr: threshold below which the nucleus is counted as fused
(default 0.85, consistent with MyoFuse calibration)
Returns
-------
True if nucleus is genuinely fused (inside myotube cytoplasm)
"""
rows, cols = nuc_coords[:, 0], nuc_coords[:, 1]
H, W = myc_channel.shape
# Step 1 β€” must overlap myotube mask at all
in_myo = myo_mask_full[rows, cols]
if in_myo.sum() == 0:
return False
# Step 2 β€” mean MyHC under nucleus
I_nuc = float(myc_channel[rows, cols].mean())
# Step 3 β€” build ring around nucleus footprint, clipped to myotube mask
nuc_footprint = np.zeros((H, W), dtype=bool)
nuc_footprint[rows, cols] = True
nuc_dilated = binary_dilation(nuc_footprint, footprint=disk(ring_width))
ring_mask = nuc_dilated & ~nuc_footprint & myo_mask_full.astype(bool)
if ring_mask.sum() < 4:
# Ring too small (nucleus near edge of myotube) β€” fall back to overlap test
return in_myo.mean() >= 0.10
I_ring = float(myc_channel[ring_mask].mean())
if I_ring < 1e-6:
# No myotube signal at all in ring β€” something is wrong, use overlap
return in_myo.mean() >= 0.10
# Step 4 β€” hole ratio test
hole_ratio = I_nuc / I_ring
return hole_ratio < hole_ratio_thr
# ─────────────────────────────────────────────────────────────────────────────
# Biological metrics (counting + fusion + surface area)
# ─────────────────────────────────────────────────────────────────────────────
def compute_bio_metrics(nuc_mask, myo_mask,
myc_channel_full=None,
min_overlap_frac=0.10,
nuc_ws_min_distance=3,
nuc_ws_min_area=6,
px_um=1.0,
ring_width=6,
hole_ratio_thr=0.85) -> dict:
"""
Compute all biological metrics.
If myc_channel_full (the raw MyHC grayscale image at original resolution)
is supplied, uses the cytoplasm-hole method (MyoFuse, Lair et al. 2025)
to classify each nucleus β€” eliminates Z-stack overlap false positives and
gives an accurate, non-overestimated fusion index.
If myc_channel_full is None, falls back to the original pixel-overlap
method for backward compatibility.
"""
nuc_lab = label_nuclei_watershed(nuc_mask,
min_distance=nuc_ws_min_distance,
min_nuc_area=nuc_ws_min_area)
myo_lab = label_cc(myo_mask)
total = int(nuc_lab.max())
# Resize masks/channel to the SAME space for comparison
# nuc_lab and myo_mask are at model resolution (e.g. 512Γ—512).
# myc_channel_full is at original image resolution.
# We resize everything to original resolution for the cytoplasm-hole test.
if myc_channel_full is not None:
H_full, W_full = myc_channel_full.shape
# Resize label maps up to original resolution
nuc_lab_full = np.array(
Image.fromarray(nuc_lab.astype(np.int32))
.resize((W_full, H_full), Image.NEAREST)
)
myo_mask_full = np.array(
Image.fromarray((myo_mask * 255).astype(np.uint8))
.resize((W_full, H_full), Image.NEAREST)
) > 0
# Normalise MyHC channel to 0-1 float
myc_f = myc_channel_full.astype(np.float32)
if myc_f.max() > 1.0:
myc_f = myc_f / 255.0
else:
nuc_lab_full = nuc_lab
myo_mask_full = myo_mask.astype(bool)
myc_f = None
pos, nm = 0, {}
for prop in measure.regionprops(nuc_lab_full):
coords = prop.coords # (N,2) in full-res space
if myc_f is not None:
# ── Cytoplasm-hole method (accurate, MyoFuse 2025) ────────────────
is_fused = classify_nucleus_in_myotube(
coords, myc_f, myo_mask_full,
ring_width=ring_width,
hole_ratio_thr=hole_ratio_thr,
)
else:
# ── Legacy pixel-overlap fallback ─────────────────────────────────
ids = myo_mask_full.astype(np.uint8)[coords[:, 0], coords[:, 1]]
frac = ids.sum() / max(len(coords), 1)
is_fused = frac >= min_overlap_frac
if is_fused:
# Find which myotube this nucleus belongs to (use model-res myo_lab)
# Scale coords back to model resolution
if myc_f is not None:
r_m = np.clip((coords[:, 0] * nuc_lab.shape[0] / H_full).astype(int),
0, nuc_lab.shape[0] - 1)
c_m = np.clip((coords[:, 1] * nuc_lab.shape[1] / W_full).astype(int),
0, nuc_lab.shape[1] - 1)
ids_mt = myo_lab[r_m, c_m]
else:
ids_mt = myo_lab[coords[:, 0], coords[:, 1]]
ids_mt = ids_mt[ids_mt > 0]
if ids_mt.size > 0:
unique, counts = np.unique(ids_mt, return_counts=True)
mt = int(unique[np.argmax(counts)])
nm.setdefault(mt, []).append(prop.label)
pos += 1
per = [len(v) for v in nm.values()]
fused = sum(n for n in per if n >= 2)
fi = 100.0 * fused / total if total else 0.0
pct = 100.0 * pos / total if total else 0.0
avg = float(np.mean(per)) if per else 0.0
sa = compute_surface_area(myo_mask, px_um=px_um)
return {
"total_nuclei" : total,
"myHC_positive_nuclei" : int(pos),
"myHC_positive_percentage" : round(pct, 2),
"nuclei_fused" : int(fused),
"myotube_count" : int(len(per)),
"avg_nuclei_per_myotube" : round(avg, 2),
"fusion_index" : round(fi, 2),
"total_area_um2" : sa["total_area_um2"],
"mean_area_um2" : sa["mean_area_um2"],
"max_area_um2" : sa["max_area_um2"],
"_per_myotube_areas" : sa["per_myotube_areas"],
"_bio_myo_ids" : set(nm.keys()), # myotube label IDs with β‰₯1 MyHC+ nucleus
"_total_cc_count" : int(myo_lab.max()), # total connected components (for reference)
}
# ─────────────────────────────────────────────────────────────────────────────
# Overlay helpers
# ─────────────────────────────────────────────────────────────────────────────
def make_simple_overlay(rgb_u8, nuc_mask, myo_mask, nuc_color, myo_color, alpha):
"""Flat colour overlay β€” used for the ZIP export (fast, no matplotlib)."""
base = rgb_u8.astype(np.float32)
H0, W0 = rgb_u8.shape[:2]
nuc = np.array(Image.fromarray((nuc_mask*255).astype(np.uint8))
.resize((W0, H0), Image.NEAREST)) > 0
myo = np.array(Image.fromarray((myo_mask*255).astype(np.uint8))
.resize((W0, H0), Image.NEAREST)) > 0
out = base.copy()
for mask, color in [(myo, myo_color), (nuc, nuc_color)]:
c = np.array(color, dtype=np.float32)
out[mask] = (1 - alpha) * out[mask] + alpha * c
return np.clip(out, 0, 255).astype(np.uint8)
def make_coloured_overlay(rgb_u8: np.ndarray,
nuc_lab: np.ndarray,
myo_lab: np.ndarray,
alpha: float = 0.45,
nuc_color: tuple = None,
myo_color: tuple = None) -> np.ndarray:
"""
Colour the mask regions only β€” NO text baked in.
Returns an RGB uint8 array at original image resolution.
nuc_color / myo_color: RGB tuple e.g. (0, 255, 255).
If None, uses per-instance colourmaps (cool / autumn).
If provided, uses a flat solid colour for all instances of that type β€”
this is what the sidebar colour pickers control.
"""
orig_h, orig_w = rgb_u8.shape[:2]
nuc_cmap = plt.cm.get_cmap("cool")
myo_cmap = plt.cm.get_cmap("autumn")
def _resize_lab(lab, h, w):
return np.array(
Image.fromarray(lab.astype(np.int32)).resize((w, h), Image.NEAREST)
)
nuc_disp = _resize_lab(nuc_lab, orig_h, orig_w)
myo_disp = _resize_lab(myo_lab, orig_h, orig_w)
n_nuc = int(nuc_disp.max())
n_myo = int(myo_disp.max())
base = rgb_u8.astype(np.float32).copy()
if n_myo > 0:
mask = myo_disp > 0
if myo_color is not None:
colour_layer = np.array(myo_color, dtype=np.float32)
base[mask] = (1 - alpha) * base[mask] + alpha * colour_layer
else:
myo_norm = (myo_disp / max(n_myo, 1)).astype(np.float32)
myo_rgba = (myo_cmap(myo_norm)[:, :, :3] * 255).astype(np.float32)
base[mask] = (1 - alpha) * base[mask] + alpha * myo_rgba[mask]
if n_nuc > 0:
mask = nuc_disp > 0
if nuc_color is not None:
colour_layer = np.array(nuc_color, dtype=np.float32)
base[mask] = (1 - alpha) * base[mask] + alpha * colour_layer
else:
nuc_norm = (nuc_disp / max(n_nuc, 1)).astype(np.float32)
nuc_rgba = (nuc_cmap(nuc_norm)[:, :, :3] * 255).astype(np.float32)
base[mask] = (1 - alpha) * base[mask] + alpha * nuc_rgba[mask]
return np.clip(base, 0, 255).astype(np.uint8)
def make_outline_overlay(rgb_u8: np.ndarray,
nuc_lab: np.ndarray,
myo_lab: np.ndarray,
nuc_color: tuple = (0, 255, 255),
myo_color: tuple = (0, 255, 0),
line_width: int = 2) -> np.ndarray:
"""
Draw contour outlines around each detected instance on the original image.
Shows exactly what the model considers each myotube/nucleus boundary.
"""
orig_h, orig_w = rgb_u8.shape[:2]
def _resize_lab(lab, h, w):
return np.array(
Image.fromarray(lab.astype(np.int32)).resize((w, h), Image.NEAREST)
)
nuc_disp = _resize_lab(nuc_lab, orig_h, orig_w)
myo_disp = _resize_lab(myo_lab, orig_h, orig_w)
out = rgb_u8.copy()
# Myotube outlines
if myo_disp.max() > 0:
myo_bounds = find_boundaries(myo_disp, mode='outer')
if line_width > 1:
myo_bounds = binary_dilation(myo_bounds, footprint=disk(line_width - 1))
out[myo_bounds] = myo_color
# Nuclei outlines
if nuc_disp.max() > 0:
nuc_bounds = find_boundaries(nuc_disp, mode='outer')
if line_width > 1:
nuc_bounds = binary_dilation(nuc_bounds, footprint=disk(max(line_width - 2, 0)))
out[nuc_bounds] = nuc_color
return out
def collect_label_positions(nuc_lab: np.ndarray,
myo_lab: np.ndarray,
img_w: int, img_h: int,
bio_myo_ids: set = None) -> dict:
"""
Collect centroid positions for every nucleus and myotube,
scaled to the original image pixel dimensions.
bio_myo_ids: set of myotube label IDs that have β‰₯1 MyHC+ nucleus.
If provided, only these are labelled as "M1", "M2", …
(renumbered sequentially). Non-bio regions get no label
and are stored separately for faint outline rendering.
Returns:
{ "nuclei": [ {"id": "1", "x": 123.4, "y": 56.7}, ... ],
"myotubes": [ {"id": "M1", "x": 200.1, "y": 300.5, "orig_label": 5}, ... ],
"myotubes_nonbio": [ {"id": "", "x": ..., "y": ..., "orig_label": 3}, ... ] }
"""
sx = img_w / nuc_lab.shape[1]
sy = img_h / nuc_lab.shape[0]
nuclei = []
for prop in measure.regionprops(nuc_lab):
r, c = prop.centroid
nuclei.append({"id": str(prop.label), "x": round(c * sx, 1), "y": round(r * sy, 1)})
sx2 = img_w / myo_lab.shape[1]
sy2 = img_h / myo_lab.shape[0]
myotubes = []
myotubes_nonbio = []
if bio_myo_ids is not None and len(bio_myo_ids) > 0:
# Renumber biological myotubes sequentially: M1, M2, M3…
sorted_bio = sorted(bio_myo_ids)
bio_remap = {orig: idx + 1 for idx, orig in enumerate(sorted_bio)}
for prop in measure.regionprops(myo_lab):
r, c = prop.centroid
pos = {"x": round(c * sx2, 1), "y": round(r * sy2, 1), "orig_label": prop.label}
if prop.label in bio_remap:
pos["id"] = f"M{bio_remap[prop.label]}"
myotubes.append(pos)
else:
pos["id"] = ""
myotubes_nonbio.append(pos)
else:
# Fallback: label all connected components (backward compat)
for prop in measure.regionprops(myo_lab):
r, c = prop.centroid
myotubes.append({"id": f"M{prop.label}", "x": round(c * sx2, 1),
"y": round(r * sy2, 1), "orig_label": prop.label})
return {"nuclei": nuclei, "myotubes": myotubes, "myotubes_nonbio": myotubes_nonbio}
def make_svg_viewer(img_b64: str,
img_w: int, img_h: int,
label_data: dict,
show_nuclei: bool = True,
show_myotubes: bool = True,
nuc_font_size: int = 11,
myo_font_size: int = 22,
viewer_height: int = 620) -> str:
"""
Build a self-contained HTML string with:
- A pan-and-zoom SVG viewer (mouse wheel + click-drag)
- The coloured overlay PNG as the background
- SVG <text> labels that stay pixel-perfect at any zoom level
- A font-size slider that updates label sizes live
- Toggle buttons for nuclei / myotubes labels
- Count badges in the top-right corner
Parameters
----------
img_b64 : base64-encoded PNG of the coloured overlay (no text)
img_w, img_h : original pixel dimensions of the image
label_data : output of collect_label_positions()
show_nuclei : initial visibility of nucleus labels
show_myotubes : initial visibility of myotube labels
nuc_font_size : initial nucleus label font size (px)
myo_font_size : initial myotube label font size (px)
viewer_height : height of the viewer div in pixels
"""
import json as _json
labels_json = _json.dumps(label_data)
n_nuc = len(label_data.get("nuclei", []))
n_myo = len(label_data.get("myotubes", []))
show_nuc_js = "true" if show_nuclei else "false"
show_myo_js = "true" if show_myotubes else "false"
html = f"""
<style>
.myo-viewer-wrap {{
background: #0e0e1a;
border: 1px solid #2a2a4e;
border-radius: 10px;
overflow: hidden;
position: relative;
user-select: none;
}}
.myo-toolbar {{
display: flex;
align-items: center;
gap: 12px;
padding: 8px 14px;
background: #13132a;
border-bottom: 1px solid #2a2a4e;
flex-wrap: wrap;
}}
.myo-badge {{
background: #1a1a3e;
border: 1px solid #3a3a6e;
border-radius: 6px;
padding: 3px 10px;
color: #e0e0e0;
font-size: 13px;
font-family: Arial, sans-serif;
white-space: nowrap;
}}
.myo-badge span {{ font-weight: bold; }}
.myo-btn {{
padding: 4px 12px;
border-radius: 6px;
border: 1px solid #444;
cursor: pointer;
font-size: 12px;
font-family: Arial, sans-serif;
font-weight: bold;
transition: opacity 0.15s;
}}
.myo-btn.nuc {{ background: #003366; color: white; border-color: #4fc3f7; }}
.myo-btn.myo {{ background: #8B0000; color: white; border-color: #ff6666; }}
.myo-btn.off {{ opacity: 0.35; }}
.myo-btn.reset {{ background: #1a1a2e; color: #90caf9; border-color: #3a3a6e; }}
.myo-slider-wrap {{
display: flex;
align-items: center;
gap: 6px;
color: #aaa;
font-size: 12px;
font-family: Arial, sans-serif;
}}
.myo-slider-wrap input {{ width: 70px; accent-color: #4fc3f7; cursor: pointer; }}
.myo-hint {{
margin-left: auto;
color: #555;
font-size: 11px;
font-family: Arial, sans-serif;
white-space: nowrap;
}}
.myo-svg-wrap {{
width: 100%;
height: {viewer_height}px;
overflow: hidden;
cursor: grab;
position: relative;
}}
.myo-svg-wrap:active {{ cursor: grabbing; }}
svg.myo-svg {{
width: 100%;
height: 100%;
display: block;
}}
</style>
<div class="myo-viewer-wrap" id="myoViewer">
<div class="myo-toolbar">
<div class="myo-badge">πŸ”΅ Nuclei &nbsp;<span id="nucCount">{n_nuc}</span></div>
<div class="myo-badge">πŸ”΄ Myotubes &nbsp;<span id="myoCount">{n_myo}</span></div>
<button class="myo-btn nuc" id="btnNuc" onclick="toggleLayer('nuc')">Nuclei IDs</button>
<button class="myo-btn myo" id="btnMyo" onclick="toggleLayer('myo')">Myotube IDs</button>
<button class="myo-btn reset" onclick="resetView()">⟳ Reset</button>
<div class="myo-slider-wrap">
Nucleus size:
<input type="range" id="slNuc" min="4" max="40" value="{nuc_font_size}"
oninput="setFontSize('nuc', this.value)" />
<span id="lblNuc">{nuc_font_size}px</span>
</div>
<div class="myo-slider-wrap">
Myotube size:
<input type="range" id="slMyo" min="8" max="60" value="{myo_font_size}"
oninput="setFontSize('myo', this.value)" />
<span id="lblMyo">{myo_font_size}px</span>
</div>
<div class="myo-hint">Scroll to zoom &nbsp;Β·&nbsp; Drag to pan</div>
</div>
<div class="myo-svg-wrap" id="svgWrap">
<svg class="myo-svg" id="mainSvg"
viewBox="0 0 {img_w} {img_h}"
preserveAspectRatio="xMidYMid meet">
<defs>
<filter id="dropshadow" x="-5%" y="-5%" width="110%" height="110%">
<feDropShadow dx="0" dy="0" stdDeviation="1.5" flood-color="#000" flood-opacity="0.8"/>
</filter>
</defs>
<!-- background image β€” the coloured overlay PNG -->
<image href="data:image/png;base64,{img_b64}"
x="0" y="0" width="{img_w}" height="{img_h}"
preserveAspectRatio="xMidYMid meet"/>
<!-- nuclei labels group -->
<g id="gNuc" visibility="{'visible' if show_nuclei else 'hidden'}">
</g>
<!-- myotube labels group -->
<g id="gMyo" visibility="{'visible' if show_myotubes else 'hidden'}">
</g>
</svg>
</div>
</div>
<script>
(function() {{
const labels = {labels_json};
const IMG_W = {img_w};
const IMG_H = {img_h};
let nucFontSize = {nuc_font_size};
let myoFontSize = {myo_font_size};
let showNuc = {show_nuc_js};
let showMyo = {show_myo_js};
// ── Build SVG label elements ─────────────────────────────────────────────
const NS = "http://www.w3.org/2000/svg";
function makeLabelGroup(items, fontSize, bgColor, borderColor, isMyo) {{
const frag = document.createDocumentFragment();
items.forEach(item => {{
const g = document.createElementNS(NS, "g");
g.setAttribute("class", isMyo ? "lbl-myo" : "lbl-nuc");
// Background rect β€” sized after text is measured
const rect = document.createElementNS(NS, "rect");
rect.setAttribute("rx", isMyo ? "4" : "3");
rect.setAttribute("ry", isMyo ? "4" : "3");
rect.setAttribute("fill", bgColor);
rect.setAttribute("stroke", borderColor);
rect.setAttribute("stroke-width", isMyo ? "1.5" : "0");
rect.setAttribute("opacity", isMyo ? "0.93" : "0.90");
rect.setAttribute("filter", "url(#dropshadow)");
// Text
const txt = document.createElementNS(NS, "text");
txt.textContent = item.id;
txt.setAttribute("x", item.x);
txt.setAttribute("y", item.y);
txt.setAttribute("text-anchor", "middle");
txt.setAttribute("dominant-baseline", "central");
txt.setAttribute("fill", "white");
txt.setAttribute("font-family", "Arial, sans-serif");
txt.setAttribute("font-weight", "bold");
txt.setAttribute("font-size", fontSize);
txt.setAttribute("paint-order", "stroke");
g.appendChild(rect);
g.appendChild(txt);
frag.appendChild(g);
}});
return frag;
}}
function positionRects() {{
// After elements are in the DOM, size and position the backing rects
document.querySelectorAll(".lbl-nuc, .lbl-myo").forEach(g => {{
const txt = g.querySelector("text");
const rect = g.querySelector("rect");
try {{
const bb = txt.getBBox();
const pad = parseFloat(txt.getAttribute("font-size")) * 0.22;
rect.setAttribute("x", bb.x - pad);
rect.setAttribute("y", bb.y - pad);
rect.setAttribute("width", bb.width + pad * 2);
rect.setAttribute("height", bb.height + pad * 2);
}} catch(e) {{}}
}});
}}
function rebuildLabels() {{
const gNuc = document.getElementById("gNuc");
const gMyo = document.getElementById("gMyo");
gNuc.innerHTML = "";
gMyo.innerHTML = "";
gNuc.appendChild(makeLabelGroup(labels.nuclei, nucFontSize, "#003366", "none", false));
gMyo.appendChild(makeLabelGroup(labels.myotubes, myoFontSize, "#8B0000", "#FF6666", true));
// rAF so the browser has laid out the text before we measure it
requestAnimationFrame(positionRects);
}}
// ── Font size controls ────────────────────────────────────────────────────
window.setFontSize = function(which, val) {{
val = parseInt(val);
if (which === "nuc") {{
nucFontSize = val;
document.getElementById("lblNuc").textContent = val + "px";
document.querySelectorAll(".lbl-nuc text").forEach(t => t.setAttribute("font-size", val));
}} else {{
myoFontSize = val;
document.getElementById("lblMyo").textContent = val + "px";
document.querySelectorAll(".lbl-myo text").forEach(t => t.setAttribute("font-size", val));
}}
requestAnimationFrame(positionRects);
}};
// ── Layer toggles ─────────────────────────────────────────────────────────
window.toggleLayer = function(which) {{
if (which === "nuc") {{
showNuc = !showNuc;
document.getElementById("gNuc").setAttribute("visibility", showNuc ? "visible" : "hidden");
document.getElementById("btnNuc").classList.toggle("off", !showNuc);
}} else {{
showMyo = !showMyo;
document.getElementById("gMyo").setAttribute("visibility", showMyo ? "visible" : "hidden");
document.getElementById("btnMyo").classList.toggle("off", !showMyo);
}}
}};
// ── Pan + Zoom (pure SVG viewBox manipulation) ────────────────────────────
const wrap = document.getElementById("svgWrap");
const svg = document.getElementById("mainSvg");
let vx = 0, vy = 0, vw = IMG_W, vh = IMG_H; // current viewBox
function setVB() {{
svg.setAttribute("viewBox", `${{vx}} ${{vy}} ${{vw}} ${{vh}}`);
}}
// Scroll to zoom β€” zoom toward mouse cursor
wrap.addEventListener("wheel", e => {{
e.preventDefault();
const rect = wrap.getBoundingClientRect();
const mx = (e.clientX - rect.left) / rect.width; // 0..1
const my = (e.clientY - rect.top) / rect.height;
const factor = e.deltaY < 0 ? 0.85 : 1.0 / 0.85;
const nw = Math.min(IMG_W, Math.max(IMG_W * 0.05, vw * factor));
const nh = Math.min(IMG_H, Math.max(IMG_H * 0.05, vh * factor));
vx = vx + mx * (vw - nw);
vy = vy + my * (vh - nh);
vw = nw;
vh = nh;
// Clamp
vx = Math.max(0, Math.min(IMG_W - vw, vx));
vy = Math.max(0, Math.min(IMG_H - vh, vy));
setVB();
}}, {{ passive: false }});
// Drag to pan
let dragging = false, dragX0, dragY0, vx0, vy0;
wrap.addEventListener("mousedown", e => {{
dragging = true;
dragX0 = e.clientX; dragY0 = e.clientY;
vx0 = vx; vy0 = vy;
}});
window.addEventListener("mousemove", e => {{
if (!dragging) return;
const rect = wrap.getBoundingClientRect();
const scaleX = vw / rect.width;
const scaleY = vh / rect.height;
vx = Math.max(0, Math.min(IMG_W - vw, vx0 - (e.clientX - dragX0) * scaleX));
vy = Math.max(0, Math.min(IMG_H - vh, vy0 - (e.clientY - dragY0) * scaleY));
setVB();
}});
window.addEventListener("mouseup", () => {{ dragging = false; }});
// Touch support
let t0 = null, pinch0 = null;
wrap.addEventListener("touchstart", e => {{
if (e.touches.length === 1) {{
t0 = e.touches[0]; vx0 = vx; vy0 = vy;
}} else if (e.touches.length === 2) {{
pinch0 = Math.hypot(
e.touches[0].clientX - e.touches[1].clientX,
e.touches[0].clientY - e.touches[1].clientY
);
}}
}}, {{ passive: true }});
wrap.addEventListener("touchmove", e => {{
e.preventDefault();
if (e.touches.length === 1 && t0) {{
const rect = wrap.getBoundingClientRect();
vx = Math.max(0, Math.min(IMG_W - vw, vx0 - (e.touches[0].clientX - t0.clientX) * vw / rect.width));
vy = Math.max(0, Math.min(IMG_H - vh, vy0 - (e.touches[0].clientY - t0.clientY) * vh / rect.height));
setVB();
}} else if (e.touches.length === 2 && pinch0 !== null) {{
const dist = Math.hypot(
e.touches[0].clientX - e.touches[1].clientX,
e.touches[0].clientY - e.touches[1].clientY
);
const factor = pinch0 / dist;
const nw = Math.min(IMG_W, Math.max(IMG_W * 0.05, vw * factor));
const nh = Math.min(IMG_H, Math.max(IMG_H * 0.05, vh * factor));
vw = nw; vh = nh;
vx = Math.max(0, Math.min(IMG_W - vw, vx));
vy = Math.max(0, Math.min(IMG_H - vh, vy));
pinch0 = dist;
setVB();
}}
}}, {{ passive: false }});
// ── Reset view ────────────────────────────────────────────────────────────
window.resetView = function() {{
vx = 0; vy = 0; vw = IMG_W; vh = IMG_H;
setVB();
}};
// ── Init ──────────────────────────────────────────────────────────────────
rebuildLabels();
}})();
</script>
"""
return html
# ─────────────────────────────────────────────────────────────────────────────
# Animated counter
# ─────────────────────────────────────────────────────────────────────────────
def animated_metric(placeholder, label: str, final_val,
color: str = "#4fc3f7", steps: int = 20, delay: float = 0.025):
is_float = isinstance(final_val, float)
for i in range(1, steps + 1):
v = final_val * i / steps
display = f"{v:.1f}" if is_float else str(int(v))
placeholder.markdown(
f"""
<div style='text-align:center;padding:12px 6px;border-radius:12px;
background:#1a1a2e;border:1px solid #2a2a4e;margin:4px 0;'>
<div style='font-size:2rem;font-weight:800;color:{color};
line-height:1.1;'>{display}</div>
<div style='font-size:0.75rem;color:#9e9e9e;margin-top:4px;'>{label}</div>
</div>
""",
unsafe_allow_html=True,
)
time.sleep(delay)
# ─────────────────────────────────────────────────────────────────────────────
# Active-learning queue helpers
# ─────────────────────────────────────────────────────────────────────────────
def _ensure_dirs():
QUEUE_DIR.mkdir(parents=True, exist_ok=True)
CORRECTIONS_DIR.mkdir(parents=True, exist_ok=True)
def add_to_queue(image_array: np.ndarray, reason: str = "batch",
nuc_mask=None, myo_mask=None, metadata: dict = None):
_ensure_dirs()
ts = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
meta = {**(metadata or {}), "reason": reason, "timestamp": ts}
if nuc_mask is not None and myo_mask is not None:
folder = CORRECTIONS_DIR / ts
folder.mkdir(parents=True, exist_ok=True)
Image.fromarray(image_array).save(folder / "image.png")
Image.fromarray((nuc_mask > 0).astype(np.uint8) * 255).save(folder / "nuclei_mask.png")
Image.fromarray((myo_mask > 0).astype(np.uint8) * 255).save(folder / "myotube_mask.png")
(folder / "meta.json").write_text(json.dumps({**meta, "has_masks": True}, indent=2))
else:
Image.fromarray(image_array).save(QUEUE_DIR / f"{ts}.png")
(QUEUE_DIR / f"{ts}.json").write_text(json.dumps({**meta, "has_masks": False}, indent=2))
# ─────────────────────────────────────────────────────────────────────────────
# Model (architecture identical to training script)
# ─────────────────────────────────────────────────────────────────────────────
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(True),
nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(True),
)
def forward(self, x): return self.net(x)
class UNet(nn.Module):
def __init__(self, in_ch=2, out_ch=2, base=32):
super().__init__()
self.d1 = DoubleConv(in_ch, base); self.p1 = nn.MaxPool2d(2)
self.d2 = DoubleConv(base, base*2); self.p2 = nn.MaxPool2d(2)
self.d3 = DoubleConv(base*2, base*4); self.p3 = nn.MaxPool2d(2)
self.d4 = DoubleConv(base*4, base*8); self.p4 = nn.MaxPool2d(2)
self.bn = DoubleConv(base*8, base*16)
self.u4 = nn.ConvTranspose2d(base*16, base*8, 2, 2); self.du4 = DoubleConv(base*16, base*8)
self.u3 = nn.ConvTranspose2d(base*8, base*4, 2, 2); self.du3 = DoubleConv(base*8, base*4)
self.u2 = nn.ConvTranspose2d(base*4, base*2, 2, 2); self.du2 = DoubleConv(base*4, base*2)
self.u1 = nn.ConvTranspose2d(base*2, base, 2, 2); self.du1 = DoubleConv(base*2, base)
self.out = nn.Conv2d(base, out_ch, 1)
def forward(self, x):
d1=self.d1(x); p1=self.p1(d1)
d2=self.d2(p1); p2=self.p2(d2)
d3=self.d3(p2); p3=self.p3(d3)
d4=self.d4(p3); p4=self.p4(d4)
b=self.bn(p4)
x=self.u4(b); x=torch.cat([x,d4],1); x=self.du4(x)
x=self.u3(x); x=torch.cat([x,d3],1); x=self.du3(x)
x=self.u2(x); x=torch.cat([x,d2],1); x=self.du2(x)
x=self.u1(x); x=torch.cat([x,d1],1); x=self.du1(x)
return self.out(x)
@st.cache_resource
def load_model(device: str):
local = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME,
force_download=True)
file_sha = sha256_file(local)
mtime = time.ctime(os.path.getmtime(local))
size_mb = os.path.getsize(local) / 1e6
st.sidebar.markdown("### πŸ” Model debug")
st.sidebar.caption(f"Repo: `{MODEL_REPO_ID}`")
st.sidebar.caption(f"File: `{MODEL_FILENAME}`")
st.sidebar.caption(f"Size: {size_mb:.2f} MB")
st.sidebar.caption(f"Modified: {mtime}")
st.sidebar.caption(f"SHA256: `{file_sha[:20]}…`")
ckpt = torch.load(local, map_location=device)
state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
model = UNet(in_ch=2, out_ch=2, base=32)
model.load_state_dict(state)
model.to(device).eval()
return model
# ─────────────────────────────────────────────────────────────────────────────
# PAGE CONFIG + CSS
# ─────────────────────────────────────────────────────────────────────────────
st.set_page_config(page_title="MyoSeg β€” Myotube Analyser",
layout="wide", page_icon="πŸ”¬")
st.markdown("""
<style>
body, .stApp { background:#0e0e1a; color:#e0e0e0; }
.block-container { max-width:1200px; padding-top:1.25rem; }
h1,h2,h3,h4 { color:#90caf9; }
.flag-box {
background:#3e1a1a; border-left:4px solid #ef5350;
padding:10px 16px; border-radius:8px; margin:8px 0;
}
</style>
""", unsafe_allow_html=True)
st.title("πŸ”¬ MyoSeg β€” Myotube & Nuclei Analyser")
device = "cuda" if torch.cuda.is_available() else "cpu"
# ─────────────────────────────────────────────────────────────────────────────
# SIDEBAR
# ─────────────────────────────────────────────────────────────────────────────
with st.sidebar:
st.caption(f"Device: **{device}**")
st.header("Input mapping")
src1 = st.selectbox("Model channel 1 (MyHC / myotubes)",
["Red", "Green", "Blue", "Grayscale"], index=0)
inv1 = st.checkbox("Invert channel 1", value=False)
src2 = st.selectbox("Model channel 2 (DAPI / nuclei)",
["Red", "Green", "Blue", "Grayscale"], index=2)
inv2 = st.checkbox("Invert channel 2", value=False)
st.header("Preprocessing")
image_size = st.select_slider("Model input size",
options=[256, 384, 512, 640, 768, 1024], value=512)
st.header("Thresholds")
thr_nuc = st.slider("Nuclei threshold", 0.05, 0.95, 0.45, 0.01)
thr_myo = st.slider("Myotube threshold", 0.05, 0.95, 0.40, 0.01)
st.header("Fusion Index method")
fi_method = st.radio(
"FI classification method",
["Cytoplasm-hole (accurate, Lair 2025)", "Pixel-overlap (legacy)"],
index=0,
help=(
"Cytoplasm-hole: checks for a MyHC signal dip beneath each nucleus β€” "
"eliminates false positives from nuclei sitting above/below myotubes in Z. "
"Pixel-overlap: legacy method that overestimates FI (Lair et al. 2025)."
)
)
use_hole_method = fi_method.startswith("Cytoplasm")
hole_ratio_thr = st.slider(
"Hole ratio threshold", 0.50, 0.99, 0.85, 0.01,
help=(
"A nucleus is counted as fused if its MyHC signal is less than "
"this fraction of the surrounding cytoplasm ring signal. "
"Lower = stricter (fewer nuclei counted as fused). "
"0.85 is the value validated by Lair et al. 2025."
),
disabled=not use_hole_method,
)
ring_width_px = st.number_input(
"Cytoplasm ring width (px)", 2, 20, 6, 1,
help="Width of the ring around each nucleus used to measure local MyHC intensity.",
disabled=not use_hole_method,
)
st.header("Postprocessing")
min_nuc_area = st.number_input("Min nucleus area (px)", 0, 10000, 20, 1)
min_myo_area = st.number_input("Min myotube area (px)", 0, 200000, 500, 10)
nuc_close_radius = st.number_input("Nuclei close radius", 0, 50, 2, 1)
myo_open_radius = st.number_input("Myotube open radius", 0, 50, 2, 1,
help="Opening removes small noise without merging separate myotubes. "
"Replaces the old closing radius which was merging adjacent myotubes.")
st.header("Myotube separation")
st.caption(
"These controls break apart touching/bridged myotubes that would "
"otherwise be counted as a single object."
)
myo_erode_radius = st.number_input(
"Myotube erode radius (px)", 0, 15, 2, 1,
help=(
"Erode + re-dilate breaks thin pixel bridges between adjacent "
"myotubes while preserving their overall size. "
"Start at 2 px; increase to 3–4 px for very dense cultures. "
"Set 0 to disable."
)
)
min_myo_aspect_ratio = st.number_input(
"Min myotube aspect ratio", 0.0, 10.0, 0.0, 0.1,
help=(
"Rejects round blobs (debris/artifacts) that are not real myotubes. "
"Myotubes are elongated (aspect ratio > 3). Round objects have ~1. "
"Set to 1.5–2.0 to filter false positives in sparse cultures. "
"Set to 0 to disable (default)."
)
)
myo_max_area_px = st.number_input(
"Max myotube area before split (pxΒ²)", 0, 500000, 20000, 500,
help=(
"Any connected myotube region larger than this is split using "
"nucleus-seeded watershed. Set to 0 to disable. "
"Increase for cultures with legitimately large single myotubes."
)
)
myo_split_min_seeds = st.number_input(
"Min nuclei seeds to split", 2, 20, 2, 1,
help=(
"Minimum nucleus centroids required before splitting a large region. "
"Set to 2 to split merged pairs; increase if single large myotubes "
"are being incorrectly split."
)
)
st.header("Watershed (nuclei splitting)")
nuc_ws_min_dist = st.number_input("Min watershed distance", 1, 30, 3, 1)
nuc_ws_min_area = st.number_input("Min watershed area (px)", 1, 500, 6, 1)
st.header("Overlay")
nuc_hex = st.color_picker("Nuclei colour", "#00FFFF")
myo_hex = st.color_picker("Myotube colour", "#FF0000")
alpha = st.slider("Overlay alpha", 0.0, 1.0, 0.45, 0.01)
nuc_rgb = hex_to_rgb(nuc_hex)
myo_rgb = hex_to_rgb(myo_hex)
label_nuc = st.checkbox("Show nucleus IDs on overlay", value=True)
label_myo = st.checkbox("Show myotube IDs on overlay", value=True)
st.header("Surface area")
px_um = st.number_input("Pixel size (Β΅m) β€” set for real Β΅mΒ²",
value=1.0, min_value=0.01, step=0.01)
st.header("Active learning")
enable_al = st.toggle("Enable correction upload", value=True)
st.header("Privacy & Training")
private_mode = st.toggle(
"πŸ”’ Private mode",
value=False,
help=(
"When enabled, your images are processed locally only. "
"They are NOT added to the retraining queue, NOT saved to "
"corrections/, and NOT used for model improvement in any way. "
"Use this for unpublished data or sensitive research images."
)
)
if private_mode:
st.info(
"πŸ”’ **Private mode ON** β€” your images will not be used for "
"training or stored beyond this session."
)
st.header("Metric definitions")
with st.expander("Fusion Index"):
st.write("100 Γ— (nuclei in myotubes with β‰₯2 nuclei) / total nuclei")
with st.expander("MyHC-positive nucleus"):
st.write("Counted if β‰₯10% of nucleus pixels overlap a myotube.")
with st.expander("Surface area"):
st.write("Pixel count Γ— px_umΒ². Set pixel size for real Β΅mΒ² values.")
# ─────────────────────────────────────────────────────────────────────────────
# FILE UPLOADER
# ─────────────────────────────────────────────────────────────────────────────
_uploader_label = (
"Upload 1+ images (png / jpg / tif). πŸ”’ Private mode is ON β€” images will not be stored."
if private_mode else
"Upload 1+ images (png / jpg / tif). Images may be used for model improvement."
)
uploads = st.file_uploader(
_uploader_label,
type=["png", "jpg", "jpeg", "tif", "tiff"],
accept_multiple_files=True,
)
for key in ("df", "artifacts", "zip_bytes", "bio_metrics"):
if key not in st.session_state:
st.session_state[key] = None
if not uploads:
st.info("πŸ‘† Upload one or more fluorescence images to get started.")
st.stop()
model = load_model(device=device)
# ─────────────────────────────────────────────────────────────────────────────
# RUN ANALYSIS
# ─────────────────────────────────────────────────────────────────────────────
with st.form("run_form"):
run = st.form_submit_button("β–Ά Run / Rerun analysis", type="primary")
if run:
results = []
artifacts = {}
all_bio_metrics = {}
low_conf_flags = []
zip_buf = io.BytesIO()
with st.spinner("Analysing images…"):
with zipfile.ZipFile(zip_buf, "w", compression=zipfile.ZIP_DEFLATED) as zf:
prog = st.progress(0.0)
for i, up in enumerate(uploads):
name = Path(up.name).stem
rgb_u8 = np.array(
Image.open(io.BytesIO(up.getvalue())).convert("RGB"),
dtype=np.uint8
)
ch1 = get_channel(rgb_u8, src1) # MyHC channel
ch2 = get_channel(rgb_u8, src2) # DAPI / nuclei channel
if inv1: ch1 = 255 - ch1
if inv2: ch2 = 255 - ch2
# Keep the full-resolution MyHC channel for the cytoplasm-hole
# FI classifier β€” must be at original image resolution
myc_full = ch1.copy() # uint8, original resolution
H = W = int(image_size)
x1 = resize_u8_to_float01(ch1, W, H, Image.BILINEAR)
x2 = resize_u8_to_float01(ch2, W, H, Image.BILINEAR)
x = np.stack([x1, x2], 0).astype(np.float32)
x_t = torch.from_numpy(x).unsqueeze(0).to(device)
with torch.no_grad():
probs = torch.sigmoid(model(x_t)).cpu().numpy()[0]
# Confidence check β€” only queue for training if NOT in private mode
conf = float(np.mean([probs[0].max(), probs[1].max()]))
if conf < CONF_FLAG_THR:
low_conf_flags.append((name, conf))
if not private_mode:
add_to_queue(rgb_u8, reason="low_confidence",
metadata={"confidence": conf, "filename": up.name})
nuc_raw = (probs[0] > float(thr_nuc)).astype(np.uint8)
myo_raw = (probs[1] > float(thr_myo)).astype(np.uint8)
nuc_pp, myo_pp = postprocess_masks(
nuc_raw, myo_raw,
min_nuc_area=int(min_nuc_area),
min_myo_area=int(min_myo_area),
nuc_close_radius=int(nuc_close_radius),
myo_open_radius=int(myo_open_radius),
myo_erode_radius=int(myo_erode_radius),
min_myo_aspect_ratio=float(min_myo_aspect_ratio),
)
# Flat overlay for ZIP (no labels β€” just colour regions)
simple_ov = make_simple_overlay(
rgb_u8, nuc_pp, myo_pp, nuc_rgb, myo_rgb, float(alpha)
)
# Label maps β€” shared across all three viewers
nuc_lab = label_nuclei_watershed(nuc_pp,
min_distance=int(nuc_ws_min_dist),
min_nuc_area=int(nuc_ws_min_area))
myo_lab = label_cc(myo_pp)
# Fix 2+3: split oversized merged myotube regions using nucleus seeds
# Runs only when myo_max_area_px > 0; no effect if disabled
if int(myo_max_area_px) > 0:
myo_lab = split_large_myotubes(
myo_lab, nuc_lab,
max_area_px=int(myo_max_area_px),
min_seeds=int(myo_split_min_seeds),
)
# Coloured pixel overlays (no baked-in text β€” labels drawn as SVG)
inst_px = make_coloured_overlay(rgb_u8, nuc_lab, myo_lab, alpha=float(alpha))
nuc_only_px = make_coloured_overlay(rgb_u8, nuc_lab, np.zeros_like(myo_lab), alpha=float(alpha))
myo_only_px = make_coloured_overlay(rgb_u8, np.zeros_like(nuc_lab), myo_lab, alpha=float(alpha))
# Compute bio metrics FIRST so we know which myotubes are biological
bio = compute_bio_metrics(
nuc_pp, myo_pp,
myc_channel_full=myc_full if use_hole_method else None,
nuc_ws_min_distance=int(nuc_ws_min_dist),
nuc_ws_min_area=int(nuc_ws_min_area),
px_um=float(px_um),
ring_width=int(ring_width_px),
hole_ratio_thr=float(hole_ratio_thr),
)
bio["fi_method"] = "cytoplasm-hole" if use_hole_method else "pixel-overlap"
per_areas = bio.pop("_per_myotube_areas", [])
bio_myo_ids = bio.pop("_bio_myo_ids", set())
total_cc_count = bio.pop("_total_cc_count", 0)
bio["image"] = name
results.append(bio)
all_bio_metrics[name] = {**bio, "_per_myotube_areas": per_areas}
# Label positions β€” uses bio_myo_ids to only label biological myotubes
orig_h_img, orig_w_img = rgb_u8.shape[:2]
label_positions = collect_label_positions(nuc_lab, myo_lab, orig_w_img, orig_h_img,
bio_myo_ids=bio_myo_ids)
artifacts[name] = {
# raw pixel data β€” overlays built at display time from these
"rgb_u8" : rgb_u8,
"nuc_lab" : nuc_lab,
"myo_lab" : myo_lab,
# postprocessed masks (for outline generation)
"nuc_pp_arr" : nuc_pp,
"myo_pp_arr" : myo_pp,
# static mask PNGs
"nuc_pp" : png_bytes((nuc_pp * 255).astype(np.uint8)),
"myo_pp" : png_bytes((myo_pp * 255).astype(np.uint8)),
"nuc_raw_bytes" : png_bytes((nuc_raw*255).astype(np.uint8)),
"myo_raw_bytes" : png_bytes((myo_raw*255).astype(np.uint8)),
# label positions for SVG viewer
"label_positions": label_positions,
# image dimensions
"img_w" : orig_w_img,
"img_h" : orig_h_img,
}
# ZIP built with current colour settings at run time
outline_ov = make_outline_overlay(rgb_u8, nuc_lab, myo_lab,
nuc_color=nuc_rgb, myo_color=(0, 255, 0),
line_width=2)
zf.writestr(f"{name}/overlay_combined.png", png_bytes(simple_ov))
zf.writestr(f"{name}/overlay_instance.png", png_bytes(inst_px))
zf.writestr(f"{name}/overlay_nuclei.png", png_bytes(nuc_only_px))
zf.writestr(f"{name}/overlay_myotubes.png", png_bytes(myo_only_px))
zf.writestr(f"{name}/overlay_outlines.png", png_bytes(outline_ov))
zf.writestr(f"{name}/nuclei_pp.png", artifacts[name]["nuc_pp"])
zf.writestr(f"{name}/myotube_pp.png", artifacts[name]["myo_pp"])
zf.writestr(f"{name}/nuclei_raw.png", artifacts[name]["nuc_raw_bytes"])
zf.writestr(f"{name}/myotube_raw.png", artifacts[name]["myo_raw_bytes"])
prog.progress((i + 1) / len(uploads))
df = pd.DataFrame(results).sort_values("image")
zf.writestr("metrics.csv", df.to_csv(index=False).encode("utf-8"))
st.session_state.df = df
st.session_state.artifacts = artifacts
st.session_state.zip_bytes = zip_buf.getvalue()
st.session_state.bio_metrics = all_bio_metrics
if low_conf_flags:
names_str = ", ".join(f"{n} (conf={c:.2f})" for n, c in low_conf_flags)
st.markdown(
f"<div class='flag-box'>⚠️ <b>Low-confidence images auto-queued for retraining:</b> "
f"{names_str}</div>",
unsafe_allow_html=True,
)
if st.session_state.df is None:
st.info("Click **β–Ά Run / Rerun analysis** to generate results.")
st.stop()
# ─────────────────────────────────────────────────────────────────────────────
# RESULTS TABLE + DOWNLOADS
# ─────────────────────────────────────────────────────────────────────────────
st.subheader("πŸ“‹ Results")
display_cols = [c for c in st.session_state.df.columns if not c.startswith("_")]
st.dataframe(st.session_state.df[display_cols], use_container_width=True, height=320)
c1, c2, c3 = st.columns(3)
with c1:
st.download_button("⬇️ Download metrics.csv",
st.session_state.df[display_cols].to_csv(index=False).encode(),
file_name="metrics.csv", mime="text/csv")
with c2:
st.download_button("⬇️ Download results.zip",
st.session_state.zip_bytes,
file_name="results.zip", mime="application/zip")
with c3:
# Rebuild ZIP with CURRENT colour / alpha settings β€” no model rerun needed
if st.button("🎨 Rebuild ZIP with current colours", help=(
"Regenerates the overlay images in the ZIP using the current "
"colour picker and alpha values from the sidebar."
)):
import base64 as _b64_zip
new_zip_buf = io.BytesIO()
with zipfile.ZipFile(new_zip_buf, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for img_name, art in st.session_state.artifacts.items():
_r = art["rgb_u8"]
_nl = art["nuc_lab"]
_ml = art["myo_lab"]
_zn = np.zeros_like(_nl)
_zm = np.zeros_like(_ml)
ov_comb = make_coloured_overlay(_r, _nl, _ml,
alpha=float(alpha),
nuc_color=nuc_rgb, myo_color=myo_rgb)
ov_nuc = make_coloured_overlay(_r, _nl, _zm,
alpha=float(alpha),
nuc_color=nuc_rgb, myo_color=None)
ov_myo = make_coloured_overlay(_r, _zn, _ml,
alpha=float(alpha),
nuc_color=None, myo_color=myo_rgb)
simple = make_simple_overlay(_r,
(_nl > 0).astype(np.uint8),
(_ml > 0).astype(np.uint8),
nuc_rgb, myo_rgb, float(alpha))
outline = make_outline_overlay(_r, _nl, _ml,
nuc_color=nuc_rgb, myo_color=(0, 255, 0),
line_width=2)
zf.writestr(f"{img_name}/overlay_combined.png", png_bytes(simple))
zf.writestr(f"{img_name}/overlay_instance.png", png_bytes(ov_comb))
zf.writestr(f"{img_name}/overlay_nuclei.png", png_bytes(ov_nuc))
zf.writestr(f"{img_name}/overlay_myotubes.png", png_bytes(ov_myo))
zf.writestr(f"{img_name}/overlay_outlines.png", png_bytes(outline))
zf.writestr(f"{img_name}/nuclei_pp.png", art["nuc_pp"])
zf.writestr(f"{img_name}/myotube_pp.png", art["myo_pp"])
zf.writestr(f"{img_name}/nuclei_raw.png", art["nuc_raw_bytes"])
zf.writestr(f"{img_name}/myotube_raw.png", art["myo_raw_bytes"])
df_cols = [c for c in st.session_state.df.columns if not c.startswith("_")]
zf.writestr("metrics.csv", st.session_state.df[df_cols].to_csv(index=False).encode())
st.session_state.zip_bytes = new_zip_buf.getvalue()
st.success("ZIP rebuilt with current colours. Click Download results.zip above to save.")
st.divider()
# ─────────────────────────────────────────────────────────────────────────────
# PER-IMAGE PREVIEW + ANIMATED METRICS
# ─────────────────────────────────────────────────────────────────────────────
st.subheader("πŸ–ΌοΈ Image preview & live metrics")
names = list(st.session_state.artifacts.keys())
pick = st.selectbox("Select image", names)
col_img, col_metrics = st.columns([3, 2], gap="large")
with col_img:
tabs = st.tabs([
"πŸ”΅ Combined",
"πŸ“ Nuclei outlines",
"πŸ“ Myotube outlines",
"🟣 Nuclei only",
"🟠 Myotubes only",
"πŸ“· Original",
"⬜ Nuclei mask",
"⬜ Myotube mask",
])
art = st.session_state.artifacts[pick]
bio_cur = st.session_state.bio_metrics.get(pick, {})
lpos = art["label_positions"]
iw = art["img_w"]
ih = art["img_h"]
# Build coloured overlays RIGHT NOW using the current sidebar colour / alpha.
# This means changing colour picker or alpha slider instantly updates the
# viewer β€” no rerun needed for display changes.
import base64 as _b64_disp
def _b64png_disp(arr):
return _b64_disp.b64encode(png_bytes(arr)).decode()
_rgb = art["rgb_u8"]
_nl = art["nuc_lab"]
_ml = art["myo_lab"]
_zero_nuc = np.zeros_like(_nl)
_zero_myo = np.zeros_like(_ml)
inst_b64 = _b64png_disp(make_coloured_overlay(_rgb, _nl, _ml,
alpha=float(alpha),
nuc_color=nuc_rgb, myo_color=myo_rgb))
nuc_only_b64 = _b64png_disp(make_coloured_overlay(_rgb, _nl, _zero_myo,
alpha=float(alpha),
nuc_color=nuc_rgb, myo_color=None))
myo_only_b64 = _b64png_disp(make_coloured_overlay(_rgb, _zero_nuc, _ml,
alpha=float(alpha),
nuc_color=None, myo_color=myo_rgb))
with tabs[0]:
html_combined = make_svg_viewer(
inst_b64, iw, ih, lpos,
show_nuclei=True, show_myotubes=True,
)
st.components.v1.html(html_combined, height=680, scrolling=False)
with tabs[1]:
# Nuclei-only outlines
nuc_outline_img = make_outline_overlay(
_rgb, _nl, np.zeros_like(_ml),
nuc_color=nuc_rgb, myo_color=(0, 255, 0),
line_width=2,
)
nuc_outline_b64 = _b64png_disp(nuc_outline_img)
nuc_outline_lpos = {"nuclei": lpos["nuclei"], "myotubes": [], "myotubes_nonbio": []}
html_nuc_outline = make_svg_viewer(
nuc_outline_b64, iw, ih, nuc_outline_lpos,
show_nuclei=True, show_myotubes=False,
)
st.components.v1.html(html_nuc_outline, height=680, scrolling=False)
with tabs[2]:
# Myotube-only outlines
myo_outline_img = make_outline_overlay(
_rgb, np.zeros_like(_nl), _ml,
nuc_color=nuc_rgb, myo_color=(0, 255, 0),
line_width=2,
)
myo_outline_b64 = _b64png_disp(myo_outline_img)
myo_outline_lpos = {"nuclei": [], "myotubes": lpos["myotubes"],
"myotubes_nonbio": lpos.get("myotubes_nonbio", [])}
html_myo_outline = make_svg_viewer(
myo_outline_b64, iw, ih, myo_outline_lpos,
show_nuclei=False, show_myotubes=True,
)
st.components.v1.html(html_myo_outline, height=680, scrolling=False)
with tabs[3]:
nuc_only_lpos = {"nuclei": lpos["nuclei"], "myotubes": [], "myotubes_nonbio": []}
html_nuc = make_svg_viewer(
nuc_only_b64, iw, ih, nuc_only_lpos,
show_nuclei=True, show_myotubes=False,
)
st.components.v1.html(html_nuc, height=680, scrolling=False)
with tabs[4]:
myo_only_lpos = {"nuclei": [], "myotubes": lpos["myotubes"],
"myotubes_nonbio": lpos.get("myotubes_nonbio", [])}
html_myo = make_svg_viewer(
myo_only_b64, iw, ih, myo_only_lpos,
show_nuclei=False, show_myotubes=True,
)
st.components.v1.html(html_myo, height=680, scrolling=False)
with tabs[5]:
st.image(art["rgb_u8"], use_container_width=True)
with tabs[6]:
st.image(art["nuc_pp"], use_container_width=True)
with tabs[7]:
st.image(art["myo_pp"], use_container_width=True)
with col_metrics:
st.markdown("#### πŸ“Š Live metrics")
bio = st.session_state.bio_metrics.get(pick, {})
per_areas = bio.get("_per_myotube_areas", [])
r1c1, r1c2, r1c3 = st.columns(3)
r2c1, r2c2, r2c3 = st.columns(3)
r3c1, r3c2, r3c3 = st.columns(3)
placeholders = {
"total_nuclei" : r1c1.empty(),
"myotube_count" : r1c2.empty(),
"myHC_positive_nuclei" : r1c3.empty(),
"myHC_positive_percentage": r2c1.empty(),
"fusion_index" : r2c2.empty(),
"avg_nuclei_per_myotube" : r2c3.empty(),
"total_area_um2" : r3c1.empty(),
"mean_area_um2" : r3c2.empty(),
"max_area_um2" : r3c3.empty(),
}
specs = [
("total_nuclei", "Total nuclei", "#4fc3f7", False),
("myotube_count", "Myotubes", "#ff8a65", False),
("myHC_positive_nuclei", "MyHC⁺ nuclei", "#a5d6a7", False),
("myHC_positive_percentage", "MyHC⁺ %", "#ce93d8", True),
("fusion_index", "Fusion index %", "#80cbc4", True),
("avg_nuclei_per_myotube", "Avg nuc/myotube", "#80deea", True),
("total_area_um2", f"Total area (Β΅mΒ²)", "#fff176", True),
("mean_area_um2", f"Mean area (Β΅mΒ²)", "#ffcc80", True),
("max_area_um2", f"Max area (Β΅mΒ²)", "#ef9a9a", True),
]
for key, label, color, is_float in specs:
val = bio.get(key, 0)
animated_metric(placeholders[key], label,
float(val) if is_float else int(val),
color=color)
if per_areas:
st.markdown("#### πŸ“ Per-myotube area")
area_df = pd.DataFrame({
"Myotube" : [f"M{i+1}" for i in range(len(per_areas))],
f"Area (Β΅mΒ²)" : per_areas,
}).set_index("Myotube")
st.bar_chart(area_df, height=220)
st.divider()
# ─────────────────────────────────────────────────────────────────────────────
# TRAINING CONTRIBUTION β€” User-initiated parameter + image submission
# ─────────────────────────────────────────────────────────────────────────────
if not private_mode and names:
st.subheader("πŸ“€ Submit image for training")
st.caption(
"Once you've tuned the sidebar parameters to get the best results for "
"this image, click below to submit both the image and your optimized "
"parameters as a training contribution. This helps MyoSeg learn "
"better settings for similar images."
)
train_pick = st.selectbox("Image to submit", names, key="train_pick")
if st.button("πŸ“€ Submit for training", type="primary"):
_ensure_dirs()
ts = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
folder = CORRECTIONS_DIR / f"params_{ts}"
folder.mkdir(parents=True, exist_ok=True)
# Save the original image
train_art = st.session_state.artifacts[train_pick]
Image.fromarray(train_art["rgb_u8"]).save(folder / "image.png")
# Save the postprocessed masks (from current parameter settings)
nuc_pp_arr = train_art.get("nuc_pp_arr")
myo_pp_arr = train_art.get("myo_pp_arr")
if nuc_pp_arr is not None:
Image.fromarray((nuc_pp_arr > 0).astype(np.uint8) * 255).save(folder / "nuclei_mask.png")
if myo_pp_arr is not None:
Image.fromarray((myo_pp_arr > 0).astype(np.uint8) * 255).save(folder / "myotube_mask.png")
# Save the current parameter set β€” self_train can learn from these
param_snapshot = {
"reason": "user_optimized_params",
"has_masks": nuc_pp_arr is not None and myo_pp_arr is not None,
"timestamp": ts,
"source_image": train_pick,
"parameters": {
"thr_nuc": float(thr_nuc),
"thr_myo": float(thr_myo),
"min_nuc_area": int(min_nuc_area),
"min_myo_area": int(min_myo_area),
"nuc_close_radius": int(nuc_close_radius),
"myo_open_radius": int(myo_open_radius),
"myo_erode_radius": int(myo_erode_radius),
"min_myo_aspect_ratio": float(min_myo_aspect_ratio),
"myo_max_area_px": int(myo_max_area_px),
"myo_split_min_seeds": int(myo_split_min_seeds),
"image_size": int(image_size),
},
"metrics": st.session_state.bio_metrics.get(train_pick, {}),
}
(folder / "meta.json").write_text(json.dumps(param_snapshot, indent=2, default=str))
st.success(
f"βœ… **{train_pick}** submitted for training with your optimized parameters. "
"The model will incorporate this at the next retraining cycle."
)
# ─────────────────────────────────────────────────────────────────────────────
# ACTIVE LEARNING β€” CORRECTION UPLOAD
# ─────────────────────────────────────────────────────────────────────────────
if enable_al and not private_mode:
st.subheader("🧠 Submit corrected labels (Active Learning)")
st.caption(
"Download the predicted masks, correct them in ImageJ/FIJI or any "
"image editor (white = foreground, black = background), then upload "
"the corrected versions below."
)
al_pick = st.selectbox("Correct masks for image", names, key="al_pick")
# Download buttons for current masks β€” lets users edit and re-upload
if al_pick in st.session_state.artifacts:
al_art = st.session_state.artifacts[al_pick]
dl1, dl2 = st.columns(2)
with dl1:
nuc_pp = al_art.get("nuc_pp_arr")
if nuc_pp is not None:
st.download_button(
"⬇️ Download nuclei mask",
data=png_bytes((nuc_pp * 255).astype(np.uint8)),
file_name=f"{al_pick}_nuclei_mask.png",
mime="image/png",
key="dl_nuc_mask",
)
else:
st.download_button(
"⬇️ Download nuclei mask",
data=al_art["nuc_pp"],
file_name=f"{al_pick}_nuclei_mask.png",
mime="image/png",
key="dl_nuc_mask",
)
with dl2:
myo_pp = al_art.get("myo_pp_arr")
if myo_pp is not None:
st.download_button(
"⬇️ Download myotube mask",
data=png_bytes((myo_pp * 255).astype(np.uint8)),
file_name=f"{al_pick}_myotube_mask.png",
mime="image/png",
key="dl_myo_mask",
)
else:
st.download_button(
"⬇️ Download myotube mask",
data=al_art["myo_pp"],
file_name=f"{al_pick}_myotube_mask.png",
mime="image/png",
key="dl_myo_mask",
)
acol1, acol2 = st.columns(2)
with acol1:
corr_nuc = st.file_uploader("Upload corrected NUCLEI mask",
type=["png", "tif", "tiff"], key="nuc_corr")
with acol2:
corr_myo = st.file_uploader("Upload corrected MYOTUBE mask",
type=["png", "tif", "tiff"], key="myo_corr")
if st.button("βœ… Submit corrections", type="primary"):
if corr_nuc is None or corr_myo is None:
st.error("Please upload BOTH a nuclei mask and a myotube mask.")
else:
orig_rgb = st.session_state.artifacts[al_pick]["rgb_u8"]
nuc_arr = (np.array(Image.open(corr_nuc).convert("L")) > 0).astype(np.uint8)
myo_arr = (np.array(Image.open(corr_myo).convert("L")) > 0).astype(np.uint8)
add_to_queue(orig_rgb, nuc_mask=nuc_arr, myo_mask=myo_arr,
reason="user_correction",
metadata={"source_image": al_pick,
"timestamp": datetime.now().isoformat()})
st.success(
f"βœ… Corrections for **{al_pick}** saved to `corrections/`. "
"The model will retrain at the next scheduled cycle."
)
elif enable_al and private_mode:
st.info(
"πŸ”’ Active learning and training submissions are disabled in Private mode. "
"Toggle off Private mode in the sidebar to enable."
)
st.divider()
# ─────────────────────────────────────────────────────────────────────────────
# RETRAINING QUEUE STATUS
# ─────────────────────────────────────────────────────────────────────────────
with st.expander("πŸ”§ Self-training queue status"):
_ensure_dirs()
q_items = list(QUEUE_DIR.glob("*.json"))
c_items = list(CORRECTIONS_DIR.glob("*/meta.json"))
sq1, sq2 = st.columns(2)
sq1.metric("Images in retraining queue", len(q_items))
sq2.metric("Corrected label pairs", len(c_items))
if q_items:
reasons = {}
for p in q_items:
try:
r = json.loads(p.read_text()).get("reason", "unknown")
reasons[r] = reasons.get(r, 0) + 1
except Exception:
pass
st.write("Queue breakdown:", reasons)
manifest = Path("manifest.json")
if manifest.exists():
try:
history = json.loads(manifest.read_text())
if history:
st.markdown("**Last 5 retraining runs:**")
hist_df = pd.DataFrame(history[-5:])
st.dataframe(hist_df, use_container_width=True)
except Exception:
pass
if st.button("πŸ”„ Trigger retraining now"):
import subprocess
subprocess.Popen(["python", "self_train.py", "--manual"])
st.info("Retraining started in the background. Check terminal / logs for progress.")