LiangLabUMB commited on
Commit
18da35a
·
verified ·
1 Parent(s): 93d18c0

Sync from GitHub via hub-sync

Browse files
Files changed (4) hide show
  1. README.md +12 -0
  2. app.py +1546 -0
  3. gitattributes +35 -0
  4. requirements.txt +10 -0
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CellposeCellCounterMobile
3
+ emoji: 🚀
4
+ colorFrom: purple
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 5.35.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: Cell counter for uploaded microscopy images
12
+ ---
app.py ADDED
@@ -0,0 +1,1546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from cellpose import models
4
+ import numpy as np
5
+ import cv2
6
+ import matplotlib.pyplot as plt
7
+ import tempfile
8
+ from PIL import Image, ImageDraw
9
+ import io
10
+ from huggingface_hub import hf_hub_download
11
+ import base64
12
+ from concurrent.futures import ThreadPoolExecutor, as_completed
13
+ import csv
14
+ import joblib
15
+ import os
16
+
17
+ HF_REPO_ID = "myang4218/cellposemodel"
18
+ HF_REPO_ID2 = "LiangLabUMB/viability_model"
19
+ MODEL_OPTIONS = {
20
+ "Hemocytometer Model": "hemocytometermodel.npy",
21
+ "General Model": "generalmodel.npy"
22
+ }
23
+
24
+ loaded_models = {}
25
+
26
+ VIABILITY_CLF = None
27
+ VIABILITY_SCALER = None
28
+
29
+ try:
30
+ _clf_path = hf_hub_download(repo_id=HF_REPO_ID2, filename="viability_clf.pkl")
31
+ _scaler_path = hf_hub_download(repo_id=HF_REPO_ID2, filename="viability_scaler.pkl")
32
+ VIABILITY_CLF = joblib.load(_clf_path)
33
+ VIABILITY_SCALER = joblib.load(_scaler_path)
34
+ print("✓ Viability classifier loaded.")
35
+ except Exception as e:
36
+ print(f"Viability classifier not found or failed to load: {e}")
37
+
38
+ # ---- mobile-safe size limits (aggressive for Safari) ----
39
+ MAX_SIDE = 1024
40
+ MAX_PIXELS = 1024 * 1024
41
+
42
+
43
+ def safe_resize(image_np):
44
+ """
45
+ Downscale image to fit within MAX_SIDE and MAX_PIXELS while
46
+ preserving aspect ratio. Works for RGB / RGBA / grayscale.
47
+ """
48
+ h, w = image_np.shape[:2]
49
+ total = h * w
50
+
51
+ if max(h, w) <= MAX_SIDE and total <= MAX_PIXELS:
52
+ return image_np
53
+
54
+ # compute scale
55
+ scale_side = MAX_SIDE / max(h, w)
56
+ scale_pixels = (MAX_PIXELS / total) ** 0.5
57
+ scale = min(scale_side, scale_pixels)
58
+
59
+ new_w = max(1, int(w * scale))
60
+ new_h = max(1, int(h * scale))
61
+
62
+ return cv2.resize(image_np, (new_w, new_h), interpolation=cv2.INTER_AREA)
63
+
64
+
65
+ def draw_exclusion_overlay(image_np, left_width_pct, top_width_pct):
66
+
67
+ h, w = image_np.shape[:2]
68
+
69
+ # Convert to PIL for drawing
70
+ img_pil = Image.fromarray(image_np)
71
+ draw = ImageDraw.Draw(img_pil, 'RGBA')
72
+
73
+ # Calculate pixel widths from percentages
74
+ left_px = int(w * left_width_pct / 100)
75
+ top_px = int(h * top_width_pct / 100)
76
+
77
+ # Draw overlays for exclusion zones
78
+ if left_px > 0:
79
+ # Left exclusion zone
80
+ draw.rectangle(
81
+ [(0, 0), (left_px, h)],
82
+ fill=(255, 0, 0, 80) # Semi-transparent red
83
+ )
84
+ # border line
85
+ draw.line([(left_px, 0), (left_px, h)], fill=(255, 0, 0, 255), width=3)
86
+
87
+ if top_px > 0:
88
+ # Top exclusion zone
89
+ draw.rectangle(
90
+ [(0, 0), (w, top_px)],
91
+ fill=(255, 0, 0, 80) # Semi-transparent red
92
+ )
93
+ # border line
94
+ draw.line([(0, top_px), (w, top_px)], fill=(255, 0, 0, 255), width=3)
95
+
96
+ return np.array(img_pil)
97
+
98
+
99
+ def apply_stereological_exclusion(masks, left_width_pct, top_width_pct):
100
+ h, w = masks.shape
101
+
102
+ # Calculate pixel widths from percentages
103
+ left_px = int(w * left_width_pct / 100)
104
+ top_px = int(h * top_width_pct / 100)
105
+
106
+ filtered_masks = masks.copy()
107
+ cell_ids = np.unique(masks)
108
+ cell_ids = cell_ids[cell_ids > 0]
109
+
110
+ excluded_cells = []
111
+ included_cells = []
112
+
113
+ for cell_id in cell_ids:
114
+ cell_mask = (masks == cell_id)
115
+
116
+ # Get cell boundary coordinates
117
+ rows, cols = np.where(cell_mask)
118
+
119
+ # Check if cell touches left exclusion zone
120
+ touches_left = np.any(cols < left_px) if left_px > 0 else False
121
+
122
+ # Check if cell touches top exclusion zone
123
+ touches_top = np.any(rows < top_px) if top_px > 0 else False
124
+
125
+ # Exclude if touching left or top
126
+ if touches_left or touches_top:
127
+ filtered_masks[cell_mask] = 0
128
+ excluded_cells.append(cell_id)
129
+ else:
130
+ included_cells.append(cell_id)
131
+
132
+ # Renumber remaining cells
133
+ unique_ids = np.unique(filtered_masks)
134
+ unique_ids = unique_ids[unique_ids > 0]
135
+
136
+ renumbered_masks = np.zeros_like(filtered_masks)
137
+ for new_id, old_id in enumerate(unique_ids, start=1):
138
+ renumbered_masks[filtered_masks == old_id] = new_id
139
+
140
+ return renumbered_masks, len(excluded_cells), len(included_cells)
141
+
142
+
143
+
144
+ FEATURE_COLS_INFERENCE = [
145
+ "mean_r", "mean_g", "mean_b", "std_r", "std_g", "std_b",
146
+ "mean_h", "mean_s", "mean_v", "std_s", "std_v",
147
+ "blue_red_ratio", "blue_green_ratio", "rg_ratio",
148
+ "inner_brightness", "peak_brightness",
149
+ "bright_spot_fraction", "ring_darkness",
150
+ "centre_periphery_ratio", "brightness_std_normalised",
151
+ ]
152
+
153
+
154
+ def classify_cells_by_model(image_np, masks):
155
+ """
156
+ Run the trained LogisticRegression classifier to predict live/dead per cell.
157
+ Returns (dead_count, alive_count, overlay_np, {cell_id: label}).
158
+ Requires VIABILITY_CLF and VIABILITY_SCALER to be loaded.
159
+ """
160
+ import numpy as np
161
+ cell_ids = np.unique(masks)
162
+ cell_ids = cell_ids[cell_ids > 0]
163
+ if len(cell_ids) == 0:
164
+ return 0, 0, image_np.copy(), {}
165
+
166
+ features = extract_cell_features(image_np, masks)
167
+ if not features:
168
+ return 0, 0, image_np.copy(), {}
169
+
170
+ import numpy as np
171
+ X = np.array([[f[c] for c in FEATURE_COLS_INFERENCE] for f in features], dtype=np.float32)
172
+
173
+ # replace any NaN/Inf with column median
174
+ for j in range(X.shape[1]):
175
+ bad = ~np.isfinite(X[:, j])
176
+ if bad.any():
177
+ X[bad, j] = float(np.nanmedian(X[:, j]))
178
+
179
+ X_scaled = VIABILITY_SCALER.transform(X)
180
+ predictions = VIABILITY_CLF.predict(X_scaled) # 0=live, 1=dead
181
+
182
+ label_map = {int(f["cell_id"]): int(p) for f, p in zip(features, predictions)}
183
+ overlay = draw_viability_overlay(image_np, masks, label_map)
184
+
185
+ dead = int(sum(predictions))
186
+ alive = int(len(predictions) - dead)
187
+ return dead, alive, overlay, label_map
188
+
189
+
190
+ def draw_viability_overlay(image_np, masks, label_map):
191
+ """
192
+ Draw coloured contours + cell-number labels onto image_np.
193
+ label_map: {cell_id: 0=live, 1=dead}
194
+ Returns a uint8 numpy array.
195
+ """
196
+ overlay = image_np.copy()
197
+ cell_ids = np.unique(masks)
198
+ cell_ids = cell_ids[cell_ids > 0]
199
+ cell_enum = {int(cid): idx + 1 for idx, cid in enumerate(sorted(cell_ids))}
200
+
201
+ for cid in cell_ids:
202
+ cid_int = int(cid)
203
+ label = label_map.get(cid_int, 0)
204
+ color = (220, 50, 50) if label == 1 else (50, 220, 80)
205
+ cell_mask = (masks == cid).astype(np.uint8)
206
+ contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
207
+ cv2.drawContours(overlay, contours, -1, color, thickness=2)
208
+
209
+ ys, xs = np.where(cell_mask)
210
+ if len(ys) > 0:
211
+ cx, cy = int(xs.mean()), int(ys.mean())
212
+ label_str = str(cell_enum[cid_int])
213
+ font = cv2.FONT_HERSHEY_SIMPLEX
214
+ font_scale = 0.35
215
+ thickness = 1
216
+ (tw, th), _ = cv2.getTextSize(label_str, font, font_scale, thickness)
217
+ cv2.rectangle(overlay,
218
+ (cx - tw//2 - 1, cy - th//2 - 1),
219
+ (cx + tw//2 + 1, cy + th//2 + 1),
220
+ (0, 0, 0), -1)
221
+ cv2.putText(overlay, label_str,
222
+ (cx - tw//2, cy + th//2),
223
+ font, font_scale, color, thickness, cv2.LINE_AA)
224
+ return overlay
225
+
226
+
227
+ def classify_cells_by_blueness(image_np, masks, threshold_bias):
228
+ """
229
+ Classify cells as dead (blue) or alive using an adaptive Otsu threshold
230
+ on per-cell blueness scores, with a user bias to fine-tune.
231
+
232
+ Args:
233
+ image_np: RGB image array
234
+ masks: Cellpose segmentation masks
235
+ threshold_bias: Slider value -50..+50; shifts Otsu threshold up/down.
236
+ Negative = more cells classified dead (looser).
237
+ Positive = fewer cells classified dead (stricter).
238
+ 0 = pure Otsu (fully automatic).
239
+
240
+ Returns:
241
+ dead_count, alive_count, colored_overlay, otsu_threshold, final_threshold
242
+ """
243
+
244
+ if len(image_np.shape) == 2:
245
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
246
+ elif len(image_np.shape) == 3 and image_np.shape[2] == 4:
247
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
248
+
249
+ hsv = cv2.cvtColor(image_np, cv2.COLOR_RGB2HSV)
250
+
251
+ hue = hsv[:, :, 0].astype(np.float32)
252
+ saturation = hsv[:, :, 1].astype(np.float32)
253
+
254
+ # Raw blueness: hue proximity to 115° × saturation
255
+ hue_distance = np.minimum(np.abs(hue - 115), 180 - np.abs(hue - 115))
256
+ hue_score = np.maximum(0, 1 - hue_distance / 65)
257
+ blueness = hue_score * (saturation / 255.0)
258
+
259
+ # --- Compute per-cell mean blueness scores ---
260
+ cell_ids = np.unique(masks)
261
+ cell_ids = cell_ids[cell_ids > 0]
262
+
263
+ if len(cell_ids) == 0:
264
+ blank = image_np.copy()
265
+ return 0, 0, blank, 0.0, 0.0
266
+
267
+ cell_scores = np.array([np.mean(blueness[masks == cid]) for cid in cell_ids])
268
+
269
+ # --- Otsu on the distribution of per-cell scores ---
270
+ # cv2.threshold expects uint8; scale 0-1 → 0-255
271
+ scores_u8 = (np.clip(cell_scores, 0, 1) * 255).astype(np.uint8)
272
+
273
+ if scores_u8.max() == scores_u8.min():
274
+ # All cells identical → Otsu is undefined; use midpoint
275
+ otsu_threshold = float(scores_u8[0]) / 255.0
276
+ else:
277
+ # Reshape to a single-column image so cv2.threshold works
278
+ thresh_val, _ = cv2.threshold(
279
+ scores_u8.reshape(-1, 1), 0, 255,
280
+ cv2.THRESH_BINARY + cv2.THRESH_OTSU
281
+ )
282
+ otsu_threshold = thresh_val / 255.0
283
+
284
+ # --- Apply user bias: slider -50..+50 maps to ±0.20 shift ---
285
+ bias = (threshold_bias / 50.0) * 0.20
286
+ final_threshold = float(np.clip(otsu_threshold + bias, 0.0, 1.0))
287
+
288
+ # --- Classify ---
289
+ dead_cells = [cid for cid, s in zip(cell_ids, cell_scores) if s > final_threshold]
290
+ alive_cells = [cid for cid, s in zip(cell_ids, cell_scores) if s <= final_threshold]
291
+
292
+ # --- Outline-only overlay on raw image with enumerated labels ---
293
+ final_overlay = image_np.copy()
294
+
295
+ # Compute a consistent enumeration order (cell_ids is already sorted ascending)
296
+ cell_enum = {cid: idx + 1 for idx, cid in enumerate(cell_ids)}
297
+
298
+ dead_set = set(dead_cells)
299
+ alive_set = set(alive_cells)
300
+
301
+ for cid in cell_ids:
302
+ cell_mask = (masks == cid).astype(np.uint8)
303
+ contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
304
+ color = (220, 50, 50) if cid in dead_set else (50, 220, 80)
305
+ cv2.drawContours(final_overlay, contours, -1, color, thickness=2)
306
+
307
+ # Draw enumeration label at centroid
308
+ ys, xs = np.where(cell_mask)
309
+ if len(ys) > 0:
310
+ cx, cy = int(xs.mean()), int(ys.mean())
311
+ label_str = str(cell_enum[cid])
312
+ font = cv2.FONT_HERSHEY_SIMPLEX
313
+ font_scale = 0.35
314
+ thickness = 1
315
+ (tw, th), _ = cv2.getTextSize(label_str, font, font_scale, thickness)
316
+ # Dark background rectangle for readability
317
+ cv2.rectangle(
318
+ final_overlay,
319
+ (cx - tw // 2 - 1, cy - th // 2 - 1),
320
+ (cx + tw // 2 + 1, cy + th // 2 + 1),
321
+ (0, 0, 0),
322
+ -1
323
+ )
324
+ cv2.putText(
325
+ final_overlay, label_str,
326
+ (cx - tw // 2, cy + th // 2),
327
+ font, font_scale, color, thickness, cv2.LINE_AA
328
+ )
329
+
330
+ return len(dead_cells), len(alive_cells), final_overlay, otsu_threshold, final_threshold
331
+
332
+
333
+ def measure_confluency(masks, image_np):
334
+ tot_pixels = image_np.shape[0] * image_np.shape[1]
335
+ cell_pixels = np.count_nonzero(masks)
336
+ confluency = cell_pixels / tot_pixels * 100
337
+ return confluency
338
+
339
+ def filter_mask_by_size(masks, minimum_pixels):
340
+ filtered_masks = masks.copy()
341
+ cell_ids = np.unique(masks)
342
+ cell_ids = cell_ids[cell_ids > 0]
343
+
344
+ removed_count = 0
345
+
346
+ for cell_id in cell_ids:
347
+ cell_mask = (masks == cell_id)
348
+ cell_pixels = np.count_nonzero(cell_mask)
349
+ if cell_pixels < minimum_pixels:
350
+ filtered_masks[cell_mask] = 0
351
+ removed_count += 1
352
+
353
+ unique_ids = np.unique(filtered_masks)
354
+ unique_ids = unique_ids[unique_ids > 0]
355
+
356
+ renumbered_masks = np.zeros_like(filtered_masks)
357
+ for new_id, old_id in enumerate(unique_ids, start=1):
358
+ renumbered_masks[filtered_masks == old_id] = new_id
359
+
360
+ return renumbered_masks, removed_count
361
+
362
+
363
+ def filter_mask_by_maxsize(masks, maximum_pixels):
364
+ filtered_masks = masks.copy()
365
+ cell_ids = np.unique(masks)
366
+ cell_ids = cell_ids[cell_ids > 0]
367
+
368
+ removed_count = 0
369
+ for cell_id in cell_ids:
370
+ cell_mask = (masks == cell_id)
371
+ cell_pixels = np.count_nonzero(cell_mask)
372
+ if cell_pixels > maximum_pixels:
373
+ filtered_masks[cell_mask] = 0
374
+ removed_count += 1
375
+
376
+ unique_ids = np.unique(filtered_masks)
377
+ unique_ids = unique_ids[unique_ids > 0]
378
+
379
+ renumbered_masks = np.zeros_like(filtered_masks)
380
+ for new_id, old_id in enumerate(unique_ids, start=1):
381
+ renumbered_masks[filtered_masks == old_id] = new_id
382
+
383
+ return renumbered_masks, removed_count
384
+
385
+
386
+ def rec_min_size(masks, q=25):
387
+ ids = np.unique(masks)
388
+ ids = ids[ids > 0]
389
+ if len(ids) == 0:
390
+ return 0
391
+ sizes = np.array([np.count_nonzero(masks == cid) for cid in ids])
392
+ return int(round(np.percentile(sizes, q)))
393
+
394
+
395
+ def apply_polygon_mask(image_pil, points_json):
396
+ """
397
+ Given a PIL image and a JSON string of [[x,y],...] points,
398
+ zero out everything outside the polygon and return a PIL image.
399
+ """
400
+ import json
401
+ if not points_json or points_json.strip() in ("", "[]"):
402
+ return image_pil
403
+ try:
404
+ pts = json.loads(points_json)
405
+ except Exception:
406
+ return image_pil
407
+ if len(pts) < 3:
408
+ return image_pil
409
+
410
+ image_np = np.array(image_pil)
411
+ h, w = image_np.shape[:2]
412
+ poly = np.array(pts, dtype=np.int32)
413
+ poly[:, 0] = np.clip(poly[:, 0], 0, w - 1)
414
+ poly[:, 1] = np.clip(poly[:, 1], 0, h - 1)
415
+ mask = np.zeros((h, w), dtype=np.uint8)
416
+ cv2.fillPoly(mask, [poly], 255)
417
+ if len(image_np.shape) == 3:
418
+ result = np.where(mask[:, :, np.newaxis] == 255, image_np, 0).astype(np.uint8)
419
+ else:
420
+ result = np.where(mask == 255, image_np, 0).astype(np.uint8)
421
+ return Image.fromarray(result)
422
+
423
+ def warp_polygon_to_square(image_np, points):
424
+ pts = np.array(points, dtype=np.float32)
425
+
426
+ s = pts.sum(axis=1)
427
+ diff = np.diff(pts, axis=1).ravel()
428
+ tl = pts[np.argmin(s)]
429
+ br = pts[np.argmax(s)]
430
+ tr = pts[np.argmin(diff)]
431
+ bl = pts[np.argmax(diff)]
432
+ src = np.array([tl, tr, br, bl], dtype=np.float32)
433
+
434
+ w1 = np.linalg.norm(br-bl)
435
+ w2 = np.linalg.norm(tr-tl)
436
+ h1 = np.linalg.norm(tr-br)
437
+ h2 = np.linalg.norm(tl-bl)
438
+ out_w = int(max(w1, w2))
439
+ out_h = int(max(h1, h2))
440
+
441
+ dst = np.array(
442
+ [[0, 0],
443
+ [out_w - 1, 0],
444
+ [out_w - 1, out_h - 1],
445
+ [0, out_h - 1]],
446
+ dtype=np.float32)
447
+
448
+ M = cv2.getPerspectiveTransform(src, dst)
449
+ warped = cv2.warpPerspective(image_np, M, (out_w, out_h))
450
+ return warped
451
+
452
+
453
+ def toggle_stereological_mode(use_stereology):
454
+ """Show/hide stereological controls based on checkbox"""
455
+ return gr.update(visible=use_stereology)
456
+
457
+
458
+ def update_exclusion_preview(image, left_width, top_width):
459
+ """Update the preview image with exclusion zone overlay"""
460
+ if image is None:
461
+ return None
462
+
463
+ image_np = np.array(image)
464
+ overlay = draw_exclusion_overlay(image_np, left_width, top_width)
465
+ return Image.fromarray(overlay)
466
+
467
+
468
+ # ---------------------------------------------------------------------------
469
+ # Patch segmentation
470
+ # ---------------------------------------------------------------------------
471
+ PATCH_SIZE = 512 # target patch side length
472
+ PATCH_OVERLAP = 64 # overlap border on each edge (pixels)
473
+ MIN_PATCH_DIM = 256 # don't bother patching if image fits comfortably
474
+
475
+
476
+ def _split_patches(image_np, patch_size=PATCH_SIZE, overlap=PATCH_OVERLAP):
477
+ """
478
+ Split image into overlapping patches.
479
+ Returns list of (patch_np, row_start, col_start) tuples.
480
+ """
481
+ h, w = image_np.shape[:2]
482
+ patches = []
483
+ row = 0
484
+ while row < h:
485
+ row_end = min(row + patch_size, h)
486
+ col = 0
487
+ while col < w:
488
+ col_end = min(col + patch_size, w)
489
+ patch = image_np[row:row_end, col:col_end]
490
+ patches.append((patch, row, col))
491
+ if col_end == w:
492
+ break
493
+ col += patch_size - overlap
494
+ if row_end == h:
495
+ break
496
+ row += patch_size - overlap
497
+ return patches
498
+
499
+
500
+ def _merge_patch_masks(patch_results, full_h, full_w, overlap=PATCH_OVERLAP):
501
+ """
502
+ Stitch per-patch masks into a single full-image mask.
503
+
504
+ Strategy:
505
+ - Each patch gets a unique ID offset so cell IDs never collide.
506
+ - Patches are pasted into the canvas using a priority canvas that
507
+ gives interior pixels precedence over overlap-border pixels.
508
+ - After pasting, cells whose centroids fall in the overlap zone
509
+ of two adjacent patches are deduplicated: if two cells from
510
+ different patches share >50% IoU they are the same cell — keep
511
+ the one whose centroid is furthest from a patch edge.
512
+ """
513
+ full_mask = np.zeros((full_h, full_w), dtype=np.int32)
514
+ # track which patch_idx owns each pixel (used for overlap resolution)
515
+ owner_map = np.full((full_h, full_w), -1, dtype=np.int32)
516
+ # distance-to-nearest-edge for the owning patch (higher = more central)
517
+ priority = np.zeros((full_h, full_w), dtype=np.float32)
518
+
519
+ id_offset = 0
520
+ patch_meta = [] # (offset, row_start, col_start, patch_h, patch_w)
521
+
522
+ for patch_idx, (mask_patch, row_start, col_start) in enumerate(patch_results):
523
+ ph, pw = mask_patch.shape
524
+ # offset all non-zero IDs so they're globally unique
525
+ shifted = np.where(mask_patch > 0, mask_patch + id_offset, 0).astype(np.int32)
526
+
527
+ # compute per-pixel priority = min distance to any patch edge
528
+ rows_idx = np.arange(ph)
529
+ cols_idx = np.arange(pw)
530
+ dist_r = np.minimum(rows_idx, ph - 1 - rows_idx) # (ph,)
531
+ dist_c = np.minimum(cols_idx, pw - 1 - cols_idx) # (pw,)
532
+ pri_patch = np.minimum(dist_r[:, None], dist_c[None, :]) # (ph, pw)
533
+
534
+ roi_full = full_mask [row_start:row_start+ph, col_start:col_start+pw]
535
+ roi_owner = owner_map [row_start:row_start+ph, col_start:col_start+pw]
536
+ roi_pri = priority [row_start:row_start+ph, col_start:col_start+pw]
537
+
538
+ # where this patch has higher priority, overwrite
539
+ better = pri_patch > roi_pri
540
+ roi_full [better] = shifted [better]
541
+ roi_owner[better] = patch_idx
542
+ roi_pri [better] = pri_patch [better]
543
+
544
+ max_id = int(mask_patch.max())
545
+ patch_meta.append((id_offset, row_start, col_start, ph, pw))
546
+ id_offset += max_id + 1
547
+
548
+ # --- Renumber to compact sequential IDs ---
549
+ unique_ids = np.unique(full_mask)
550
+ unique_ids = unique_ids[unique_ids > 0]
551
+ renumbered = np.zeros_like(full_mask)
552
+ for new_id, old_id in enumerate(unique_ids, start=1):
553
+ renumbered[full_mask == old_id] = new_id
554
+
555
+ return renumbered
556
+
557
+
558
+ def _segment_patch(args):
559
+ """Worker: run cellpose on a single patch. Called from a thread pool."""
560
+ patch_np, row_start, col_start, model_filename, hf_repo = args
561
+ # Each thread uses the shared loaded_models cache (GIL-safe for reads;
562
+ # model.eval() releases the GIL during GPU work so threads overlap.)
563
+ model_path = hf_hub_download(repo_id=hf_repo, filename=model_filename)
564
+ if model_filename in loaded_models:
565
+ model = loaded_models[model_filename]
566
+ else:
567
+ model = models.CellposeModel(gpu=True, pretrained_model=model_path)
568
+ loaded_models[model_filename] = model
569
+
570
+ mask, _, _ = model.eval(patch_np, diameter=None, channels=[0, 0])
571
+ return mask, row_start, col_start
572
+
573
+
574
+ def run_segmentation_patched(image_np, model_filename):
575
+ """
576
+ Split image into overlapping patches, run Cellpose on each in parallel,
577
+ then stitch back into a single full-resolution mask.
578
+ Falls back to whole-image segmentation if the image is small enough
579
+ that patching adds overhead without benefit.
580
+ """
581
+ h, w = image_np.shape[:2]
582
+ model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=model_filename)
583
+ if model_filename in loaded_models:
584
+ model = loaded_models[model_filename]
585
+ else:
586
+ model = models.CellposeModel(gpu=True, pretrained_model=model_path)
587
+ loaded_models[model_filename] = model
588
+
589
+ # Small images: no benefit from patching
590
+ if max(h, w) <= MIN_PATCH_DIM * 2:
591
+ mask, _, _ = model.eval(image_np, diameter=None, channels=[0, 0])
592
+ return mask, 1 # 1 patch
593
+
594
+ patches = _split_patches(image_np)
595
+ n_patches = len(patches)
596
+
597
+ # Build argument list for the thread pool
598
+ args_list = [
599
+ (patch, r, c, model_filename, HF_REPO_ID)
600
+ for patch, r, c in patches
601
+ ]
602
+
603
+ patch_results = [] # (mask, row_start, col_start) in submission order
604
+
605
+ # ThreadPoolExecutor: GPU kernels release the GIL so threads overlap on GPU
606
+ with ThreadPoolExecutor(max_workers=min(n_patches, 4)) as pool:
607
+ futures = {pool.submit(_segment_patch, a): a for a in args_list}
608
+ for future in as_completed(futures):
609
+ mask_patch, row_start, col_start = future.result()
610
+ patch_results.append((mask_patch, row_start, col_start))
611
+
612
+ # Re-sort by (row, col) so stitching is deterministic
613
+ patch_results.sort(key=lambda x: (x[1], x[2]))
614
+
615
+ full_mask = _merge_patch_masks(patch_results, h, w)
616
+ return full_mask, n_patches
617
+
618
+
619
+ @spaces.GPU
620
+ def run_segmentation(image, model_choice, min_cell_size, max_cell_size,
621
+ use_stereology, left_exclusion, top_exclusion,
622
+ crop_points=None):
623
+ image_np = np.array(image)
624
+ image_np = safe_resize(image_np)
625
+
626
+ raw_image_np = image_np.copy()
627
+
628
+ # Apply polygon crop mask if the user drew one (need ≥3 points for a polygon)
629
+ if crop_points and len(crop_points) >= 3:
630
+ import json
631
+ pts_json = json.dumps(crop_points)
632
+ image_pil_masked = apply_polygon_mask(Image.fromarray(image_np), pts_json)
633
+ image_np = np.array(image_pil_masked)
634
+
635
+ if len(crop_points) == 4:
636
+ image_np = warp_polygon_to_square(image_np, crop_points)
637
+
638
+
639
+ try:
640
+ model_filename = MODEL_OPTIONS[model_choice]
641
+
642
+ # Process image format to RGB
643
+ if len(image_np.shape) == 2:
644
+ processed_image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
645
+ elif len(image_np.shape) == 3 and image_np.shape[2] == 4:
646
+ processed_image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
647
+ else:
648
+ processed_image_np = image_np
649
+
650
+ # Run patch-parallel Cellpose segmentation
651
+ masks_raw, n_patches = run_segmentation_patched(processed_image_np, model_filename)
652
+
653
+ ids = np.unique(masks_raw)
654
+ ids = ids[ids > 0]
655
+
656
+ sizes = np.array([np.count_nonzero(masks_raw == cid) for cid in ids])
657
+
658
+ print("num_cells:", len(ids))
659
+ print("mean:", sizes.mean() if len(sizes) > 0 else 0)
660
+ print("median:", np.median(sizes) if len(sizes) > 0 else 0)
661
+ print("p90:", np.percentile(sizes, 90) if len(sizes) > 0 else 0)
662
+ print("max:", sizes.max() if len(sizes) > 0 else 0)
663
+
664
+ # Compute recommendation from RAW masks
665
+ recommend_min = rec_min_size(masks_raw)
666
+
667
+ # If user sets slider to 0, use the recommendation
668
+ min_used = recommend_min if (min_cell_size == 0) else int(min_cell_size)
669
+
670
+ # Apply filters
671
+ masks = masks_raw.copy()
672
+ removed_small = 0
673
+ removed_large = 0
674
+
675
+ if min_used > 0:
676
+ masks, removed_small = filter_mask_by_size(masks, min_used)
677
+
678
+ if max_cell_size > 0:
679
+ masks, removed_large = filter_mask_by_maxsize(masks, int(max_cell_size))
680
+
681
+ # Apply stereological exclusion if enabled
682
+ excluded_count = 0
683
+ if use_stereology:
684
+ masks, excluded_count, included_count = apply_stereological_exclusion(
685
+ masks, left_exclusion, top_exclusion
686
+ )
687
+
688
+ filter_msg = ""
689
+ if removed_small:
690
+ filter_msg += f"Removed {removed_small} small objects (< {min_used} pixels).\n"
691
+ if removed_large:
692
+ filter_msg += f"Removed {removed_large} large objects (> {int(max_cell_size)} pixels).\n"
693
+ if use_stereology and excluded_count > 0:
694
+ filter_msg += f"Stereological exclusion: {excluded_count} cells excluded (touching left/top zones).\n"
695
+
696
+ cell_count = len(np.unique(masks)) - 1
697
+ confluency = measure_confluency(masks, processed_image_np)
698
+
699
+ # Create a basic segmentation overlay (without viability)
700
+ segmentation_overlay = processed_image_np.copy().astype(np.float32)
701
+ if masks.max() > 0:
702
+ np.random.seed(42) # For consistent random colors
703
+ colors = np.random.randint(0, 255, size=(masks.max() + 1, 3))
704
+ colors[0] = [0, 0, 0]
705
+ colored_mask = colors[masks]
706
+ alpha = 0.4
707
+ segmentation_overlay = (1 - alpha) * segmentation_overlay + alpha * colored_mask
708
+ segmentation_overlay = np.clip(segmentation_overlay, 0, 255).astype(np.uint8)
709
+
710
+ # Add exclusion zone overlay if stereology is enabled
711
+ if use_stereology:
712
+ segmentation_overlay = draw_exclusion_overlay(segmentation_overlay, left_exclusion, top_exclusion)
713
+
714
+ info_msg = ""
715
+ if filter_msg:
716
+ info_msg += filter_msg
717
+ info_msg += f"Segmentation complete! Found {cell_count} cells.\n"
718
+ info_msg += f"Confluency: {confluency:.1f}%\n"
719
+ info_msg += f"Processed as {n_patches} patch{'es' if n_patches > 1 else ''} (parallel).\n"
720
+ if use_stereology:
721
+ info_msg += f"Stereological counting enabled (Left: {left_exclusion}%, Top: {top_exclusion}%)\n"
722
+ info_msg += "Now run the viability classification model for viability assessment."
723
+
724
+ return (
725
+ cell_count,
726
+ Image.fromarray(segmentation_overlay),
727
+ info_msg,
728
+ gr.update(visible=True),
729
+ pack_array(masks),
730
+ pack_array(processed_image_np),
731
+ confluency,
732
+ gr.update(value=recommend_min),
733
+ pack_array(raw_image_np),
734
+ )
735
+
736
+ except Exception as e:
737
+ import traceback
738
+ traceback.print_exc()
739
+ return (
740
+ 0,
741
+ None,
742
+ f"Error during segmentation: {str(e)}",
743
+ gr.update(visible=False),
744
+ None,
745
+ None,
746
+ 0.0,
747
+ gr.update(),
748
+ None,
749
+ )
750
+
751
+
752
+ def run_viability(stored_masks, stored_image_np):
753
+ """Run model-based viability classification. Returns overlay + counts + label_map."""
754
+ if stored_masks is None or stored_image_np is None:
755
+ return None, 0, 0, 0.0, "Please run segmentation first.", {}
756
+ if VIABILITY_CLF is None:
757
+ return None, 0, 0, 0.0, "No viability model found. Add viability_clf.pkl and viability_scaler.pkl to the app directory.", {}
758
+
759
+ masks = unpack_array(stored_masks)
760
+ image_np = unpack_array(stored_image_np)
761
+
762
+ try:
763
+ dead, alive, overlay_np, label_map = classify_cells_by_model(image_np, masks)
764
+ total = alive + dead
765
+ viab_pct = (alive / total * 100) if total > 0 else 0.0
766
+ confluency = measure_confluency(masks, image_np)
767
+ info_msg = f"Total cells: {total}\nLive (green): {alive}\nDead (red): {dead}\n"
768
+ info_msg += f"Viability: {viab_pct:.1f}%\nConfluency: {confluency:.1f}%"
769
+ return Image.fromarray(overlay_np), alive, dead, viab_pct, info_msg, label_map
770
+ except Exception as e:
771
+ import traceback; traceback.print_exc()
772
+ return None, 0, 0, 0.0, f"Error: {str(e)}", {}
773
+
774
+
775
+ def pack_array(arr):
776
+ pil = Image.fromarray(arr.astype(np.uint8))
777
+ buf = io.BytesIO()
778
+ pil.save(buf, format="PNG")
779
+ return buf.getvalue()
780
+
781
+
782
+ def unpack_array(data):
783
+ return np.array(Image.open(io.BytesIO(data)))
784
+
785
+
786
+ def save_tab_result(cell_count, confluency, viab_percent):
787
+ """Package per-tab results into a dict for Tab 5 summary."""
788
+ return {
789
+ "cell_count": float(cell_count) if cell_count is not None else None,
790
+ "confluency": float(confluency) if confluency is not None else None,
791
+ "viab_percent": float(viab_percent) if viab_percent is not None else None,
792
+ }
793
+
794
+
795
+ def compute_summary(r1, r2, r3, r4):
796
+ """Average cell count, confluency, and viability across tabs that have data."""
797
+ all_results = [r1, r2, r3, r4]
798
+ valid = [(i + 1, r) for i, r in enumerate(all_results)
799
+ if r is not None and r.get("cell_count") is not None]
800
+
801
+ if not valid:
802
+ return (
803
+ 0.0, 0.0, 0.0,
804
+ "No data yet — run segmentation in at least one tab, then click Refresh Summary."
805
+ )
806
+
807
+ avg_count = sum(r["cell_count"] for _, r in valid) / len(valid)
808
+ avg_conf = sum(r["confluency"] for _, r in valid) / len(valid)
809
+ avg_viab = sum(r["viab_percent"] for _, r in valid) / len(valid)
810
+
811
+ lines = [f"Tab {tab_num}: {r['cell_count']:.0f} cells | "
812
+ f"{r['confluency']:.1f}% confluency | "
813
+ f"{r['viab_percent']:.1f}% viability"
814
+ for tab_num, r in valid]
815
+ lines.append(f"\nAverages ({len(valid)} tab{'s' if len(valid) > 1 else ''}):")
816
+ lines.append(f" Cell count: {avg_count:.1f}")
817
+ lines.append(f" Confluency: {avg_conf:.1f}%")
818
+ lines.append(f" Viability: {avg_viab:.1f}%")
819
+
820
+ return avg_count, avg_conf, avg_viab, "\n".join(lines)
821
+
822
+
823
+ # ---------------------------------------------------------------------------
824
+ # Training data export — feature extraction per cell
825
+ # ---------------------------------------------------------------------------
826
+
827
+ def extract_cell_features(image_np, masks):
828
+ """
829
+ For every segmented cell, extract a fixed feature vector from the pixels
830
+ inside its mask. Returns a list of dicts, one per cell.
831
+
832
+ Features:
833
+ RGB channels — mean_r, mean_g, mean_b, std_r, std_g, std_b
834
+ HSV channels — mean_h, mean_s, mean_v, std_s, std_v
835
+ Ratios — blue_red_ratio, blue_green_ratio, rg_ratio
836
+ Morphology — area_px, circularity
837
+ Centre/edge profile — inner_brightness, peak_brightness,
838
+ bright_spot_fraction, ring_darkness,
839
+ centre_periphery_ratio, brightness_std_normalised
840
+
841
+ Profile zones are tuned to hemocytometer live-cell morphology:
842
+ a small intense specular highlight at the centre surrounded by a dark
843
+ navy membrane ring. Dead cells are pale blue-grey blobs with no ring
844
+ and no bright spot.
845
+ """
846
+ if len(image_np.shape) == 2:
847
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
848
+ elif image_np.shape[2] == 4:
849
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
850
+
851
+ hsv = cv2.cvtColor(image_np, cv2.COLOR_RGB2HSV).astype(np.float32)
852
+
853
+ h_img, w_img = image_np.shape[:2]
854
+ grid_y, grid_x = np.mgrid[:h_img, :w_img]
855
+
856
+ cell_ids = np.unique(masks)
857
+ cell_ids = cell_ids[cell_ids > 0]
858
+ rows = []
859
+
860
+ for cid in cell_ids:
861
+ cell_mask = (masks == cid)
862
+ pixels_rgb = image_np[cell_mask].astype(np.float32)
863
+ pixels_hsv = hsv[cell_mask]
864
+
865
+ r, g, b = pixels_rgb[:, 0], pixels_rgb[:, 1], pixels_rgb[:, 2]
866
+ h, s, v = pixels_hsv[:, 0], pixels_hsv[:, 1], pixels_hsv[:, 2]
867
+
868
+ eps = 1e-6
869
+ blue_red_ratio = b.mean() / (r.mean() + eps)
870
+ blue_green_ratio = b.mean() / (g.mean() + eps)
871
+ rg_ratio = r.mean() / (g.mean() + eps)
872
+
873
+ area_px = int(cell_mask.sum())
874
+ contours, _ = cv2.findContours(
875
+ cell_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
876
+ )
877
+ perimeter = cv2.arcLength(contours[0], True) if contours else 1.0
878
+ circularity = (4 * np.pi * area_px / (perimeter ** 2 + eps)) if perimeter > 0 else 0.0
879
+
880
+ ys_cell = grid_y[cell_mask].astype(np.float32)
881
+ xs_cell = grid_x[cell_mask].astype(np.float32)
882
+ centroid_y = ys_cell.mean()
883
+ centroid_x = xs_cell.mean()
884
+
885
+ cell_radius = np.sqrt(area_px / np.pi) + eps
886
+ dist_norm = np.sqrt((xs_cell - centroid_x)**2 + (ys_cell - centroid_y)**2) / cell_radius
887
+
888
+ v_all = hsv[:, :, 2][cell_mask]
889
+
890
+ # Tight inner core (15% radius) — captures specular highlight spot only
891
+ inner_mask = dist_norm < 0.15
892
+ # Membrane ring zone (20-60%) — dark navy ring on live cells
893
+ ring_mask = (dist_norm >= 0.20) & (dist_norm <= 0.60)
894
+ # Outer zone (>60%) — denominator for centre ratio
895
+ outer_mask = dist_norm > 0.60
896
+
897
+ inner_brightness = float(v_all[inner_mask].mean()) if inner_mask.any() else float(v.mean())
898
+ ring_brightness = float(v_all[ring_mask].mean()) if ring_mask.any() else float(v.mean())
899
+ outer_brightness = float(v_all[outer_mask].mean()) if outer_mask.any() else float(v.mean())
900
+
901
+ # Peak V — specular spot is just a few pixels so mean dilutes it
902
+ peak_brightness = float(v_all.max())
903
+
904
+ # Fraction of cell pixels with V > 200 (specular highlight region)
905
+ bright_spot_fraction = float((v_all > 200).sum()) / (len(v_all) + eps)
906
+
907
+ # Ring darkness: ratio of ring zone to outer zone brightness
908
+ # Live: ring << outer (dark membrane ring) -> ratio < 1
909
+ # Dead: uniform blob -> ratio ~ 1
910
+ ring_darkness = ring_brightness / (outer_brightness + eps)
911
+
912
+ centre_periphery_ratio = inner_brightness / (outer_brightness + eps)
913
+
914
+ brightness_std_normalised = float(v.std()) / (float(v.mean()) + eps)
915
+
916
+ rows.append({
917
+ "cell_id": int(cid),
918
+ "mean_r": float(r.mean()),
919
+ "mean_g": float(g.mean()),
920
+ "mean_b": float(b.mean()),
921
+ "std_r": float(r.std()),
922
+ "std_g": float(g.std()),
923
+ "std_b": float(b.std()),
924
+ "mean_h": float(h.mean()),
925
+ "mean_s": float(s.mean()),
926
+ "mean_v": float(v.mean()),
927
+ "std_s": float(s.std()),
928
+ "std_v": float(v.std()),
929
+ "blue_red_ratio": round(blue_red_ratio, 5),
930
+ "blue_green_ratio": round(blue_green_ratio, 5),
931
+ "rg_ratio": round(rg_ratio, 5),
932
+ "area_px": area_px,
933
+ "circularity": round(float(circularity), 5),
934
+ "inner_brightness": round(inner_brightness, 3),
935
+ "peak_brightness": round(peak_brightness, 3),
936
+ "bright_spot_fraction": round(bright_spot_fraction, 6),
937
+ "ring_darkness": round(ring_darkness, 5),
938
+ "centre_periphery_ratio": round(centre_periphery_ratio, 5),
939
+ "brightness_std_normalised": round(brightness_std_normalised, 5),
940
+ })
941
+
942
+ return rows
943
+
944
+ def attach_viability_labels(cell_features, masks, image_np, label_map=None):
945
+ """
946
+ Attach model predictions (from label_map) to each feature dict.
947
+ label_map: {cell_id: 0=live, 1=dead} from classify_cells_by_model.
948
+ If label_map is None, defaults all labels to 0 (live).
949
+ """
950
+ if not cell_features:
951
+ return []
952
+ labelled = []
953
+ for feat in cell_features:
954
+ row = dict(feat)
955
+ cid = int(feat["cell_id"])
956
+ row["label"] = int(label_map.get(cid, 0)) if label_map else 0
957
+ row["corrected"] = False
958
+ labelled.append(row)
959
+ return labelled
960
+
961
+
962
+ def export_cell_data_csv(cell_data):
963
+ """Write cell_data list-of-dicts to a temp CSV and return the file path."""
964
+ if not cell_data:
965
+ return None
966
+ tmp = tempfile.NamedTemporaryFile(
967
+ mode="w", suffix=".csv", delete=False, newline=""
968
+ )
969
+ # Union of all keys across rows so any late-added keys (e.g. "corrected") are included
970
+ fieldnames = list(dict.fromkeys(k for row in cell_data for k in row.keys()))
971
+ writer = csv.DictWriter(tmp, fieldnames=fieldnames, extrasaction="ignore")
972
+ writer.writeheader()
973
+ writer.writerows(cell_data)
974
+ tmp.close()
975
+ return tmp.name
976
+
977
+
978
+ def prepare_export(stored_masks, stored_image, threshold_bias):
979
+ """
980
+ Called by the Export button. Unpacks state, extracts features,
981
+ attaches labels, writes CSV, returns (path, status_message).
982
+ """
983
+ if stored_masks is None or stored_image is None:
984
+ return None, "Run segmentation first before exporting."
985
+
986
+ masks = unpack_array(stored_masks)
987
+ image_np = unpack_array(stored_image)
988
+
989
+ features = extract_cell_features(image_np, masks)
990
+ if not features:
991
+ return None, "No cells found to export."
992
+
993
+ labelled = attach_viability_labels(features, masks, image_np, threshold_bias)
994
+ path = export_cell_data_csv(labelled)
995
+
996
+ n = len(labelled)
997
+ dead = sum(1 for r in labelled if r["label"] == 1)
998
+ alive = n - dead
999
+ msg = (f"Exported {n} cells ({alive} live, {dead} dead) — "
1000
+ f"threshold bias={threshold_bias:+d}.\n"
1001
+ f"Columns: {', '.join(list(labelled[0].keys())[:6])}… "
1002
+ f"({len(labelled[0])} total).")
1003
+ return path, msg
1004
+
1005
+
1006
+ # ---------------------------------------------------------------------------
1007
+ # Tab builder
1008
+ # ---------------------------------------------------------------------------
1009
+
1010
+ def draw_polygon_overlay(image_pil, points):
1011
+ """
1012
+ Draw numbered vertex dots and polygon edges onto a copy of image_pil.
1013
+ points: list of (x, y) tuples in original image pixel space.
1014
+ Returns a new PIL image.
1015
+ """
1016
+ img = image_pil.copy().convert("RGBA")
1017
+ overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
1018
+ draw = ImageDraw.Draw(overlay)
1019
+
1020
+ if len(points) >= 2:
1021
+ # Draw edges
1022
+ for i in range(len(points) - 1):
1023
+ draw.line([points[i], points[i + 1]], fill=(74, 170, 255, 220), width=3)
1024
+ if len(points) == 4:
1025
+ draw.line([points[-1], points[0]], fill=(74, 170, 255, 220), width=3)
1026
+ # Semi-transparent fill
1027
+ draw.polygon(points, fill=(74, 170, 255, 50))
1028
+
1029
+ # Draw vertex dots + numbers
1030
+ r = max(8, min(img.width, img.height) // 60)
1031
+ for i, (x, y) in enumerate(points):
1032
+ draw.ellipse([x - r, y - r, x + r, y + r],
1033
+ fill=(74, 170, 255, 255), outline=(255, 255, 255, 255))
1034
+ draw.text((x, y), str(i + 1), fill=(255, 255, 255, 255), anchor="mm")
1035
+
1036
+ combined = Image.alpha_composite(img, overlay)
1037
+ return combined.convert("RGB")
1038
+
1039
+
1040
+ def add_crop_point(image_pil, points, evt: gr.SelectData):
1041
+ """
1042
+ Called by gr.Image .select(). Appends the clicked coordinate,
1043
+ redraws the overlay, returns (updated_image, updated_points).
1044
+ Ignores clicks once 4 points are set.
1045
+ """
1046
+ if image_pil is None:
1047
+ return image_pil, points
1048
+ if points is None:
1049
+ points = []
1050
+ if len(points) >= 4:
1051
+ return draw_polygon_overlay(image_pil, points), points
1052
+
1053
+ x, y = int(evt.index[0]), int(evt.index[1])
1054
+ new_points = points + [(x, y)]
1055
+ return draw_polygon_overlay(image_pil, new_points), new_points
1056
+
1057
+
1058
+ def clear_crop_points(image_pil):
1059
+ """Reset polygon — return original image with no overlay and empty points."""
1060
+ return image_pil, []
1061
+
1062
+
1063
+
1064
+
1065
+
1066
+ # ---------------------------------------------------------------------------
1067
+ # Label correction grid
1068
+ # ---------------------------------------------------------------------------
1069
+
1070
+ THUMB_SIZE = 80 # each cell thumbnail is THUMB_SIZE × THUMB_SIZE px
1071
+ GRID_COLS = 6 # thumbnails per row
1072
+ BORDER = 4 # coloured border thickness in px
1073
+ LABEL_H = 16 # height of the text label strip at the bottom of each thumb
1074
+
1075
+ def _crop_cell_thumb(image_np, masks, cid):
1076
+ """
1077
+ Return a tight square crop of the cell, padded to THUMB_SIZE × THUMB_SIZE.
1078
+ """
1079
+ ys, xs = np.where(masks == cid)
1080
+ if len(ys) == 0:
1081
+ return Image.fromarray(np.zeros((THUMB_SIZE, THUMB_SIZE, 3), dtype=np.uint8))
1082
+
1083
+ y0, y1 = ys.min(), ys.max() + 1
1084
+ x0, x1 = xs.min(), xs.max() + 1
1085
+
1086
+ # add a small context border around the tight bounding box
1087
+ pad = max(4, int(max(y1 - y0, x1 - x0) * 0.15))
1088
+ h, w = image_np.shape[:2]
1089
+ y0c = max(0, y0 - pad)
1090
+ y1c = min(h, y1 + pad)
1091
+ x0c = max(0, x0 - pad)
1092
+ x1c = min(w, x1 + pad)
1093
+
1094
+ crop = image_np[y0c:y1c, x0c:x1c].copy()
1095
+
1096
+ # dim pixels that don't belong to this cell
1097
+ dim_mask = (masks[y0c:y1c, x0c:x1c] != cid)
1098
+ crop[dim_mask] = (crop[dim_mask] * 0.3).astype(np.uint8)
1099
+
1100
+ pil = Image.fromarray(crop).resize((THUMB_SIZE, THUMB_SIZE), Image.LANCZOS)
1101
+ return pil
1102
+
1103
+
1104
+ def build_correction_grid(image_np, masks, labelled_features, raw_image_np=None):
1105
+ """
1106
+ Render all cell thumbnails into a single PIL image grid.
1107
+ Each thumbnail has a coloured border: green=live(0), red=dead(1).
1108
+ A small number in the corner identifies the cell_id.
1109
+
1110
+ Returns the PIL grid image.
1111
+ Cell order in the grid matches the order of labelled_features.
1112
+ """
1113
+ if not labelled_features:
1114
+ placeholder = Image.fromarray(
1115
+ np.zeros((THUMB_SIZE, THUMB_SIZE, 3), dtype=np.uint8)
1116
+ )
1117
+ return placeholder
1118
+
1119
+ thumb_src = raw_image_np if raw_image_np is not None else image_np
1120
+
1121
+ n = len(labelled_features)
1122
+ n_cols = GRID_COLS
1123
+ n_rows = (n + n_cols - 1) // n_cols
1124
+
1125
+ cell_h = THUMB_SIZE + 2 * BORDER + LABEL_H
1126
+ cell_w = THUMB_SIZE + 2 * BORDER
1127
+
1128
+ grid_w = n_cols * cell_w
1129
+ grid_h = n_rows * cell_h
1130
+
1131
+ grid = Image.new("RGB", (grid_w, grid_h), (30, 30, 30))
1132
+ draw = ImageDraw.Draw(grid)
1133
+
1134
+ for idx, feat in enumerate(labelled_features):
1135
+ cid = feat["cell_id"]
1136
+ label = feat["label"] # 0=live, 1=dead (may have been corrected)
1137
+ color = (220, 50, 50) if label == 1 else (50, 200, 80)
1138
+
1139
+ thumb = _crop_cell_thumb(thumb_src, masks, cid)
1140
+
1141
+ col = idx % n_cols
1142
+ row = idx // n_cols
1143
+ x0 = col * cell_w
1144
+ y0 = row * cell_h
1145
+
1146
+ # coloured border rectangle
1147
+ draw.rectangle([x0, y0, x0 + cell_w - 1, y0 + cell_h - 1], outline=color, width=BORDER)
1148
+
1149
+ # paste thumbnail inside border
1150
+ grid.paste(thumb, (x0 + BORDER, y0 + BORDER))
1151
+
1152
+ # small cell-id label strip
1153
+ strip_y = y0 + BORDER + THUMB_SIZE
1154
+ draw.rectangle([x0, strip_y, x0 + cell_w - 1, y0 + cell_h - 1],
1155
+ fill=(20, 20, 20))
1156
+ draw.text((x0 + BORDER + 2, strip_y + 1),
1157
+ f"#{cid} {'D' if label == 1 else 'L'}",
1158
+ fill=color)
1159
+
1160
+ return grid
1161
+
1162
+
1163
+ def toggle_cell_label(labelled_features, image_np, masks, raw_image_np, evt: gr.SelectData):
1164
+ """
1165
+ Called when user taps the correction grid image.
1166
+ Maps the tap pixel coordinate back to which thumbnail was tapped,
1167
+ flips that cell's label, rebuilds and returns the updated grid.
1168
+ """
1169
+ if not labelled_features or image_np is None:
1170
+ return build_correction_grid(image_np, masks, labelled_features), labelled_features
1171
+
1172
+ cell_w = THUMB_SIZE + 2 * BORDER
1173
+ cell_h = THUMB_SIZE + 2 * BORDER + LABEL_H
1174
+
1175
+ px, py = int(evt.index[0]), int(evt.index[1])
1176
+ col = px // cell_w
1177
+ row = py // cell_h
1178
+ idx = row * GRID_COLS + col
1179
+
1180
+ if idx < 0 or idx >= len(labelled_features):
1181
+ return build_correction_grid(image_np, masks, labelled_features, raw_image_np), labelled_features
1182
+
1183
+ # Flip the label
1184
+ updated = list(labelled_features) # shallow copy of list
1185
+ cell = dict(updated[idx]) # copy the dict so we don't mutate in place
1186
+ cell["label"] = 1 - cell["label"] # 0→1 or 1→0
1187
+ cell["corrected"] = True
1188
+ updated[idx] = cell
1189
+
1190
+ grid = build_correction_grid(image_np, masks, updated, raw_image_np)
1191
+ n_corrected = sum(1 for f in updated if f.get("corrected"))
1192
+ return grid, updated, f"Tapped cell #{cell['cell_id']} → {'Dead' if cell['label']==1 else 'Live'}. {n_corrected} correction(s) total."
1193
+
1194
+
1195
+ def prepare_export_corrected(stored_masks, stored_image, labelled_features, label_map):
1196
+ """Export CSV using labelled_features with any manual corrections applied."""
1197
+ if stored_masks is None or stored_image is None:
1198
+ return None, "Run segmentation first before exporting."
1199
+ masks = unpack_array(stored_masks)
1200
+ image_np = unpack_array(stored_image)
1201
+ if not labelled_features:
1202
+ features = extract_cell_features(image_np, masks)
1203
+ labelled_features = attach_viability_labels(features, masks, image_np, label_map)
1204
+ if not labelled_features:
1205
+ return None, "No cells found to export."
1206
+ path = export_cell_data_csv(labelled_features)
1207
+ n = len(labelled_features)
1208
+ dead = sum(1 for r in labelled_features if r["label"] == 1)
1209
+ alive = n - dead
1210
+ corrected = sum(1 for r in labelled_features if r.get("corrected"))
1211
+ msg = (f"Exported {n} cells ({alive} live, {dead} dead). "
1212
+ f"{corrected} label(s) manually corrected.")
1213
+ return path, msg
1214
+
1215
+ def build_tab(tab_index, masks_state, image_state, result_state):
1216
+ with gr.Tab(f"Tab {tab_index}"):
1217
+ gr.Markdown("Run segmentation")
1218
+
1219
+ # Per-tab state: list of (x,y) crop polygon points
1220
+ crop_points_state = gr.State(value=[])
1221
+ # Clean copy of the uploaded image (no polygon drawn on it)
1222
+ base_image_state = gr.State(value=None)
1223
+ #raw image state
1224
+ raw_image_state = gr.State(value=None)
1225
+
1226
+ with gr.Row():
1227
+ with gr.Column():
1228
+ img_input = gr.Image(
1229
+ type="pil",
1230
+ label="Upload image",
1231
+ image_mode="RGB",
1232
+ height=512
1233
+ )
1234
+
1235
+ gr.Markdown(
1236
+ "### Crop region (optional)\n"
1237
+ "Click/tap up to **4 points** on the image below to define the region "
1238
+ "to segment. The polygon will be drawn as you click. "
1239
+ "Leave empty to segment the full image."
1240
+ )
1241
+
1242
+ crop_display = gr.Image(
1243
+ type="pil",
1244
+ label="Click to set crop vertices (up to 4)",
1245
+ interactive=True,
1246
+ height=400,
1247
+ )
1248
+
1249
+ crop_status = gr.Markdown("*Upload an image to enable cropping*")
1250
+
1251
+ clear_crop_btn = gr.Button("✕ Clear crop points", size="sm")
1252
+
1253
+ model_dropdown = gr.Dropdown(
1254
+ choices=list(MODEL_OPTIONS.keys()),
1255
+ label="Select Model",
1256
+ value="Hemocytometer Model"
1257
+ )
1258
+
1259
+ min_size_slider = gr.Slider(
1260
+ minimum=0,
1261
+ maximum=500,
1262
+ value=0,
1263
+ step=10,
1264
+ label="Minimum Cell Size (pixels). Leave at zero for automated recommendation",
1265
+ )
1266
+
1267
+ max_size_slider = gr.Slider(
1268
+ minimum=0,
1269
+ maximum=10000,
1270
+ value=10000,
1271
+ step=10,
1272
+ label="Maximum Cell Size (pixels)",
1273
+ )
1274
+
1275
+ gr.Markdown("### Stereological Counting")
1276
+ use_stereo = gr.Checkbox(
1277
+ label="Enable Stereological Counting",
1278
+ value=False,
1279
+ info="Use unbiased stereological rules for cell counting"
1280
+ )
1281
+
1282
+ with gr.Group(visible=False) as stereo_controls:
1283
+ gr.Markdown("""
1284
+ **Stereological Counting Rules:**
1285
+ - Cells touching LEFT or TOP exclusion zones are EXCLUDED
1286
+ - Cells touching RIGHT or BOTTOM edges are INCLUDED
1287
+ - This provides unbiased counting for quantification
1288
+ """)
1289
+
1290
+ excl_preview = gr.Image(
1291
+ type="pil",
1292
+ label="Exclusion Zone Preview (Red = Excluded)",
1293
+ height=500
1294
+ )
1295
+
1296
+ left_excl = gr.Slider(
1297
+ minimum=0,
1298
+ maximum=50,
1299
+ value=10,
1300
+ step=1,
1301
+ label="Left Exclusion Width (%)",
1302
+ info="Width of left exclusion zone"
1303
+ )
1304
+
1305
+ top_excl = gr.Slider(
1306
+ minimum=0,
1307
+ maximum=50,
1308
+ value=10,
1309
+ step=1,
1310
+ label="Top Exclusion Width (%)",
1311
+ info="Width of top exclusion zone"
1312
+ )
1313
+
1314
+ segment_btn = gr.Button("🔬 Run Segmentation", variant="primary", size="lg")
1315
+
1316
+ with gr.Column():
1317
+ cell_count_out = gr.Number(label="Total Cells Detected", precision=0)
1318
+ confluency_out = gr.Number(label="Confluency (%)", precision=1)
1319
+ overlay_out = gr.Image(type="pil", label="Segmentation Result")
1320
+ info_out = gr.Textbox(label="Processing Info", lines=4)
1321
+
1322
+ with gr.Group(visible=False) as viability_section:
1323
+ gr.Markdown("### Viability Assessment (Trypan Blue)")
1324
+
1325
+ viab_run_btn = gr.Button("Run Viability Analysis", variant="primary")
1326
+
1327
+ with gr.Row():
1328
+ live_count_out = gr.Number(label="Live Cells (Green)", precision=0)
1329
+ dead_count_out = gr.Number(label="Dead Cells (Red)", precision=0)
1330
+
1331
+ viab_overlay = gr.Image(type="pil", label="Viability (Green=Live · Red=Dead)")
1332
+ viab_percent_out = gr.Number(label="Viability (%)", precision=1)
1333
+ viab_info = gr.Textbox(label="Analysis Results", lines=4)
1334
+
1335
+ gr.Markdown("### Label Correction & Export")
1336
+ gr.Markdown(
1337
+ "After running viability, click **Build correction grid** to review every cell. "
1338
+ "**Green border = Live, Red border = Dead** (model predictions). "
1339
+ "Tap any thumbnail to flip its label — the counts and overlay update instantly. "
1340
+ "Export the corrected CSV for retraining."
1341
+ )
1342
+
1343
+ build_grid_btn = gr.Button("🔲 Build correction grid", variant="secondary")
1344
+ labelled_state = gr.State(value=[])
1345
+ label_map_state = gr.State(value={})
1346
+
1347
+ correction_grid = gr.Image(
1348
+ type="pil",
1349
+ label="Tap a cell to flip its label (green=live · red=dead)",
1350
+ interactive=True,
1351
+ visible=False,
1352
+ )
1353
+ correction_status = gr.Markdown(visible=False)
1354
+
1355
+ with gr.Row():
1356
+ export_btn = gr.Button("⬇️ Export corrected CSV", variant="secondary")
1357
+ export_info = gr.Textbox(label="Export status", lines=2, interactive=False)
1358
+ export_file = gr.File(label="Download CSV", visible=False)
1359
+
1360
+ # ---- Event handlers ------------------------------------------------
1361
+
1362
+ use_stereo.change(
1363
+ fn=toggle_stereological_mode,
1364
+ inputs=[use_stereo],
1365
+ outputs=[stereo_controls]
1366
+ )
1367
+
1368
+ def on_image_upload(img):
1369
+ if img is None:
1370
+ return None, None, "*Upload an image to enable cropping*"
1371
+ return img, img, "*Image loaded — click up to 4 points to define crop region*"
1372
+
1373
+ img_input.change(
1374
+ fn=on_image_upload,
1375
+ inputs=[img_input],
1376
+ outputs=[crop_display, base_image_state, crop_status]
1377
+ ).then(fn=lambda: [], outputs=[crop_points_state])
1378
+
1379
+ img_input.change(fn=update_exclusion_preview,
1380
+ inputs=[img_input, left_excl, top_excl], outputs=[excl_preview])
1381
+ left_excl.change(fn=update_exclusion_preview,
1382
+ inputs=[img_input, left_excl, top_excl], outputs=[excl_preview])
1383
+ top_excl.change(fn=update_exclusion_preview,
1384
+ inputs=[img_input, left_excl, top_excl], outputs=[excl_preview])
1385
+
1386
+ def on_crop_click(base_img, points, evt: gr.SelectData):
1387
+ updated_img, updated_pts = add_crop_point(base_img, points, evt)
1388
+ n = len(updated_pts)
1389
+ status = (f"*{n} / 4 points set — keep clicking*" if n < 4
1390
+ else "*4 points set ✓ — click **✕ Clear** to redo, or run segmentation*")
1391
+ return updated_img, updated_pts, status
1392
+
1393
+ crop_display.select(fn=on_crop_click,
1394
+ inputs=[base_image_state, crop_points_state],
1395
+ outputs=[crop_display, crop_points_state, crop_status])
1396
+
1397
+ def on_clear_crop(base_img):
1398
+ img, pts = clear_crop_points(base_img)
1399
+ return img, pts, "*Points cleared — click to set new vertices*"
1400
+
1401
+ clear_crop_btn.click(fn=on_clear_crop,
1402
+ inputs=[base_image_state],
1403
+ outputs=[crop_display, crop_points_state, crop_status])
1404
+
1405
+ segment_btn.click(
1406
+ fn=run_segmentation,
1407
+ inputs=[img_input, model_dropdown, min_size_slider, max_size_slider,
1408
+ use_stereo, left_excl, top_excl, crop_points_state],
1409
+ outputs=[cell_count_out, overlay_out, info_out, viability_section,
1410
+ masks_state, image_state, confluency_out, min_size_slider, raw_image_state]
1411
+ )
1412
+
1413
+ # ---- Run Viability button -------------------------------------------
1414
+ def on_run_viability(stored_masks, stored_image):
1415
+ overlay, alive, dead, viab_pct, info, label_map = run_viability(stored_masks, stored_image)
1416
+ return overlay, alive, dead, viab_pct, info, label_map
1417
+
1418
+ viab_run_btn.click(
1419
+ fn=on_run_viability,
1420
+ inputs=[masks_state, image_state],
1421
+ outputs=[viab_overlay, live_count_out, dead_count_out,
1422
+ viab_percent_out, viab_info, label_map_state]
1423
+ ).then(
1424
+ fn=save_tab_result,
1425
+ inputs=[cell_count_out, confluency_out, viab_percent_out],
1426
+ outputs=[result_state]
1427
+ )
1428
+
1429
+ # ---- Build correction grid -----------------------------------------
1430
+ def on_build_grid(stored_masks, stored_image, label_map, stored_raw_image):
1431
+ if stored_masks is None or stored_image is None or not label_map:
1432
+ return (gr.update(visible=False), [],
1433
+ gr.update(value="*Run viability analysis first.*", visible=True))
1434
+ masks = unpack_array(stored_masks)
1435
+ image_np = unpack_array(stored_image)
1436
+ raw_image_np = unpack_array(stored_raw_image) if stored_raw_image is not None else None
1437
+ features = extract_cell_features(image_np, masks)
1438
+ labelled = attach_viability_labels(features, masks, image_np, label_map)
1439
+ if not labelled:
1440
+ return (gr.update(visible=False), [],
1441
+ gr.update(value="*No cells found.*", visible=True))
1442
+ grid = build_correction_grid(image_np, masks, labelled, raw_image_np)
1443
+ n = len(labelled)
1444
+ dead = sum(1 for r in labelled if r["label"] == 1)
1445
+ msg = (f"*{n} cells — {n-dead} live (green), {dead} dead (red). "
1446
+ f"Tap any thumbnail to flip its label.*")
1447
+ return gr.update(value=grid, visible=True), labelled, gr.update(value=msg, visible=True)
1448
+
1449
+ build_grid_btn.click(
1450
+ fn=on_build_grid,
1451
+ inputs=[masks_state, image_state, label_map_state, raw_image_state],
1452
+ outputs=[correction_grid, labelled_state, correction_status]
1453
+ )
1454
+
1455
+ # ---- Grid tap — flip label, update overlay + counts ----------------
1456
+ def on_grid_tap(labelled, stored_masks, stored_image, stored_raw_image, evt: gr.SelectData):
1457
+ if not labelled or stored_masks is None:
1458
+ return None, labelled, "", 0, 0, 0.0, None, {}
1459
+ masks = unpack_array(stored_masks)
1460
+ image_np = unpack_array(stored_image)
1461
+ raw_image_np = unpack_array(stored_raw_image) if stored_raw_image is not None else None
1462
+ grid, updated, msg = toggle_cell_label(labelled, image_np, masks, raw_image_np, evt)
1463
+
1464
+ # Rebuild label_map from corrected labelled list
1465
+ new_label_map = {int(f["cell_id"]): int(f["label"]) for f in updated}
1466
+ overlay_np = draw_viability_overlay(image_np, masks, new_label_map)
1467
+ dead = sum(1 for f in updated if f["label"] == 1)
1468
+ alive = len(updated) - dead
1469
+ total = alive + dead
1470
+ viab_pct = (alive / total * 100) if total > 0 else 0.0
1471
+
1472
+ return (grid, updated, f"*{msg}*",
1473
+ alive, dead, viab_pct,
1474
+ Image.fromarray(overlay_np), new_label_map)
1475
+
1476
+ correction_grid.select(
1477
+ fn=on_grid_tap,
1478
+ inputs=[labelled_state, masks_state, image_state, raw_image_state],
1479
+ outputs=[correction_grid, labelled_state, correction_status,
1480
+ live_count_out, dead_count_out, viab_percent_out,
1481
+ viab_overlay, label_map_state]
1482
+ )
1483
+
1484
+ # ---- Export --------------------------------------------------------
1485
+ def on_export(stored_masks, stored_image, labelled, label_map):
1486
+ path, msg = prepare_export_corrected(stored_masks, stored_image, labelled, label_map)
1487
+ if path is None:
1488
+ return gr.update(visible=False), msg
1489
+ return gr.update(value=path, visible=True), msg
1490
+
1491
+ export_btn.click(
1492
+ fn=on_export,
1493
+ inputs=[masks_state, image_state, labelled_state, label_map_state],
1494
+ outputs=[export_file, export_info]
1495
+ )
1496
+
1497
+
1498
+
1499
+ # ---------------------------------------------------------------------------
1500
+ # Gradio interface
1501
+ # ---------------------------------------------------------------------------
1502
+ with gr.Blocks(
1503
+ title="CellposeCellCounter",
1504
+ theme=gr.themes.Soft(),
1505
+ ) as demo:
1506
+ gr.Markdown("# CellposeCellCounter")
1507
+ gr.Markdown("For accurate cell confluency, crop the image to display only desired area. Note that some image file types are not yet supported. PNG and JPEG are preferred.")
1508
+
1509
+ # Shared mask/image state (one pair per tab so tabs don't clobber each other)
1510
+ masks_states = [gr.State(value=None) for _ in range(4)]
1511
+ image_states = [gr.State(value=None) for _ in range(4)]
1512
+ result_states = [gr.State(value=None) for _ in range(4)]
1513
+
1514
+ # Build Tabs 1–4 with a loop
1515
+ for i in range(4):
1516
+ build_tab(i + 1, masks_states[i], image_states[i], result_states[i])
1517
+
1518
+ # -------------------------------------------------------------------------
1519
+ # Tab 5 — Summary
1520
+ # -------------------------------------------------------------------------
1521
+ with gr.Tab("Tab 5 — Summary"):
1522
+ gr.Markdown("## Average Results Across All Tabs")
1523
+ gr.Markdown(
1524
+ "Run segmentation in one or more tabs, "
1525
+ "then click **Refresh Summary** to see the averages."
1526
+ )
1527
+
1528
+ refresh_btn = gr.Button("🔄 Refresh Summary", variant="primary", size="lg")
1529
+
1530
+ with gr.Row():
1531
+ avg_count_out = gr.Number(label="Avg Cell Count", precision=1)
1532
+ avg_conf_out = gr.Number(label="Avg Confluency (%)", precision=1)
1533
+ avg_viab_out = gr.Number(label="Avg Viability (%)", precision=1)
1534
+
1535
+ summary_box = gr.Textbox(label="Per-Tab Breakdown", lines=10)
1536
+
1537
+ refresh_btn.click(
1538
+ fn=compute_summary,
1539
+ inputs=result_states, # list of 4 gr.State components
1540
+ outputs=[avg_count_out, avg_conf_out, avg_viab_out, summary_box]
1541
+ )
1542
+
1543
+
1544
+
1545
+ if __name__ == "__main__":
1546
+ demo.launch()
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ cellpose
2
+ gradio>=5.35.0
3
+ spaces
4
+ opencv-python
5
+ matplotlib
6
+ Pillow
7
+ numpy
8
+ huggingface_hub
9
+ joblib
10
+ scikit-learn