LiangLabUMB commited on
Commit
86fb55a
·
verified ·
1 Parent(s): 3ea28ee

Sync from GitHub via hub-sync

Browse files
Files changed (1) hide show
  1. app.py +68 -186
app.py CHANGED
@@ -14,11 +14,18 @@ import csv
14
  import joblib
15
  import os
16
 
17
- HF_REPO_ID = "myang4218/cellposemodel"
18
- HF_REPO_ID2 = "LiangLabUMB/viability_model"
 
19
  MODEL_OPTIONS = {
20
  "Hemocytometer Model": "hemocytometermodel.npy",
21
- "General Model": "generalmodel.npy"
 
 
 
 
 
 
22
  }
23
 
24
  loaded_models = {}
@@ -35,16 +42,13 @@ try:
35
  except Exception as e:
36
  print(f"Viability classifier not found or failed to load: {e}")
37
 
38
- # ---- mobile-safe size limits (aggressive for Safari) ----
39
  MAX_SIDE = 1024
40
  MAX_PIXELS = 1024 * 1024
41
 
42
 
43
  def safe_resize(image_np):
44
- """
45
- Downscale image to fit within MAX_SIDE and MAX_PIXELS while
46
- preserving aspect ratio. Works for RGB / RGBA / grayscale.
47
- """
48
  h, w = image_np.shape[:2]
49
  total = h * w
50
 
@@ -152,11 +156,7 @@ FEATURE_COLS_INFERENCE = [
152
 
153
 
154
  def classify_cells_by_model(image_np, masks):
155
- """
156
- Run the trained LogisticRegression classifier to predict live/dead per cell.
157
- Returns (dead_count, alive_count, overlay_np, {cell_id: label}).
158
- Requires VIABILITY_CLF and VIABILITY_SCALER to be loaded.
159
- """
160
  import numpy as np
161
  cell_ids = np.unique(masks)
162
  cell_ids = cell_ids[cell_ids > 0]
@@ -188,11 +188,7 @@ def classify_cells_by_model(image_np, masks):
188
 
189
 
190
  def draw_viability_overlay(image_np, masks, label_map):
191
- """
192
- Draw coloured contours + cell-number labels onto image_np.
193
- label_map: {cell_id: 0=live, 1=dead}
194
- Returns a uint8 numpy array.
195
- """
196
  overlay = image_np.copy()
197
  cell_ids = np.unique(masks)
198
  cell_ids = cell_ids[cell_ids > 0]
@@ -220,114 +216,10 @@ def draw_viability_overlay(image_np, masks, label_map):
220
  (0, 0, 0), -1)
