Spaces:
Running on Zero
Running on Zero
Sync from GitHub via hub-sync
Browse files
app.py
CHANGED
|
@@ -14,11 +14,18 @@ import csv
|
|
| 14 |
import joblib
|
| 15 |
import os
|
| 16 |
|
| 17 |
-
HF_REPO_ID
|
| 18 |
-
HF_REPO_ID2
|
|
|
|
| 19 |
MODEL_OPTIONS = {
|
| 20 |
"Hemocytometer Model": "hemocytometermodel.npy",
|
| 21 |
-
"General Model":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
}
|
| 23 |
|
| 24 |
loaded_models = {}
|
|
@@ -35,16 +42,13 @@ try:
|
|
| 35 |
except Exception as e:
|
| 36 |
print(f"Viability classifier not found or failed to load: {e}")
|
| 37 |
|
| 38 |
-
#
|
| 39 |
MAX_SIDE = 1024
|
| 40 |
MAX_PIXELS = 1024 * 1024
|
| 41 |
|
| 42 |
|
| 43 |
def safe_resize(image_np):
|
| 44 |
-
|
| 45 |
-
Downscale image to fit within MAX_SIDE and MAX_PIXELS while
|
| 46 |
-
preserving aspect ratio. Works for RGB / RGBA / grayscale.
|
| 47 |
-
"""
|
| 48 |
h, w = image_np.shape[:2]
|
| 49 |
total = h * w
|
| 50 |
|
|
@@ -152,11 +156,7 @@ FEATURE_COLS_INFERENCE = [
|
|
| 152 |
|
| 153 |
|
| 154 |
def classify_cells_by_model(image_np, masks):
|
| 155 |
-
|
| 156 |
-
Run the trained LogisticRegression classifier to predict live/dead per cell.
|
| 157 |
-
Returns (dead_count, alive_count, overlay_np, {cell_id: label}).
|
| 158 |
-
Requires VIABILITY_CLF and VIABILITY_SCALER to be loaded.
|
| 159 |
-
"""
|
| 160 |
import numpy as np
|
| 161 |
cell_ids = np.unique(masks)
|
| 162 |
cell_ids = cell_ids[cell_ids > 0]
|
|
@@ -188,11 +188,7 @@ def classify_cells_by_model(image_np, masks):
|
|
| 188 |
|
| 189 |
|
| 190 |
def draw_viability_overlay(image_np, masks, label_map):
|
| 191 |
-
|
| 192 |
-
Draw coloured contours + cell-number labels onto image_np.
|
| 193 |
-
label_map: {cell_id: 0=live, 1=dead}
|
| 194 |
-
Returns a uint8 numpy array.
|
| 195 |
-
"""
|
| 196 |
overlay = image_np.copy()
|
| 197 |
cell_ids = np.unique(masks)
|
| 198 |
cell_ids = cell_ids[cell_ids > 0]
|
|
@@ -220,114 +216,10 @@ def draw_viability_overlay(image_np, masks, label_map):
|
|
| 220 |
(0, 0, 0), -1)
|
| 221 |
cv2.putText(overlay, label_str,
|
| 222 |
(cx - tw//2, cy + th//2),
|
| 223 |
-
font, font_scale, color, thickness, cv2.LINE_AA)
|
| 224 |
return overlay
|
| 225 |
|
| 226 |
|
| 227 |
-
def classify_cells_by_blueness(image_np, masks, threshold_bias):
|
| 228 |
-
"""
|
| 229 |
-
Classify cells as dead (blue) or alive using an adaptive Otsu threshold
|
| 230 |
-
on per-cell blueness scores, with a user bias to fine-tune.
|
| 231 |
-
|
| 232 |
-
Args:
|
| 233 |
-
image_np: RGB image array
|
| 234 |
-
masks: Cellpose segmentation masks
|
| 235 |
-
threshold_bias: Slider value -50..+50; shifts Otsu threshold up/down.
|
| 236 |
-
Negative = more cells classified dead (looser).
|
| 237 |
-
Positive = fewer cells classified dead (stricter).
|
| 238 |
-
0 = pure Otsu (fully automatic).
|
| 239 |
-
|
| 240 |
-
Returns:
|
| 241 |
-
dead_count, alive_count, colored_overlay, otsu_threshold, final_threshold
|
| 242 |
-
"""
|
| 243 |
-
|
| 244 |
-
if len(image_np.shape) == 2:
|
| 245 |
-
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
|
| 246 |
-
elif len(image_np.shape) == 3 and image_np.shape[2] == 4:
|
| 247 |
-
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
|
| 248 |
-
|
| 249 |
-
hsv = cv2.cvtColor(image_np, cv2.COLOR_RGB2HSV)
|
| 250 |
-
|
| 251 |
-
hue = hsv[:, :, 0].astype(np.float32)
|
| 252 |
-
saturation = hsv[:, :, 1].astype(np.float32)
|
| 253 |
-
|
| 254 |
-
# Raw blueness: hue proximity to 115° × saturation
|
| 255 |
-
hue_distance = np.minimum(np.abs(hue - 115), 180 - np.abs(hue - 115))
|
| 256 |
-
hue_score = np.maximum(0, 1 - hue_distance / 65)
|
| 257 |
-
blueness = hue_score * (saturation / 255.0)
|
| 258 |
-
|
| 259 |
-
# --- Compute per-cell mean blueness scores ---
|
| 260 |
-
cell_ids = np.unique(masks)
|
| 261 |
-
cell_ids = cell_ids[cell_ids > 0]
|
| 262 |
-
|
| 263 |
-
if len(cell_ids) == 0:
|
| 264 |
-
blank = image_np.copy()
|
| 265 |
-
return 0, 0, blank, 0.0, 0.0
|
| 266 |
-
|
| 267 |
-
cell_scores = np.array([np.mean(blueness[masks == cid]) for cid in cell_ids])
|
| 268 |
-
|
| 269 |
-
# --- Otsu on the distribution of per-cell scores ---
|
| 270 |
-
# cv2.threshold expects uint8; scale 0-1 → 0-255
|
| 271 |
-
scores_u8 = (np.clip(cell_scores, 0, 1) * 255).astype(np.uint8)
|
| 272 |
-
|
| 273 |
-
if scores_u8.max() == scores_u8.min():
|
| 274 |
-
# All cells identical → Otsu is undefined; use midpoint
|
| 275 |
-
otsu_threshold = float(scores_u8[0]) / 255.0
|
| 276 |
-
else:
|
| 277 |
-
# Reshape to a single-column image so cv2.threshold works
|
| 278 |
-
thresh_val, _ = cv2.threshold(
|
| 279 |
-
scores_u8.reshape(-1, 1), 0, 255,
|
| 280 |
-
cv2.THRESH_BINARY + cv2.THRESH_OTSU
|
| 281 |
-
)
|
| 282 |
-
otsu_threshold = thresh_val / 255.0
|
| 283 |
-
|
| 284 |
-
# --- Apply user bias: slider -50..+50 maps to ±0.20 shift ---
|
| 285 |
-
bias = (threshold_bias / 50.0) * 0.20
|
| 286 |
-
final_threshold = float(np.clip(otsu_threshold + bias, 0.0, 1.0))
|
| 287 |
-
|
| 288 |
-
# --- Classify ---
|
| 289 |
-
dead_cells = [cid for cid, s in zip(cell_ids, cell_scores) if s > final_threshold]
|
| 290 |
-
alive_cells = [cid for cid, s in zip(cell_ids, cell_scores) if s <= final_threshold]
|
| 291 |
-
|
| 292 |
-
# --- Outline-only overlay on raw image with enumerated labels ---
|
| 293 |
-
final_overlay = image_np.copy()
|
| 294 |
-
|
| 295 |
-
# Compute a consistent enumeration order (cell_ids is already sorted ascending)
|
| 296 |
-
cell_enum = {cid: idx + 1 for idx, cid in enumerate(cell_ids)}
|
| 297 |
-
|
| 298 |
-
dead_set = set(dead_cells)
|
| 299 |
-
alive_set = set(alive_cells)
|
| 300 |
-
|
| 301 |
-
for cid in cell_ids:
|
| 302 |
-
cell_mask = (masks == cid).astype(np.uint8)
|
| 303 |
-
contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 304 |
-
color = (220, 50, 50) if cid in dead_set else (50, 220, 80)
|
| 305 |
-
cv2.drawContours(final_overlay, contours, -1, color, thickness=2)
|
| 306 |
-
|
| 307 |
-
# Draw enumeration label at centroid
|
| 308 |
-
ys, xs = np.where(cell_mask)
|
| 309 |
-
if len(ys) > 0:
|
| 310 |
-
cx, cy = int(xs.mean()), int(ys.mean())
|
| 311 |
-
label_str = str(cell_enum[cid])
|
| 312 |
-
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 313 |
-
font_scale = 0.35
|
| 314 |
-
thickness = 1
|
| 315 |
-
(tw, th), _ = cv2.getTextSize(label_str, font, font_scale, thickness)
|
| 316 |
-
# Dark background rectangle for readability
|
| 317 |
-
cv2.rectangle(
|
| 318 |
-
final_overlay,
|
| 319 |
-
(cx - tw // 2 - 1, cy - th // 2 - 1),
|
| 320 |
-
(cx + tw // 2 + 1, cy + th // 2 + 1),
|
| 321 |
-
(0, 0, 0),
|
| 322 |
-
-1
|
| 323 |
-
)
|
| 324 |
-
cv2.putText(
|
| 325 |
-
final_overlay, label_str,
|
| 326 |
-
(cx - tw // 2, cy + th // 2),
|
| 327 |
-
font, font_scale, color, thickness, cv2.LINE_AA
|
| 328 |
-
)
|
| 329 |
-
|
| 330 |
-
return len(dead_cells), len(alive_cells), final_overlay, otsu_threshold, final_threshold
|
| 331 |
|
| 332 |
|
| 333 |
def measure_confluency(masks, image_np):
|
|
@@ -451,12 +343,10 @@ def warp_polygon_to_square(image_np, points):
|
|
| 451 |
|
| 452 |
|
| 453 |
def toggle_stereological_mode(use_stereology):
|
| 454 |
-
"""Show/hide stereological controls based on checkbox"""
|
| 455 |
return gr.update(visible=use_stereology)
|
| 456 |
|
| 457 |
|
| 458 |
def update_exclusion_preview(image, left_width, top_width):
|
| 459 |
-
"""Update the preview image with exclusion zone overlay"""
|
| 460 |
if image is None:
|
| 461 |
return None
|
| 462 |
|
|
@@ -465,9 +355,8 @@ def update_exclusion_preview(image, left_width, top_width):
|
|
| 465 |
return Image.fromarray(overlay)
|
| 466 |
|
| 467 |
|
| 468 |
-
# ---------------------------------------------------------------------------
|
| 469 |
# Patch segmentation
|
| 470 |
-
|
| 471 |
PATCH_SIZE = 512 # target patch side length
|
| 472 |
PATCH_OVERLAP = 64 # overlap border on each edge (pixels)
|
| 473 |
MIN_PATCH_DIM = 256 # don't bother patching if image fits comfortably
|
|
@@ -567,7 +456,7 @@ def _segment_patch(args):
|
|
| 567 |
model = models.CellposeModel(gpu=True, pretrained_model=model_path)
|
| 568 |
loaded_models[model_filename] = model
|
| 569 |
|
| 570 |
-
mask, _, _ = model.eval(patch_np, diameter=None
|
| 571 |
return mask, row_start, col_start
|
| 572 |
|
| 573 |
|
|
@@ -579,7 +468,8 @@ def run_segmentation_patched(image_np, model_filename):
|
|
| 579 |
that patching adds overhead without benefit.
|
| 580 |
"""
|
| 581 |
h, w = image_np.shape[:2]
|
| 582 |
-
|
|
|
|
| 583 |
if model_filename in loaded_models:
|
| 584 |
model = loaded_models[model_filename]
|
| 585 |
else:
|
|
@@ -588,15 +478,16 @@ def run_segmentation_patched(image_np, model_filename):
|
|
| 588 |
|
| 589 |
# Small images: no benefit from patching
|
| 590 |
if max(h, w) <= MIN_PATCH_DIM * 2:
|
| 591 |
-
mask, _, _ = model.eval(image_np, diameter=None
|
| 592 |
return mask, 1 # 1 patch
|
| 593 |
|
| 594 |
patches = _split_patches(image_np)
|
| 595 |
n_patches = len(patches)
|
| 596 |
|
| 597 |
# Build argument list for the thread pool
|
|
|
|
| 598 |
args_list = [
|
| 599 |
-
(patch, r, c, model_filename,
|
| 600 |
for patch, r, c in patches
|
| 601 |
]
|
| 602 |
|
|
@@ -618,6 +509,7 @@ def run_segmentation_patched(image_np, model_filename):
|
|
| 618 |
|
| 619 |
@spaces.GPU
|
| 620 |
def run_segmentation(image, model_choice, min_cell_size, max_cell_size,
|
|
|
|
| 621 |
use_stereology, left_exclusion, top_exclusion,
|
| 622 |
crop_points=None):
|
| 623 |
image_np = np.array(image)
|
|
@@ -661,21 +553,18 @@ def run_segmentation(image, model_choice, min_cell_size, max_cell_size,
|
|
| 661 |
print("p90:", np.percentile(sizes, 90) if len(sizes) > 0 else 0)
|
| 662 |
print("max:", sizes.max() if len(sizes) > 0 else 0)
|
| 663 |
|
| 664 |
-
# Compute recommendation from RAW masks
|
| 665 |
recommend_min = rec_min_size(masks_raw)
|
| 666 |
|
| 667 |
-
#
|
| 668 |
-
min_used = recommend_min if (min_cell_size == 0) else int(min_cell_size)
|
| 669 |
-
|
| 670 |
-
# Apply filters
|
| 671 |
masks = masks_raw.copy()
|
| 672 |
removed_small = 0
|
| 673 |
removed_large = 0
|
| 674 |
|
| 675 |
-
if
|
| 676 |
-
masks, removed_small = filter_mask_by_size(masks,
|
| 677 |
|
| 678 |
-
if max_cell_size > 0:
|
| 679 |
masks, removed_large = filter_mask_by_maxsize(masks, int(max_cell_size))
|
| 680 |
|
| 681 |
# Apply stereological exclusion if enabled
|
|
@@ -687,7 +576,7 @@ def run_segmentation(image, model_choice, min_cell_size, max_cell_size,
|
|
| 687 |
|
| 688 |
filter_msg = ""
|
| 689 |
if removed_small:
|
| 690 |
-
filter_msg += f"Removed {removed_small} small objects (< {
|
| 691 |
if removed_large:
|
| 692 |
filter_msg += f"Removed {removed_large} large objects (> {int(max_cell_size)} pixels).\n"
|
| 693 |
if use_stereology and excluded_count > 0:
|
|
@@ -729,7 +618,7 @@ def run_segmentation(image, model_choice, min_cell_size, max_cell_size,
|
|
| 729 |
pack_array(masks),
|
| 730 |
pack_array(processed_image_np),
|
| 731 |
confluency,
|
| 732 |
-
|
| 733 |
pack_array(raw_image_np),
|
| 734 |
)
|
| 735 |
|
|
@@ -744,13 +633,12 @@ def run_segmentation(image, model_choice, min_cell_size, max_cell_size,
|
|
| 744 |
None,
|
| 745 |
None,
|
| 746 |
0.0,
|
| 747 |
-
|
| 748 |
None,
|
| 749 |
)
|
| 750 |
|
| 751 |
|
| 752 |
def run_viability(stored_masks, stored_image_np):
|
| 753 |
-
"""Run model-based viability classification. Returns overlay + counts + label_map."""
|
| 754 |
if stored_masks is None or stored_image_np is None:
|
| 755 |
return None, 0, 0, 0.0, "Please run segmentation first.", {}
|
| 756 |
if VIABILITY_CLF is None:
|
|
@@ -773,14 +661,20 @@ def run_viability(stored_masks, stored_image_np):
|
|
| 773 |
|
| 774 |
|
| 775 |
def pack_array(arr):
|
| 776 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 777 |
buf = io.BytesIO()
|
| 778 |
-
|
| 779 |
return buf.getvalue()
|
| 780 |
|
| 781 |
|
| 782 |
def unpack_array(data):
|
| 783 |
-
|
|
|
|
| 784 |
|
| 785 |
|
| 786 |
def save_tab_result(cell_count, confluency, viab_percent):
|
|
@@ -820,29 +714,12 @@ def compute_summary(r1, r2, r3, r4):
|
|
| 820 |
return avg_count, avg_conf, avg_viab, "\n".join(lines)
|
| 821 |
|
| 822 |
|
| 823 |
-
|
| 824 |
# Training data export — feature extraction per cell
|
| 825 |
-
|
| 826 |
|
| 827 |
def extract_cell_features(image_np, masks):
|
| 828 |
-
|
| 829 |
-
For every segmented cell, extract a fixed feature vector from the pixels
|
| 830 |
-
inside its mask. Returns a list of dicts, one per cell.
|
| 831 |
-
|
| 832 |
-
Features:
|
| 833 |
-
RGB channels — mean_r, mean_g, mean_b, std_r, std_g, std_b
|
| 834 |
-
HSV channels — mean_h, mean_s, mean_v, std_s, std_v
|
| 835 |
-
Ratios — blue_red_ratio, blue_green_ratio, rg_ratio
|
| 836 |
-
Morphology — area_px, circularity
|
| 837 |
-
Centre/edge profile — inner_brightness, peak_brightness,
|
| 838 |
-
bright_spot_fraction, ring_darkness,
|
| 839 |
-
centre_periphery_ratio, brightness_std_normalised
|
| 840 |
-
|
| 841 |
-
Profile zones are tuned to hemocytometer live-cell morphology:
|
| 842 |
-
a small intense specular highlight at the centre surrounded by a dark
|
| 843 |
-
navy membrane ring. Dead cells are pale blue-grey blobs with no ring
|
| 844 |
-
and no bright spot.
|
| 845 |
-
"""
|
| 846 |
if len(image_np.shape) == 2:
|
| 847 |
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
|
| 848 |
elif image_np.shape[2] == 4:
|
|
@@ -1003,9 +880,8 @@ def prepare_export(stored_masks, stored_image, threshold_bias):
|
|
| 1003 |
return path, msg
|
| 1004 |
|
| 1005 |
|
| 1006 |
-
|
| 1007 |
# Tab builder
|
| 1008 |
-
# ---------------------------------------------------------------------------
|
| 1009 |
|
| 1010 |
def draw_polygon_overlay(image_pil, points):
|
| 1011 |
"""
|
|
@@ -1063,14 +939,12 @@ def clear_crop_points(image_pil):
|
|
| 1063 |
|
| 1064 |
|
| 1065 |
|
| 1066 |
-
# ---------------------------------------------------------------------------
|
| 1067 |
# Label correction grid
|
| 1068 |
-
# ---------------------------------------------------------------------------
|
| 1069 |
|
| 1070 |
-
THUMB_SIZE = 80
|
| 1071 |
-
GRID_COLS =
|
| 1072 |
-
BORDER = 4
|
| 1073 |
-
LABEL_H = 16
|
| 1074 |
|
| 1075 |
def _crop_cell_thumb(image_np, masks, cid):
|
| 1076 |
"""
|
|
@@ -1102,14 +976,7 @@ def _crop_cell_thumb(image_np, masks, cid):
|
|
| 1102 |
|
| 1103 |
|
| 1104 |
def build_correction_grid(image_np, masks, labelled_features, raw_image_np=None):
|
| 1105 |
-
|
| 1106 |
-
Render all cell thumbnails into a single PIL image grid.
|
| 1107 |
-
Each thumbnail has a coloured border: green=live(0), red=dead(1).
|
| 1108 |
-
A small number in the corner identifies the cell_id.
|
| 1109 |
-
|
| 1110 |
-
Returns the PIL grid image.
|
| 1111 |
-
Cell order in the grid matches the order of labelled_features.
|
| 1112 |
-
"""
|
| 1113 |
if not labelled_features:
|
| 1114 |
placeholder = Image.fromarray(
|
| 1115 |
np.zeros((THUMB_SIZE, THUMB_SIZE, 3), dtype=np.uint8)
|
|
@@ -1256,14 +1123,29 @@ def build_tab(tab_index, masks_state, image_state, result_state):
|
|
| 1256 |
value="Hemocytometer Model"
|
| 1257 |
)
|
| 1258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1259 |
min_size_slider = gr.Slider(
|
| 1260 |
minimum=0,
|
| 1261 |
maximum=500,
|
| 1262 |
value=0,
|
| 1263 |
step=10,
|
| 1264 |
-
label="Minimum Cell Size (pixels)
|
|
|
|
|
|
|
|
|
|
| 1265 |
)
|
| 1266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1267 |
max_size_slider = gr.Slider(
|
| 1268 |
minimum=0,
|
| 1269 |
maximum=10000,
|
|
@@ -1405,9 +1287,10 @@ def build_tab(tab_index, masks_state, image_state, result_state):
|
|
| 1405 |
segment_btn.click(
|
| 1406 |
fn=run_segmentation,
|
| 1407 |
inputs=[img_input, model_dropdown, min_size_slider, max_size_slider,
|
|
|
|
| 1408 |
use_stereo, left_excl, top_excl, crop_points_state],
|
| 1409 |
outputs=[cell_count_out, overlay_out, info_out, viability_section,
|
| 1410 |
-
masks_state, image_state, confluency_out,
|
| 1411 |
)
|
| 1412 |
|
| 1413 |
# ---- Run Viability button -------------------------------------------
|
|
@@ -1496,9 +1379,8 @@ def build_tab(tab_index, masks_state, image_state, result_state):
|
|
| 1496 |
|
| 1497 |
|
| 1498 |
|
| 1499 |
-
# ---------------------------------------------------------------------------
|
| 1500 |
# Gradio interface
|
| 1501 |
-
|
| 1502 |
with gr.Blocks(
|
| 1503 |
title="CellposeCellCounter",
|
| 1504 |
theme=gr.themes.Soft(),
|
|
|
|
| 14 |
import joblib
|
| 15 |
import os
|
| 16 |
|
| 17 |
+
HF_REPO_ID = "myang4218/cellposemodel"
|
| 18 |
+
HF_REPO_ID2 = "LiangLabUMB/viability_model"
|
| 19 |
+
HF_REPO_CPSAM = "mouseland/cellpose-sam"
|
| 20 |
MODEL_OPTIONS = {
|
| 21 |
"Hemocytometer Model": "hemocytometermodel.npy",
|
| 22 |
+
"General Model": "generalmodel.npy",
|
| 23 |
+
"Cellpose SAMv2": "cpsam_v2",
|
| 24 |
+
}
|
| 25 |
+
MODEL_REPOS = {
|
| 26 |
+
"hemocytometermodel.npy": HF_REPO_ID,
|
| 27 |
+
"generalmodel.npy": HF_REPO_ID,
|
| 28 |
+
"cpsam_v2": HF_REPO_CPSAM,
|
| 29 |
}
|
| 30 |
|
| 31 |
loaded_models = {}
|
|
|
|
| 42 |
except Exception as e:
|
| 43 |
print(f"Viability classifier not found or failed to load: {e}")
|
| 44 |
|
| 45 |
+
# mobile safe resize limits
|
| 46 |
MAX_SIDE = 1024
|
| 47 |
MAX_PIXELS = 1024 * 1024
|
| 48 |
|
| 49 |
|
| 50 |
def safe_resize(image_np):
|
| 51 |
+
|
|
|
|
|
|
|
|
|
|
| 52 |
h, w = image_np.shape[:2]
|
| 53 |
total = h * w
|
| 54 |
|
|
|
|
| 156 |
|
| 157 |
|
| 158 |
def classify_cells_by_model(image_np, masks):
|
| 159 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
import numpy as np
|
| 161 |
cell_ids = np.unique(masks)
|
| 162 |
cell_ids = cell_ids[cell_ids > 0]
|
|
|
|
| 188 |
|
| 189 |
|
| 190 |
def draw_viability_overlay(image_np, masks, label_map):
|
| 191 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
overlay = image_np.copy()
|
| 193 |
cell_ids = np.unique(masks)
|
| 194 |
cell_ids = cell_ids[cell_ids > 0]
|
|
|
|
| 216 |
(0, 0, 0), -1)
|
| 217 |
cv2.putText(overlay, label_str,
|
| 218 |
(cx - tw//2, cy + th//2),
|
| 219 |
+
font, font_scale, color, thickness, cv2.LINE_AA)
|
| 220 |
return overlay
|
| 221 |
|
| 222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
|
| 225 |
def measure_confluency(masks, image_np):
|
|
|
|
| 343 |
|
| 344 |
|
| 345 |
def toggle_stereological_mode(use_stereology):
|
|
|
|
| 346 |
return gr.update(visible=use_stereology)
|
| 347 |
|
| 348 |
|
| 349 |
def update_exclusion_preview(image, left_width, top_width):
|
|
|
|
| 350 |
if image is None:
|
| 351 |
return None
|
| 352 |
|
|
|
|
| 355 |
return Image.fromarray(overlay)
|
| 356 |
|
| 357 |
|
|
|
|
| 358 |
# Patch segmentation
|
| 359 |
+
|
| 360 |
PATCH_SIZE = 512 # target patch side length
|
| 361 |
PATCH_OVERLAP = 64 # overlap border on each edge (pixels)
|
| 362 |
MIN_PATCH_DIM = 256 # don't bother patching if image fits comfortably
|
|
|
|
| 456 |
model = models.CellposeModel(gpu=True, pretrained_model=model_path)
|
| 457 |
loaded_models[model_filename] = model
|
| 458 |
|
| 459 |
+
mask, _, _ = model.eval(patch_np, diameter=None)
|
| 460 |
return mask, row_start, col_start
|
| 461 |
|
| 462 |
|
|
|
|
| 468 |
that patching adds overhead without benefit.
|
| 469 |
"""
|
| 470 |
h, w = image_np.shape[:2]
|
| 471 |
+
repo = MODEL_REPOS.get(model_filename, HF_REPO_ID)
|
| 472 |
+
model_path = hf_hub_download(repo_id=repo, filename=model_filename)
|
| 473 |
if model_filename in loaded_models:
|
| 474 |
model = loaded_models[model_filename]
|
| 475 |
else:
|
|
|
|
| 478 |
|
| 479 |
# Small images: no benefit from patching
|
| 480 |
if max(h, w) <= MIN_PATCH_DIM * 2:
|
| 481 |
+
mask, _, _ = model.eval(image_np, diameter=None)
|
| 482 |
return mask, 1 # 1 patch
|
| 483 |
|
| 484 |
patches = _split_patches(image_np)
|
| 485 |
n_patches = len(patches)
|
| 486 |
|
| 487 |
# Build argument list for the thread pool
|
| 488 |
+
patch_repo = MODEL_REPOS.get(model_filename, HF_REPO_ID)
|
| 489 |
args_list = [
|
| 490 |
+
(patch, r, c, model_filename, patch_repo)
|
| 491 |
for patch, r, c in patches
|
| 492 |
]
|
| 493 |
|
|
|
|
| 509 |
|
| 510 |
@spaces.GPU
|
| 511 |
def run_segmentation(image, model_choice, min_cell_size, max_cell_size,
|
| 512 |
+
use_min_filter, use_max_filter,
|
| 513 |
use_stereology, left_exclusion, top_exclusion,
|
| 514 |
crop_points=None):
|
| 515 |
image_np = np.array(image)
|
|
|
|
| 553 |
print("p90:", np.percentile(sizes, 90) if len(sizes) > 0 else 0)
|
| 554 |
print("max:", sizes.max() if len(sizes) > 0 else 0)
|
| 555 |
|
| 556 |
+
# Compute recommendation from RAW masks (always shown, never auto-applied)
|
| 557 |
recommend_min = rec_min_size(masks_raw)
|
| 558 |
|
| 559 |
+
# Apply filters only if their checkboxes are enabled
|
|
|
|
|
|
|
|
|
|
| 560 |
masks = masks_raw.copy()
|
| 561 |
removed_small = 0
|
| 562 |
removed_large = 0
|
| 563 |
|
| 564 |
+
if use_min_filter and int(min_cell_size) > 0:
|
| 565 |
+
masks, removed_small = filter_mask_by_size(masks, int(min_cell_size))
|
| 566 |
|
| 567 |
+
if use_max_filter and max_cell_size > 0:
|
| 568 |
masks, removed_large = filter_mask_by_maxsize(masks, int(max_cell_size))
|
| 569 |
|
| 570 |
# Apply stereological exclusion if enabled
|
|
|
|
| 576 |
|
| 577 |
filter_msg = ""
|
| 578 |
if removed_small:
|
| 579 |
+
filter_msg += f"Removed {removed_small} small objects (< {int(min_cell_size)} pixels).\n"
|
| 580 |
if removed_large:
|
| 581 |
filter_msg += f"Removed {removed_large} large objects (> {int(max_cell_size)} pixels).\n"
|
| 582 |
if use_stereology and excluded_count > 0:
|
|
|
|
| 618 |
pack_array(masks),
|
| 619 |
pack_array(processed_image_np),
|
| 620 |
confluency,
|
| 621 |
+
f"Recommended minimum: **{recommend_min} px** (25th percentile of detected cell sizes)",
|
| 622 |
pack_array(raw_image_np),
|
| 623 |
)
|
| 624 |
|
|
|
|
| 633 |
None,
|
| 634 |
None,
|
| 635 |
0.0,
|
| 636 |
+
"",
|
| 637 |
None,
|
| 638 |
)
|
| 639 |
|
| 640 |
|
| 641 |
def run_viability(stored_masks, stored_image_np):
|
|
|
|
| 642 |
if stored_masks is None or stored_image_np is None:
|
| 643 |
return None, 0, 0, 0.0, "Please run segmentation first.", {}
|
| 644 |
if VIABILITY_CLF is None:
|
|
|
|
| 661 |
|
| 662 |
|
| 663 |
def pack_array(arr):
|
| 664 |
+
"""
|
| 665 |
+
Serialise a numpy array to bytes for gr.State storage.
|
| 666 |
+
Uses numpy's .npy format (not PNG) so integer dtypes of any
|
| 667 |
+
magnitude are preserved exactly — PNG is 8-bit only and silently
|
| 668 |
+
truncates cell IDs above 255.
|
| 669 |
+
"""
|
| 670 |
buf = io.BytesIO()
|
| 671 |
+
np.save(buf, arr)
|
| 672 |
return buf.getvalue()
|
| 673 |
|
| 674 |
|
| 675 |
def unpack_array(data):
|
| 676 |
+
buf = io.BytesIO(data)
|
| 677 |
+
return np.load(buf, allow_pickle=False)
|
| 678 |
|
| 679 |
|
| 680 |
def save_tab_result(cell_count, confluency, viab_percent):
|
|
|
|
| 714 |
return avg_count, avg_conf, avg_viab, "\n".join(lines)
|
| 715 |
|
| 716 |
|
| 717 |
+
|
| 718 |
# Training data export — feature extraction per cell
|
| 719 |
+
|
| 720 |
|
| 721 |
def extract_cell_features(image_np, masks):
|
| 722 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 723 |
if len(image_np.shape) == 2:
|
| 724 |
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
|
| 725 |
elif image_np.shape[2] == 4:
|
|
|
|
| 880 |
return path, msg
|
| 881 |
|
| 882 |
|
| 883 |
+
|
| 884 |
# Tab builder
|
|
|
|
| 885 |
|
| 886 |
def draw_polygon_overlay(image_pil, points):
|
| 887 |
"""
|
|
|
|
| 939 |
|
| 940 |
|
| 941 |
|
|
|
|
| 942 |
# Label correction grid
|
|
|
|
| 943 |
|
| 944 |
+
THUMB_SIZE = 80
|
| 945 |
+
GRID_COLS = 10
|
| 946 |
+
BORDER = 4
|
| 947 |
+
LABEL_H = 16
|
| 948 |
|
| 949 |
def _crop_cell_thumb(image_np, masks, cid):
|
| 950 |
"""
|
|
|
|
| 976 |
|
| 977 |
|
| 978 |
def build_correction_grid(image_np, masks, labelled_features, raw_image_np=None):
|
| 979 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 980 |
if not labelled_features:
|
| 981 |
placeholder = Image.fromarray(
|
| 982 |
np.zeros((THUMB_SIZE, THUMB_SIZE, 3), dtype=np.uint8)
|
|
|
|
| 1123 |
value="Hemocytometer Model"
|
| 1124 |
)
|
| 1125 |
|
| 1126 |
+
gr.Markdown("### Size Filters")
|
| 1127 |
+
|
| 1128 |
+
use_min_filter = gr.Checkbox(
|
| 1129 |
+
label="Enable minimum size filter",
|
| 1130 |
+
value=False,
|
| 1131 |
+
info="Remove objects smaller than the threshold below"
|
| 1132 |
+
)
|
| 1133 |
min_size_slider = gr.Slider(
|
| 1134 |
minimum=0,
|
| 1135 |
maximum=500,
|
| 1136 |
value=0,
|
| 1137 |
step=10,
|
| 1138 |
+
label="Minimum Cell Size (pixels)",
|
| 1139 |
+
)
|
| 1140 |
+
min_size_recommendation = gr.Markdown(
|
| 1141 |
+
value="*Run segmentation to see recommended minimum*",
|
| 1142 |
)
|
| 1143 |
|
| 1144 |
+
use_max_filter = gr.Checkbox(
|
| 1145 |
+
label="Enable maximum size filter",
|
| 1146 |
+
value=False,
|
| 1147 |
+
info="Remove objects larger than the threshold below"
|
| 1148 |
+
)
|
| 1149 |
max_size_slider = gr.Slider(
|
| 1150 |
minimum=0,
|
| 1151 |
maximum=10000,
|
|
|
|
| 1287 |
segment_btn.click(
|
| 1288 |
fn=run_segmentation,
|
| 1289 |
inputs=[img_input, model_dropdown, min_size_slider, max_size_slider,
|
| 1290 |
+
use_min_filter, use_max_filter,
|
| 1291 |
use_stereo, left_excl, top_excl, crop_points_state],
|
| 1292 |
outputs=[cell_count_out, overlay_out, info_out, viability_section,
|
| 1293 |
+
masks_state, image_state, confluency_out, min_size_recommendation, raw_image_state]
|
| 1294 |
)
|
| 1295 |
|
| 1296 |
# ---- Run Viability button -------------------------------------------
|
|
|
|
| 1379 |
|
| 1380 |
|
| 1381 |
|
|
|
|
| 1382 |
# Gradio interface
|
| 1383 |
+
|
| 1384 |
with gr.Blocks(
|
| 1385 |
title="CellposeCellCounter",
|
| 1386 |
theme=gr.themes.Soft(),
|