221
  cv2.putText(overlay, label_str,
222
  (cx - tw//2, cy + th//2),
223
- font, font_scale, color, thickness, cv2.LINE_AA)
224
  return overlay
225
 
226
 
227
- def classify_cells_by_blueness(image_np, masks, threshold_bias):
228
- """
229
- Classify cells as dead (blue) or alive using an adaptive Otsu threshold
230
- on per-cell blueness scores, with a user bias to fine-tune.
231
-
232
- Args:
233
- image_np: RGB image array
234
- masks: Cellpose segmentation masks
235
- threshold_bias: Slider value -50..+50; shifts Otsu threshold up/down.
236
- Negative = more cells classified dead (looser).
237
- Positive = fewer cells classified dead (stricter).
238
- 0 = pure Otsu (fully automatic).
239
-
240
- Returns:
241
- dead_count, alive_count, colored_overlay, otsu_threshold, final_threshold
242
- """
243
-
244
- if len(image_np.shape) == 2:
245
- image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
246
- elif len(image_np.shape) == 3 and image_np.shape[2] == 4:
247
- image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
248
-
249
- hsv = cv2.cvtColor(image_np, cv2.COLOR_RGB2HSV)
250
-
251
- hue = hsv[:, :, 0].astype(np.float32)
252
- saturation = hsv[:, :, 1].astype(np.float32)
253
-
254
- # Raw blueness: hue proximity to 115° × saturation
255
- hue_distance = np.minimum(np.abs(hue - 115), 180 - np.abs(hue - 115))
256
- hue_score = np.maximum(0, 1 - hue_distance / 65)
257
- blueness = hue_score * (saturation / 255.0)
258
-
259
- # --- Compute per-cell mean blueness scores ---
260
- cell_ids = np.unique(masks)
261
- cell_ids = cell_ids[cell_ids > 0]
262
-
263
- if len(cell_ids) == 0:
264
- blank = image_np.copy()
265
- return 0, 0, blank, 0.0, 0.0
266
-
267
- cell_scores = np.array([np.mean(blueness[masks == cid]) for cid in cell_ids])
268
-
269
- # --- Otsu on the distribution of per-cell scores ---
270
- # cv2.threshold expects uint8; scale 0-1 → 0-255
271
- scores_u8 = (np.clip(cell_scores, 0, 1) * 255).astype(np.uint8)
272
-
273
- if scores_u8.max() == scores_u8.min():
274
- # All cells identical → Otsu is undefined; use midpoint
275
- otsu_threshold = float(scores_u8[0]) / 255.0
276
- else:
277
- # Reshape to a single-column image so cv2.threshold works
278
- thresh_val, _ = cv2.threshold(
279
- scores_u8.reshape(-1, 1), 0, 255,
280
- cv2.THRESH_BINARY + cv2.THRESH_OTSU
281
- )
282
- otsu_threshold = thresh_val / 255.0
283
-
284
- # --- Apply user bias: slider -50..+50 maps to ±0.20 shift ---
285
- bias = (threshold_bias / 50.0) * 0.20
286
- final_threshold = float(np.clip(otsu_threshold + bias, 0.0, 1.0))
287
-
288
- # --- Classify ---
289
- dead_cells = [cid for cid, s in zip(cell_ids, cell_scores) if s > final_threshold]
290
- alive_cells = [cid for cid, s in zip(cell_ids, cell_scores) if s <= final_threshold]
291
-
292
- # --- Outline-only overlay on raw image with enumerated labels ---
293
- final_overlay = image_np.copy()
294
-
295
- # Compute a consistent enumeration order (cell_ids is already sorted ascending)
296
- cell_enum = {cid: idx + 1 for idx, cid in enumerate(cell_ids)}
297
-
298
- dead_set = set(dead_cells)
299
- alive_set = set(alive_cells)
300
-
301
- for cid in cell_ids:
302
- cell_mask = (masks == cid).astype(np.uint8)
303
- contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
304
- color = (220, 50, 50) if cid in dead_set else (50, 220, 80)
305
- cv2.drawContours(final_overlay, contours, -1, color, thickness=2)
306
-
307
- # Draw enumeration label at centroid
308
- ys, xs = np.where(cell_mask)
309
- if len(ys) > 0:
310
- cx, cy = int(xs.mean()), int(ys.mean())
311
- label_str = str(cell_enum[cid])
312
- font = cv2.FONT_HERSHEY_SIMPLEX
313
- font_scale = 0.35
314
- thickness = 1
315
- (tw, th), _ = cv2.getTextSize(label_str, font, font_scale, thickness)
316
- # Dark background rectangle for readability
317
- cv2.rectangle(
318
- final_overlay,
319
- (cx - tw // 2 - 1, cy - th // 2 - 1),
320
- (cx + tw // 2 + 1, cy + th // 2 + 1),
321
- (0, 0, 0),
322
- -1
323
- )
324
- cv2.putText(
325
- final_overlay, label_str,
326
- (cx - tw // 2, cy + th // 2),
327
- font, font_scale, color, thickness, cv2.LINE_AA
328
- )
329
-
330
- return len(dead_cells), len(alive_cells), final_overlay, otsu_threshold, final_threshold
331
 
332
 
333
  def measure_confluency(masks, image_np):
@@ -451,12 +343,10 @@ def warp_polygon_to_square(image_np, points):
451
 
452
 
453
  def toggle_stereological_mode(use_stereology):
454
- """Show/hide stereological controls based on checkbox"""
455
  return gr.update(visible=use_stereology)
456
 
457
 
458
  def update_exclusion_preview(image, left_width, top_width):
459
- """Update the preview image with exclusion zone overlay"""
460
  if image is None:
461
  return None
462
 
@@ -465,9 +355,8 @@ def update_exclusion_preview(image, left_width, top_width):
465
  return Image.fromarray(overlay)
466
 
467
 
468
- # ---------------------------------------------------------------------------
469
  # Patch segmentation
470
- # ---------------------------------------------------------------------------
471
  PATCH_SIZE = 512 # target patch side length
472
  PATCH_OVERLAP = 64 # overlap border on each edge (pixels)
473
  MIN_PATCH_DIM = 256 # don't bother patching if image fits comfortably
@@ -567,7 +456,7 @@ def _segment_patch(args):
567
  model = models.CellposeModel(gpu=True, pretrained_model=model_path)
568
  loaded_models[model_filename] = model
569
 
570
- mask, _, _ = model.eval(patch_np, diameter=None, channels=[0, 0])
571
  return mask, row_start, col_start
572
 
573
 
@@ -579,7 +468,8 @@ def run_segmentation_patched(image_np, model_filename):
579
  that patching adds overhead without benefit.
580
  """
581
  h, w = image_np.shape[:2]
582
- model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=model_filename)
 
583
  if model_filename in loaded_models:
584
  model = loaded_models[model_filename]
585
  else:
@@ -588,15 +478,16 @@ def run_segmentation_patched(image_np, model_filename):
588
 
589
  # Small images: no benefit from patching
590
  if max(h, w) <= MIN_PATCH_DIM * 2:
591
- mask, _, _ = model.eval(image_np, diameter=None, channels=[0, 0])
592
  return mask, 1 # 1 patch
593
 
594
  patches = _split_patches(image_np)
595
  n_patches = len(patches)
596
 
597
  # Build argument list for the thread pool
 
598
  args_list = [
599
- (patch, r, c, model_filename, HF_REPO_ID)
600
  for patch, r, c in patches
601
  ]
602
 
@@ -618,6 +509,7 @@ def run_segmentation_patched(image_np, model_filename):
618
 
619
  @spaces.GPU
620
  def run_segmentation(image, model_choice, min_cell_size, max_cell_size,
 
621
  use_stereology, left_exclusion, top_exclusion,
622
  crop_points=None):
623
  image_np = np.array(image)
@@ -661,21 +553,18 @@ def run_segmentation(image, model_choice, min_cell_size, max_cell_size,
661
  print("p90:", np.percentile(sizes, 90) if len(sizes) > 0 else 0)
662
  print("max:", sizes.max() if len(sizes) > 0 else 0)
663
 
664
- # Compute recommendation from RAW masks
665
  recommend_min = rec_min_size(masks_raw)
666
 
667
- # If user sets slider to 0, use the recommendation
668
- min_used = recommend_min if (min_cell_size == 0) else int(min_cell_size)
669
-
670
- # Apply filters
671
  masks = masks_raw.copy()
672
  removed_small = 0
673
  removed_large = 0
674
 
675
- if min_used > 0:
676
- masks, removed_small = filter_mask_by_size(masks, min_used)
677
 
678
- if max_cell_size > 0:
679
  masks, removed_large = filter_mask_by_maxsize(masks, int(max_cell_size))
680
 
681
  # Apply stereological exclusion if enabled
@@ -687,7 +576,7 @@ def run_segmentation(image, model_choice, min_cell_size, max_cell_size,
687
 
688
  filter_msg = ""
689
  if removed_small:
690
- filter_msg += f"Removed {removed_small} small objects (< {min_used} pixels).\n"
691
  if removed_large:
692
  filter_msg += f"Removed {removed_large} large objects (> {int(max_cell_size)} pixels).\n"
693
  if use_stereology and excluded_count > 0:
@@ -729,7 +618,7 @@ def run_segmentation(image, model_choice, min_cell_size, max_cell_size,
729
  pack_array(masks),
730
  pack_array(processed_image_np),
731
  confluency,
732
- gr.update(value=recommend_min),
733
  pack_array(raw_image_np),
734
  )
735
 
@@ -744,13 +633,12 @@ def run_segmentation(image, model_choice, min_cell_size, max_cell_size,
744
  None,
745
  None,
746
  0.0,
747
- gr.update(),
748
  None,
749
  )
750
 
751
 
752
  def run_viability(stored_masks, stored_image_np):
753
- """Run model-based viability classification. Returns overlay + counts + label_map."""
754
  if stored_masks is None or stored_image_np is None:
755
  return None, 0, 0, 0.0, "Please run segmentation first.", {}
756
  if VIABILITY_CLF is None:
@@ -773,14 +661,20 @@ def run_viability(stored_masks, stored_image_np):
773
 
774
 
775
  def pack_array(arr):
776
- pil = Image.fromarray(arr.astype(np.uint8))
 
 
 
 
 
777
  buf = io.BytesIO()
778
- pil.save(buf, format="PNG")
779
  return buf.getvalue()
780
 
781
 
782
  def unpack_array(data):
783
- return np.array(Image.open(io.BytesIO(data)))
 
784
 
785
 
786
  def save_tab_result(cell_count, confluency, viab_percent):
@@ -820,29 +714,12 @@ def compute_summary(r1, r2, r3, r4):
820
  return avg_count, avg_conf, avg_viab, "\n".join(lines)
821
 
822
 
823
- # ---------------------------------------------------------------------------
824
  # Training data export — feature extraction per cell
825
- # ---------------------------------------------------------------------------
826
 
827
  def extract_cell_features(image_np, masks):
828
- """
829
- For every segmented cell, extract a fixed feature vector from the pixels
830
- inside its mask. Returns a list of dicts, one per cell.
831
-
832
- Features:
833
- RGB channels — mean_r, mean_g, mean_b, std_r, std_g, std_b
834
- HSV channels — mean_h, mean_s, mean_v, std_s, std_v
835
- Ratios — blue_red_ratio, blue_green_ratio, rg_ratio
836
- Morphology — area_px, circularity
837
- Centre/edge profile — inner_brightness, peak_brightness,
838
- bright_spot_fraction, ring_darkness,
839
- centre_periphery_ratio, brightness_std_normalised
840
-
841
- Profile zones are tuned to hemocytometer live-cell morphology:
842
- a small intense specular highlight at the centre surrounded by a dark
843
- navy membrane ring. Dead cells are pale blue-grey blobs with no ring
844
- and no bright spot.
845
- """
846
  if len(image_np.shape) == 2:
847
  image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
848
  elif image_np.shape[2] == 4:
@@ -1003,9 +880,8 @@ def prepare_export(stored_masks, stored_image, threshold_bias):
1003
  return path, msg
1004
 
1005
 
1006
- # ---------------------------------------------------------------------------
1007
  # Tab builder
1008
- # ---------------------------------------------------------------------------
1009
 
1010
  def draw_polygon_overlay(image_pil, points):
1011
  """
@@ -1063,14 +939,12 @@ def clear_crop_points(image_pil):
1063
 
1064
 
1065
 
1066
- # ---------------------------------------------------------------------------
1067
  # Label correction grid
1068
- # ---------------------------------------------------------------------------
1069
 
1070
- THUMB_SIZE = 80 # each cell thumbnail is THUMB_SIZE × THUMB_SIZE px
1071
- GRID_COLS = 6 # thumbnails per row
1072
- BORDER = 4 # coloured border thickness in px
1073
- LABEL_H = 16 # height of the text label strip at the bottom of each thumb
1074
 
1075
  def _crop_cell_thumb(image_np, masks, cid):
1076
  """
@@ -1102,14 +976,7 @@ def _crop_cell_thumb(image_np, masks, cid):
1102
 
1103
 
1104
  def build_correction_grid(image_np, masks, labelled_features, raw_image_np=None):
1105
- """
1106
- Render all cell thumbnails into a single PIL image grid.
1107
- Each thumbnail has a coloured border: green=live(0), red=dead(1).
1108
- A small number in the corner identifies the cell_id.
1109
-
1110
- Returns the PIL grid image.
1111
- Cell order in the grid matches the order of labelled_features.
1112
- """
1113
  if not labelled_features:
1114
  placeholder = Image.fromarray(
1115
  np.zeros((THUMB_SIZE, THUMB_SIZE, 3), dtype=np.uint8)
@@ -1256,14 +1123,29 @@ def build_tab(tab_index, masks_state, image_state, result_state):
1256
  value="Hemocytometer Model"
1257
  )
1258
 
 
 
 
 
 
 
 
1259
  min_size_slider = gr.Slider(
1260
  minimum=0,
1261
  maximum=500,
1262
  value=0,
1263
  step=10,
1264
- label="Minimum Cell Size (pixels). Leave at zero for automated recommendation",
 
 
 
1265
  )
1266
 
 
 
 
 
 
1267
  max_size_slider = gr.Slider(
1268
  minimum=0,
1269
  maximum=10000,
@@ -1405,9 +1287,10 @@ def build_tab(tab_index, masks_state, image_state, result_state):
1405
  segment_btn.click(
1406
  fn=run_segmentation,
1407
  inputs=[img_input, model_dropdown, min_size_slider, max_size_slider,
 
1408
  use_stereo, left_excl, top_excl, crop_points_state],
1409
  outputs=[cell_count_out, overlay_out, info_out, viability_section,
1410
- masks_state, image_state, confluency_out, min_size_slider, raw_image_state]
1411
  )
1412
 
1413
  # ---- Run Viability button -------------------------------------------
@@ -1496,9 +1379,8 @@ def build_tab(tab_index, masks_state, image_state, result_state):
1496
 
1497
 
1498
 
1499
- # ---------------------------------------------------------------------------
1500
  # Gradio interface
1501
- # ---------------------------------------------------------------------------
1502
  with gr.Blocks(
1503
  title="CellposeCellCounter",
1504
  theme=gr.themes.Soft(),
 
14
  import joblib
15
  import os
16
 
17
+ HF_REPO_ID = "myang4218/cellposemodel"
18
+ HF_REPO_ID2 = "LiangLabUMB/viability_model"
19
+ HF_REPO_CPSAM = "mouseland/cellpose-sam"
20
  MODEL_OPTIONS = {
21
  "Hemocytometer Model": "hemocytometermodel.npy",
22
+ "General Model": "generalmodel.npy",
23
+ "Cellpose SAMv2": "cpsam_v2",
24
+ }
25
+ MODEL_REPOS = {
26
+ "hemocytometermodel.npy": HF_REPO_ID,
27
+ "generalmodel.npy": HF_REPO_ID,
28
+ "cpsam_v2": HF_REPO_CPSAM,
29
  }
30
 
31
  loaded_models = {}
 
42
  except Exception as e:
43
  print(f"Viability classifier not found or failed to load: {e}")
44
 
45
+ # mobile safe resize limits
46
  MAX_SIDE = 1024
47
  MAX_PIXELS = 1024 * 1024
48
 
49
 
50
  def safe_resize(image_np):
51
+
 
 
 
52
  h, w = image_np.shape[:2]
53
  total = h * w
54
 
 
156
 
157
 
158
  def classify_cells_by_model(image_np, masks):
159
+
 
 
 
 
160
  import numpy as np
161
  cell_ids = np.unique(masks)
162
  cell_ids = cell_ids[cell_ids > 0]
 
188
 
189
 
190
  def draw_viability_overlay(image_np, masks, label_map):
191
+
 
 
 
 
192
  overlay = image_np.copy()
193
  cell_ids = np.unique(masks)
194
  cell_ids = cell_ids[cell_ids > 0]
 
216
  (0, 0, 0), -1)
217
  cv2.putText(overlay, label_str,
218
  (cx - tw//2, cy + th//2),
219
+ font, font_scale, color, thickness, cv2.LINE_AA)
220
  return overlay
221
 
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
 
225
  def measure_confluency(masks, image_np):
 
343
 
344
 
345
  def toggle_stereological_mode(use_stereology):
 
346
  return gr.update(visible=use_stereology)
347
 
348
 
349
  def update_exclusion_preview(image, left_width, top_width):
 
350
  if image is None:
351
  return None
352
 
 
355
  return Image.fromarray(overlay)
356
 
357
 
 
358
  # Patch segmentation
359
+
360
  PATCH_SIZE = 512 # target patch side length
361
  PATCH_OVERLAP = 64 # overlap border on each edge (pixels)
362
  MIN_PATCH_DIM = 256 # don't bother patching if image fits comfortably
 
456
  model = models.CellposeModel(gpu=True, pretrained_model=model_path)
457
  loaded_models[model_filename] = model
458
 
459
+ mask, _, _ = model.eval(patch_np, diameter=None)
460
  return mask, row_start, col_start
461
 
462
 
 
468
  that patching adds overhead without benefit.
469
  """
470
  h, w = image_np.shape[:2]
471
+ repo = MODEL_REPOS.get(model_filename, HF_REPO_ID)
472
+ model_path = hf_hub_download(repo_id=repo, filename=model_filename)
473
  if model_filename in loaded_models:
474
  model = loaded_models[model_filename]
475
  else:
 
478
 
479
  # Small images: no benefit from patching
480
  if max(h, w) <= MIN_PATCH_DIM * 2:
481
+ mask, _, _ = model.eval(image_np, diameter=None)
482
  return mask, 1 # 1 patch
483
 
484
  patches = _split_patches(image_np)
485
  n_patches = len(patches)
486
 
487
  # Build argument list for the thread pool
488
+ patch_repo = MODEL_REPOS.get(model_filename, HF_REPO_ID)
489
  args_list = [
490
+ (patch, r, c, model_filename, patch_repo)
491
  for patch, r, c in patches
492
  ]
493
 
 
509
 
510
  @spaces.GPU
511
  def run_segmentation(image, model_choice, min_cell_size, max_cell_size,
512
+ use_min_filter, use_max_filter,
513
  use_stereology, left_exclusion, top_exclusion,
514
  crop_points=None):
515
  image_np = np.array(image)
 
553
  print("p90:", np.percentile(sizes, 90) if len(sizes) > 0 else 0)
554
  print("max:", sizes.max() if len(sizes) > 0 else 0)
555
 
556
+ # Compute recommendation from RAW masks (always shown, never auto-applied)
557
  recommend_min = rec_min_size(masks_raw)
558
 
559
+ # Apply filters only if their checkboxes are enabled
 
 
 
560
  masks = masks_raw.copy()
561
  removed_small = 0
562
  removed_large = 0
563
 
564
+ if use_min_filter and int(min_cell_size) > 0:
565
+ masks, removed_small = filter_mask_by_size(masks, int(min_cell_size))
566
 
567
+ if use_max_filter and max_cell_size > 0:
568
  masks, removed_large = filter_mask_by_maxsize(masks, int(max_cell_size))
569
 
570
  # Apply stereological exclusion if enabled
 
576
 
577
  filter_msg = ""
578
  if removed_small:
579
+ filter_msg += f"Removed {removed_small} small objects (< {int(min_cell_size)} pixels).\n"
580
  if removed_large:
581
  filter_msg += f"Removed {removed_large} large objects (> {int(max_cell_size)} pixels).\n"
582
  if use_stereology and excluded_count > 0:
 
618
  pack_array(masks),
619
  pack_array(processed_image_np),
620
  confluency,
621
+ f"Recommended minimum: **{recommend_min} px** (25th percentile of detected cell sizes)",
622
  pack_array(raw_image_np),
623
  )
624
 
 
633
  None,
634
  None,
635
  0.0,
636
+ "",
637
  None,
638
  )
639
 
640
 
641
  def run_viability(stored_masks, stored_image_np):
 
642
  if stored_masks is None or stored_image_np is None:
643
  return None, 0, 0, 0.0, "Please run segmentation first.", {}
644
  if VIABILITY_CLF is None:
 
661
 
662
 
663
  def pack_array(arr):
664
+ """
665
+ Serialise a numpy array to bytes for gr.State storage.
666
+ Uses numpy's .npy format (not PNG) so integer dtypes of any
667
+ magnitude are preserved exactly — PNG is 8-bit only and silently
668
+ truncates cell IDs above 255.
669
+ """
670
  buf = io.BytesIO()
671
+ np.save(buf, arr)
672
  return buf.getvalue()
673
 
674
 
675
  def unpack_array(data):
676
+ buf = io.BytesIO(data)
677
+ return np.load(buf, allow_pickle=False)
678
 
679
 
680
  def save_tab_result(cell_count, confluency, viab_percent):
 
714
  return avg_count, avg_conf, avg_viab, "\n".join(lines)
715
 
716
 
717
+
718
  # Training data export — feature extraction per cell
719
+
720
 
721
  def extract_cell_features(image_np, masks):
722
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
723
  if len(image_np.shape) == 2:
724
  image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
725
  elif image_np.shape[2] == 4:
 
880
  return path, msg
881
 
882
 
883
+
884
  # Tab builder
 
885
 
886
  def draw_polygon_overlay(image_pil, points):
887
  """
 
939
 
940
 
941
 
 
942
  # Label correction grid
 
943
 
944
+ THUMB_SIZE = 80
945
+ GRID_COLS = 10
946
+ BORDER = 4
947
+ LABEL_H = 16
948
 
949
  def _crop_cell_thumb(image_np, masks, cid):
950
  """
 
976
 
977
 
978
  def build_correction_grid(image_np, masks, labelled_features, raw_image_np=None):
979
+
 
 
 
 
 
 
 
980
  if not labelled_features:
981
  placeholder = Image.fromarray(
982
  np.zeros((THUMB_SIZE, THUMB_SIZE, 3), dtype=np.uint8)
 
1123
  value="Hemocytometer Model"
1124
  )
1125
 
1126
+ gr.Markdown("### Size Filters")
1127
+
1128
+ use_min_filter = gr.Checkbox(
1129
+ label="Enable minimum size filter",
1130
+ value=False,
1131
+ info="Remove objects smaller than the threshold below"
1132
+ )
1133
  min_size_slider = gr.Slider(
1134
  minimum=0,
1135
  maximum=500,
1136
  value=0,
1137
  step=10,
1138
+ label="Minimum Cell Size (pixels)",
1139
+ )
1140
+ min_size_recommendation = gr.Markdown(
1141
+ value="*Run segmentation to see recommended minimum*",
1142
  )
1143
 
1144
+ use_max_filter = gr.Checkbox(
1145
+ label="Enable maximum size filter",
1146
+ value=False,
1147
+ info="Remove objects larger than the threshold below"
1148
+ )
1149
  max_size_slider = gr.Slider(
1150
  minimum=0,
1151
  maximum=10000,
 
1287
  segment_btn.click(
1288
  fn=run_segmentation,
1289
  inputs=[img_input, model_dropdown, min_size_slider, max_size_slider,
1290
+ use_min_filter, use_max_filter,
1291
  use_stereo, left_excl, top_excl, crop_points_state],
1292
  outputs=[cell_count_out, overlay_out, info_out, viability_section,
1293
+ masks_state, image_state, confluency_out, min_size_recommendation, raw_image_state]
1294
  )
1295
 
1296
  # ---- Run Viability button -------------------------------------------
 
1379
 
1380
 
1381
 
 
1382
  # Gradio interface
1383
+
1384
  with gr.Blocks(
1385
  title="CellposeCellCounter",
1386
  theme=gr.themes.Soft(),