LiangLabUMB commited on
Commit
4b9a7f7
Β·
verified Β·
1 Parent(s): 87fa218

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +1429 -0
app.py ADDED
@@ -0,0 +1,1429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from cellpose import models
4
+ import numpy as np
5
+ import cv2
6
+ import matplotlib.pyplot as plt
7
+ import tempfile
8
+ from PIL import Image, ImageDraw
9
+ import io
10
+ from huggingface_hub import hf_hub_download
11
+ import base64
12
+ from concurrent.futures import ThreadPoolExecutor, as_completed
13
+ import csv
14
+ import joblib
15
+ import os
16
+ import xgboost # required for loading viability_xgb_clf.pkl
17
+
18
+ HF_REPO_ID = "myang4218/cellposemodel"
19
+ HF_REPO_ID2 = "LiangLabUMB/viability_model"
20
+ HF_REPO_CPSAM = "mouseland/cellpose-sam"
21
+ MODEL_OPTIONS = {
22
+ "Hemocytometer Model": "hemocytometermodel.npy",
23
+ "General Model": "generalmodel.npy",
24
+ "Cellpose SAMv2": "cpsam_v2",
25
+ }
26
+ MODEL_REPOS = {
27
+ "hemocytometermodel.npy": HF_REPO_ID,
28
+ "generalmodel.npy": HF_REPO_ID,
29
+ "cpsam_v2": HF_REPO_CPSAM,
30
+ }
31
+
32
+ loaded_models = {}
33
+
34
+ VIABILITY_CLF = None
35
+ VIABILITY_SCALER = None
36
+
37
+ try:
38
+ _clf_path = hf_hub_download(repo_id=HF_REPO_ID2, filename="viability_xgb_clf.pkl")
39
+ _scaler_path = hf_hub_download(repo_id=HF_REPO_ID2, filename="viability_xgb_scaler.pkl")
40
+ VIABILITY_CLF = joblib.load(_clf_path)
41
+ VIABILITY_SCALER = joblib.load(_scaler_path)
42
+ print("βœ“ Viability classifier loaded.")
43
+ except Exception as e:
44
+ print(f"Viability classifier not found or failed to load: {e}")
45
+
46
+ # mobile safe resize limits
47
+ MAX_SIDE = 1024
48
+ MAX_PIXELS = 1024 * 1024
49
+
50
+
51
+ def safe_resize(image_np):
52
+
53
+ h, w = image_np.shape[:2]
54
+ total = h * w
55
+
56
+ if max(h, w) <= MAX_SIDE and total <= MAX_PIXELS:
57
+ return image_np
58
+
59
+ # compute scale
60
+ scale_side = MAX_SIDE / max(h, w)
61
+ scale_pixels = (MAX_PIXELS / total) ** 0.5
62
+ scale = min(scale_side, scale_pixels)
63
+
64
+ new_w = max(1, int(w * scale))
65
+ new_h = max(1, int(h * scale))
66
+
67
+ return cv2.resize(image_np, (new_w, new_h), interpolation=cv2.INTER_AREA)
68
+
69
+
70
+ def draw_exclusion_overlay(image_np, left_width_pct, top_width_pct):
71
+
72
+ h, w = image_np.shape[:2]
73
+
74
+ # Convert to PIL for drawing
75
+ img_pil = Image.fromarray(image_np)
76
+ draw = ImageDraw.Draw(img_pil, 'RGBA')
77
+
78
+ # Calculate pixel widths from percentages
79
+ left_px = int(w * left_width_pct / 100)
80
+ top_px = int(h * top_width_pct / 100)
81
+
82
+ # Draw overlays for exclusion zones
83
+ if left_px > 0:
84
+ # Left exclusion zone
85
+ draw.rectangle(
86
+ [(0, 0), (left_px, h)],
87
+ fill=(255, 0, 0, 80) # Semi-transparent red
88
+ )
89
+ # border line
90
+ draw.line([(left_px, 0), (left_px, h)], fill=(255, 0, 0, 255), width=3)
91
+
92
+ if top_px > 0:
93
+ # Top exclusion zone
94
+ draw.rectangle(
95
+ [(0, 0), (w, top_px)],
96
+ fill=(255, 0, 0, 80) # Semi-transparent red
97
+ )
98
+ # border line
99
+ draw.line([(0, top_px), (w, top_px)], fill=(255, 0, 0, 255), width=3)
100
+
101
+ return np.array(img_pil)
102
+
103
+
104
+ def apply_stereological_exclusion(masks, left_width_pct, top_width_pct):
105
+ h, w = masks.shape
106
+
107
+ # Calculate pixel widths from percentages
108
+ left_px = int(w * left_width_pct / 100)
109
+ top_px = int(h * top_width_pct / 100)
110
+
111
+ filtered_masks = masks.copy()
112
+ cell_ids = np.unique(masks)
113
+ cell_ids = cell_ids[cell_ids > 0]
114
+
115
+ excluded_cells = []
116
+ included_cells = []
117
+
118
+ for cell_id in cell_ids:
119
+ cell_mask = (masks == cell_id)
120
+
121
+ # Get cell boundary coordinates
122
+ rows, cols = np.where(cell_mask)
123
+
124
+ # Check if cell touches left exclusion zone
125
+ touches_left = np.any(cols < left_px) if left_px > 0 else False
126
+
127
+ # Check if cell touches top exclusion zone
128
+ touches_top = np.any(rows < top_px) if top_px > 0 else False
129
+
130
+ # Exclude if touching left or top
131
+ if touches_left or touches_top:
132
+ filtered_masks[cell_mask] = 0
133
+ excluded_cells.append(cell_id)
134
+ else:
135
+ included_cells.append(cell_id)
136
+
137
+ # Renumber remaining cells
138
+ unique_ids = np.unique(filtered_masks)
139
+ unique_ids = unique_ids[unique_ids > 0]
140
+
141
+ renumbered_masks = np.zeros_like(filtered_masks)
142
+ for new_id, old_id in enumerate(unique_ids, start=1):
143
+ renumbered_masks[filtered_masks == old_id] = new_id
144
+
145
+ return renumbered_masks, len(excluded_cells), len(included_cells)
146
+
147
+
148
+
149
+ FEATURE_COLS_INFERENCE = [
150
+ "mean_r", "mean_g", "mean_b", "std_r", "std_g", "std_b",
151
+ "mean_h", "mean_s", "mean_v", "std_s", "std_v",
152
+ "blue_red_ratio", "blue_green_ratio", "rg_ratio",
153
+ "inner_brightness", "peak_brightness",
154
+ "bright_spot_fraction", "ring_darkness",
155
+ "centre_periphery_ratio", "brightness_std_normalised",
156
+ ]
157
+
158
+
159
+ def classify_cells_by_model(image_np, masks):
160
+
161
+ import numpy as np
162
+ cell_ids = np.unique(masks)
163
+ cell_ids = cell_ids[cell_ids > 0]
164
+ if len(cell_ids) == 0:
165
+ return 0, 0, image_np.copy(), {}
166
+
167
+ features = extract_cell_features(image_np, masks)
168
+ if not features:
169
+ return 0, 0, image_np.copy(), {}
170
+
171
+ import numpy as np
172
+ X = np.array([[f[c] for c in FEATURE_COLS_INFERENCE] for f in features], dtype=np.float32)
173
+
174
+ # replace any NaN/Inf with column median
175
+ for j in range(X.shape[1]):
176
+ bad = ~np.isfinite(X[:, j])
177
+ if bad.any():
178
+ X[bad, j] = float(np.nanmedian(X[:, j]))
179
+
180
+ X_scaled = VIABILITY_SCALER.transform(X)
181
+ predictions = VIABILITY_CLF.predict(X_scaled) # 0=live, 1=dead
182
+
183
+ label_map = {int(f["cell_id"]): int(p) for f, p in zip(features, predictions)}
184
+ overlay = draw_viability_overlay(image_np, masks, label_map)
185
+
186
+ dead = int(sum(predictions))
187
+ alive = int(len(predictions) - dead)
188
+ return dead, alive, overlay, label_map
189
+
190
+
191
+ def draw_viability_overlay(image_np, masks, label_map):
192
+
193
+ overlay = image_np.copy()
194
+ cell_ids = np.unique(masks)
195
+ cell_ids = cell_ids[cell_ids > 0]
196
+ cell_enum = {int(cid): idx + 1 for idx, cid in enumerate(sorted(cell_ids))}
197
+
198
+ for cid in cell_ids:
199
+ cid_int = int(cid)
200
+ label = label_map.get(cid_int, 0)
201
+ color = (220, 50, 50) if label == 1 else (50, 220, 80)
202
+ cell_mask = (masks == cid).astype(np.uint8)
203
+ contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
204
+ cv2.drawContours(overlay, contours, -1, color, thickness=2)
205
+
206
+ ys, xs = np.where(cell_mask)
207
+ if len(ys) > 0:
208
+ cx, cy = int(xs.mean()), int(ys.mean())
209
+ label_str = str(cell_enum[cid_int])
210
+ font = cv2.FONT_HERSHEY_SIMPLEX
211
+ font_scale = 0.35
212
+ thickness = 1
213
+ (tw, th), _ = cv2.getTextSize(label_str, font, font_scale, thickness)
214
+ cv2.rectangle(overlay,
215
+ (cx - tw//2 - 1, cy - th//2 - 1),
216
+ (cx + tw//2 + 1, cy + th//2 + 1),
217
+ (0, 0, 0), -1)
218
+ cv2.putText(overlay, label_str,
219
+ (cx - tw//2, cy + th//2),
220
+ font, font_scale, color, thickness, cv2.LINE_AA)
221
+ return overlay
222
+
223
+
224
+
225
+
226
+ def measure_confluency(masks, image_np):
227
+ tot_pixels = image_np.shape[0] * image_np.shape[1]
228
+ cell_pixels = np.count_nonzero(masks)
229
+ confluency = cell_pixels / tot_pixels * 100
230
+ return confluency
231
+
232
+ def filter_mask_by_size(masks, minimum_pixels):
233
+ filtered_masks = masks.copy()
234
+ cell_ids = np.unique(masks)
235
+ cell_ids = cell_ids[cell_ids > 0]
236
+
237
+ removed_count = 0
238
+
239
+ for cell_id in cell_ids:
240
+ cell_mask = (masks == cell_id)
241
+ cell_pixels = np.count_nonzero(cell_mask)
242
+ if cell_pixels < minimum_pixels:
243
+ filtered_masks[cell_mask] = 0
244
+ removed_count += 1
245
+
246
+ unique_ids = np.unique(filtered_masks)
247
+ unique_ids = unique_ids[unique_ids > 0]
248
+
249
+ renumbered_masks = np.zeros_like(filtered_masks)
250
+ for new_id, old_id in enumerate(unique_ids, start=1):
251
+ renumbered_masks[filtered_masks == old_id] = new_id
252
+
253
+ return renumbered_masks, removed_count
254
+
255
+
256
+ def filter_mask_by_maxsize(masks, maximum_pixels):
257
+ filtered_masks = masks.copy()
258
+ cell_ids = np.unique(masks)
259
+ cell_ids = cell_ids[cell_ids > 0]
260
+
261
+ removed_count = 0
262
+ for cell_id in cell_ids:
263
+ cell_mask = (masks == cell_id)
264
+ cell_pixels = np.count_nonzero(cell_mask)
265
+ if cell_pixels > maximum_pixels:
266
+ filtered_masks[cell_mask] = 0
267
+ removed_count += 1
268
+
269
+ unique_ids = np.unique(filtered_masks)
270
+ unique_ids = unique_ids[unique_ids > 0]
271
+
272
+ renumbered_masks = np.zeros_like(filtered_masks)
273
+ for new_id, old_id in enumerate(unique_ids, start=1):
274
+ renumbered_masks[filtered_masks == old_id] = new_id
275
+
276
+ return renumbered_masks, removed_count
277
+
278
+
279
+ def rec_min_size(masks, q=25):
280
+ ids = np.unique(masks)
281
+ ids = ids[ids > 0]
282
+ if len(ids) == 0:
283
+ return 0
284
+ sizes = np.array([np.count_nonzero(masks == cid) for cid in ids])
285
+ return int(round(np.percentile(sizes, q)))
286
+
287
+
288
+ def apply_polygon_mask(image_pil, points_json):
289
+ """
290
+ Given a PIL image and a JSON string of [[x,y],...] points,
291
+ zero out everything outside the polygon and return a PIL image.
292
+ """
293
+ import json
294
+ if not points_json or points_json.strip() in ("", "[]"):
295
+ return image_pil
296
+ try:
297
+ pts = json.loads(points_json)
298
+ except Exception:
299
+ return image_pil
300
+ if len(pts) < 3:
301
+ return image_pil
302
+
303
+ image_np = np.array(image_pil)
304
+ h, w = image_np.shape[:2]
305
+ poly = np.array(pts, dtype=np.int32)
306
+ poly[:, 0] = np.clip(poly[:, 0], 0, w - 1)
307
+ poly[:, 1] = np.clip(poly[:, 1], 0, h - 1)
308
+ mask = np.zeros((h, w), dtype=np.uint8)
309
+ cv2.fillPoly(mask, [poly], 255)
310
+ if len(image_np.shape) == 3:
311
+ result = np.where(mask[:, :, np.newaxis] == 255, image_np, 0).astype(np.uint8)
312
+ else:
313
+ result = np.where(mask == 255, image_np, 0).astype(np.uint8)
314
+ return Image.fromarray(result)
315
+
316
+ def warp_polygon_to_square(image_np, points):
317
+ pts = np.array(points, dtype=np.float32)
318
+
319
+ s = pts.sum(axis=1)
320
+ diff = np.diff(pts, axis=1).ravel()
321
+ tl = pts[np.argmin(s)]
322
+ br = pts[np.argmax(s)]
323
+ tr = pts[np.argmin(diff)]
324
+ bl = pts[np.argmax(diff)]
325
+ src = np.array([tl, tr, br, bl], dtype=np.float32)
326
+
327
+ w1 = np.linalg.norm(br-bl)
328
+ w2 = np.linalg.norm(tr-tl)
329
+ h1 = np.linalg.norm(tr-br)
330
+ h2 = np.linalg.norm(tl-bl)
331
+ out_w = int(max(w1, w2))
332
+ out_h = int(max(h1, h2))
333
+
334
+ dst = np.array(
335
+ [[0, 0],
336
+ [out_w - 1, 0],
337
+ [out_w - 1, out_h - 1],
338
+ [0, out_h - 1]],
339
+ dtype=np.float32)
340
+
341
+ M = cv2.getPerspectiveTransform(src, dst)
342
+ warped = cv2.warpPerspective(image_np, M, (out_w, out_h))
343
+ return warped
344
+
345
+
346
+ def toggle_stereological_mode(use_stereology):
347
+ return gr.update(visible=use_stereology)
348
+
349
+
350
+ def update_exclusion_preview(image, left_width, top_width):
351
+ if image is None:
352
+ return None
353
+
354
+ image_np = np.array(image)
355
+ overlay = draw_exclusion_overlay(image_np, left_width, top_width)
356
+ return Image.fromarray(overlay)
357
+
358
+
359
+ # Patch segmentation
360
+
361
+ PATCH_SIZE = 512 # target patch side length
362
+ PATCH_OVERLAP = 64 # overlap border on each edge (pixels)
363
+ MIN_PATCH_DIM = 256 # don't bother patching if image fits comfortably
364
+
365
+
366
+ def _split_patches(image_np, patch_size=PATCH_SIZE, overlap=PATCH_OVERLAP):
367
+ """
368
+ Split image into overlapping patches.
369
+ Returns list of (patch_np, row_start, col_start) tuples.
370
+ """
371
+ h, w = image_np.shape[:2]
372
+ patches = []
373
+ row = 0
374
+ while row < h:
375
+ row_end = min(row + patch_size, h)
376
+ col = 0
377
+ while col < w:
378
+ col_end = min(col + patch_size, w)
379
+ patch = image_np[row:row_end, col:col_end]
380
+ patches.append((patch, row, col))
381
+ if col_end == w:
382
+ break
383
+ col += patch_size - overlap
384
+ if row_end == h:
385
+ break
386
+ row += patch_size - overlap
387
+ return patches
388
+
389
+
390
+ def _merge_patch_masks(patch_results, full_h, full_w, overlap=PATCH_OVERLAP):
391
+ """
392
+ Stitch per-patch masks into a single full-image mask.
393
+
394
+ Strategy:
395
+ - Each patch gets a unique ID offset so cell IDs never collide.
396
+ - Patches are pasted into the canvas using a priority canvas that
397
+ gives interior pixels precedence over overlap-border pixels.
398
+ - After pasting, cells whose centroids fall in the overlap zone
399
+ of two adjacent patches are deduplicated: if two cells from
400
+ different patches share >50% IoU they are the same cell β€” keep
401
+ the one whose centroid is furthest from a patch edge.
402
+ """
403
+ full_mask = np.zeros((full_h, full_w), dtype=np.int32)
404
+ # track which patch_idx owns each pixel (used for overlap resolution)
405
+ owner_map = np.full((full_h, full_w), -1, dtype=np.int32)
406
+ # distance-to-nearest-edge for the owning patch (higher = more central)
407
+ priority = np.zeros((full_h, full_w), dtype=np.float32)
408
+
409
+ id_offset = 0
410
+ patch_meta = [] # (offset, row_start, col_start, patch_h, patch_w)
411
+
412
+ for patch_idx, (mask_patch, row_start, col_start) in enumerate(patch_results):
413
+ ph, pw = mask_patch.shape
414
+ # offset all non-zero IDs so they're globally unique
415
+ shifted = np.where(mask_patch > 0, mask_patch + id_offset, 0).astype(np.int32)
416
+
417
+ # compute per-pixel priority = min distance to any patch edge
418
+ rows_idx = np.arange(ph)
419
+ cols_idx = np.arange(pw)
420
+ dist_r = np.minimum(rows_idx, ph - 1 - rows_idx) # (ph,)
421
+ dist_c = np.minimum(cols_idx, pw - 1 - cols_idx) # (pw,)
422
+ pri_patch = np.minimum(dist_r[:, None], dist_c[None, :]) # (ph, pw)
423
+
424
+ roi_full = full_mask [row_start:row_start+ph, col_start:col_start+pw]
425
+ roi_owner = owner_map [row_start:row_start+ph, col_start:col_start+pw]
426
+ roi_pri = priority [row_start:row_start+ph, col_start:col_start+pw]
427
+
428
+ # where this patch has higher priority, overwrite
429
+ better = pri_patch > roi_pri
430
+ roi_full [better] = shifted [better]
431
+ roi_owner[better] = patch_idx
432
+ roi_pri [better] = pri_patch [better]
433
+
434
+ max_id = int(mask_patch.max())
435
+ patch_meta.append((id_offset, row_start, col_start, ph, pw))
436
+ id_offset += max_id + 1
437
+
438
+ # --- Renumber to compact sequential IDs ---
439
+ unique_ids = np.unique(full_mask)
440
+ unique_ids = unique_ids[unique_ids > 0]
441
+ renumbered = np.zeros_like(full_mask)
442
+ for new_id, old_id in enumerate(unique_ids, start=1):
443
+ renumbered[full_mask == old_id] = new_id
444
+
445
+ return renumbered
446
+
447
+
448
+ def _segment_patch(args):
449
+ """Worker: run cellpose on a single patch. Called from a thread pool."""
450
+ patch_np, row_start, col_start, model_filename, hf_repo = args
451
+ # Each thread uses the shared loaded_models cache (GIL-safe for reads;
452
+ # model.eval() releases the GIL during GPU work so threads overlap.)
453
+ model_path = hf_hub_download(repo_id=hf_repo, filename=model_filename)
454
+ if model_filename in loaded_models:
455
+ model = loaded_models[model_filename]
456
+ else:
457
+ model = models.CellposeModel(gpu=True, pretrained_model=model_path)
458
+ loaded_models[model_filename] = model
459
+
460
+ mask, _, _ = model.eval(patch_np, diameter=None)
461
+ return mask, row_start, col_start
462
+
463
+
464
+ def run_segmentation_patched(image_np, model_filename):
465
+ """
466
+ Split image into overlapping patches, run Cellpose on each in parallel,
467
+ then stitch back into a single full-resolution mask.
468
+ Falls back to whole-image segmentation if the image is small enough
469
+ that patching adds overhead without benefit.
470
+ """
471
+ h, w = image_np.shape[:2]
472
+ repo = MODEL_REPOS.get(model_filename, HF_REPO_ID)
473
+ model_path = hf_hub_download(repo_id=repo, filename=model_filename)
474
+ if model_filename in loaded_models:
475
+ model = loaded_models[model_filename]
476
+ else:
477
+ model = models.CellposeModel(gpu=True, pretrained_model=model_path)
478
+ loaded_models[model_filename] = model
479
+
480
+ # Small images: no benefit from patching
481
+ if max(h, w) <= MIN_PATCH_DIM * 2:
482
+ mask, _, _ = model.eval(image_np, diameter=None)
483
+ return mask, 1 # 1 patch
484
+
485
+ patches = _split_patches(image_np)
486
+ n_patches = len(patches)
487
+
488
+ # Build argument list for the thread pool
489
+ patch_repo = MODEL_REPOS.get(model_filename, HF_REPO_ID)
490
+ args_list = [
491
+ (patch, r, c, model_filename, patch_repo)
492
+ for patch, r, c in patches
493
+ ]
494
+
495
+ patch_results = [] # (mask, row_start, col_start) in submission order
496
+
497
+ # ThreadPoolExecutor: GPU kernels release the GIL so threads overlap on GPU
498
+ with ThreadPoolExecutor(max_workers=min(n_patches, 4)) as pool:
499
+ futures = {pool.submit(_segment_patch, a): a for a in args_list}
500
+ for future in as_completed(futures):
501
+ mask_patch, row_start, col_start = future.result()
502
+ patch_results.append((mask_patch, row_start, col_start))
503
+
504
+ # Re-sort by (row, col) so stitching is deterministic
505
+ patch_results.sort(key=lambda x: (x[1], x[2]))
506
+
507
+ full_mask = _merge_patch_masks(patch_results, h, w)
508
+ return full_mask, n_patches
509
+
510
+
511
+ @spaces.GPU
512
+ def run_segmentation(image, model_choice, min_cell_size, max_cell_size,
513
+ use_min_filter, use_max_filter,
514
+ use_stereology, left_exclusion, top_exclusion,
515
+ crop_points=None):
516
+ image_np = np.array(image)
517
+ image_np = safe_resize(image_np)
518
+
519
+ raw_image_np = image_np.copy()
520
+
521
+ # Apply polygon crop mask if the user drew one (need β‰₯3 points for a polygon)
522
+ if crop_points and len(crop_points) >= 3:
523
+ import json
524
+ pts_json = json.dumps(crop_points)
525
+ image_pil_masked = apply_polygon_mask(Image.fromarray(image_np), pts_json)
526
+ image_np = np.array(image_pil_masked)
527
+
528
+ if len(crop_points) == 4:
529
+ image_np = warp_polygon_to_square(image_np, crop_points)
530
+
531
+
532
+ try:
533
+ model_filename = MODEL_OPTIONS[model_choice]
534
+
535
+ # Process image format to RGB
536
+ if len(image_np.shape) == 2:
537
+ processed_image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
538
+ elif len(image_np.shape) == 3 and image_np.shape[2] == 4:
539
+ processed_image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
540
+ else:
541
+ processed_image_np = image_np
542
+
543
+ # Run patch-parallel Cellpose segmentation
544
+ masks_raw, n_patches = run_segmentation_patched(processed_image_np, model_filename)
545
+
546
+ ids = np.unique(masks_raw)
547
+ ids = ids[ids > 0]
548
+
549
+ sizes = np.array([np.count_nonzero(masks_raw == cid) for cid in ids])
550
+
551
+ print("num_cells:", len(ids))
552
+ print("mean:", sizes.mean() if len(sizes) > 0 else 0)
553
+ print("median:", np.median(sizes) if len(sizes) > 0 else 0)
554
+ print("p90:", np.percentile(sizes, 90) if len(sizes) > 0 else 0)
555
+ print("max:", sizes.max() if len(sizes) > 0 else 0)
556
+
557
+ # Compute recommendation from RAW masks (always shown, never auto-applied)
558
+ recommend_min = rec_min_size(masks_raw)
559
+
560
+ # Apply filters only if their checkboxes are enabled
561
+ masks = masks_raw.copy()
562
+ removed_small = 0
563
+ removed_large = 0
564
+
565
+ if use_min_filter and int(min_cell_size) > 0:
566
+ masks, removed_small = filter_mask_by_size(masks, int(min_cell_size))
567
+
568
+ if use_max_filter and max_cell_size > 0:
569
+ masks, removed_large = filter_mask_by_maxsize(masks, int(max_cell_size))
570
+
571
+ # Apply stereological exclusion if enabled
572
+ excluded_count = 0
573
+ if use_stereology:
574
+ masks, excluded_count, included_count = apply_stereological_exclusion(
575
+ masks, left_exclusion, top_exclusion
576
+ )
577
+
578
+ filter_msg = ""
579
+ if removed_small:
580
+ filter_msg += f"Removed {removed_small} small objects (< {int(min_cell_size)} pixels).\n"
581
+ if removed_large:
582
+ filter_msg += f"Removed {removed_large} large objects (> {int(max_cell_size)} pixels).\n"
583
+ if use_stereology and excluded_count > 0:
584
+ filter_msg += f"Stereological exclusion: {excluded_count} cells excluded (touching left/top zones).\n"
585
+
586
+ cell_count = len(np.unique(masks)) - 1
587
+ confluency = measure_confluency(masks, processed_image_np)
588
+
589
+ # Create a basic segmentation overlay (without viability)
590
+ segmentation_overlay = processed_image_np.copy().astype(np.float32)
591
+ if masks.max() > 0:
592
+ np.random.seed(42) # For consistent random colors
593
+ colors = np.random.randint(0, 255, size=(masks.max() + 1, 3))
594
+ colors[0] = [0, 0, 0]
595
+ colored_mask = colors[masks]
596
+ alpha = 0.4
597
+ segmentation_overlay = (1 - alpha) * segmentation_overlay + alpha * colored_mask
598
+ segmentation_overlay = np.clip(segmentation_overlay, 0, 255).astype(np.uint8)
599
+
600
+ # Add exclusion zone overlay if stereology is enabled
601
+ if use_stereology:
602
+ segmentation_overlay = draw_exclusion_overlay(segmentation_overlay, left_exclusion, top_exclusion)
603
+
604
+ info_msg = ""
605
+ if filter_msg:
606
+ info_msg += filter_msg
607
+ info_msg += f"Segmentation complete! Found {cell_count} cells.\n"
608
+ info_msg += f"Confluency: {confluency:.1f}%\n"
609
+ info_msg += f"Processed as {n_patches} patch{'es' if n_patches > 1 else ''} (parallel).\n"
610
+ if use_stereology:
611
+ info_msg += f"Stereological counting enabled (Left: {left_exclusion}%, Top: {top_exclusion}%)\n"
612
+ info_msg += "Now run the viability classification model for viability assessment."
613
+
614
+ return (
615
+ cell_count,
616
+ Image.fromarray(segmentation_overlay),
617
+ info_msg,
618
+ gr.update(visible=True),
619
+ pack_array(masks),
620
+ pack_array(processed_image_np),
621
+ confluency,
622
+ f"Recommended minimum: **{recommend_min} px** (25th percentile of detected cell sizes)",
623
+ pack_array(raw_image_np),
624
+ )
625
+
626
+ except Exception as e:
627
+ import traceback
628
+ traceback.print_exc()
629
+ return (
630
+ 0,
631
+ None,
632
+ f"Error during segmentation: {str(e)}",
633
+ gr.update(visible=False),
634
+ None,
635
+ None,
636
+ 0.0,
637
+ "",
638
+ None,
639
+ )
640
+
641
+
642
+ def run_viability(stored_masks, stored_image_np):
643
+ if stored_masks is None or stored_image_np is None:
644
+ return None, 0, 0, 0.0, "Please run segmentation first.", {}
645
+ if VIABILITY_CLF is None:
646
+ return None, 0, 0, 0.0, "No viability model loaded. Check that viability_xgb_clf.pkl and viability_xgb_scaler.pkl are present in the LiangLabUMB/viability_model HuggingFace repo and that the Space has restarted after upload.", {}
647
+
648
+ masks = unpack_array(stored_masks)
649
+ image_np = unpack_array(stored_image_np)
650
+
651
+ try:
652
+ dead, alive, overlay_np, label_map = classify_cells_by_model(image_np, masks)
653
+ total = alive + dead
654
+ viab_pct = (alive / total * 100) if total > 0 else 0.0
655
+ confluency = measure_confluency(masks, image_np)
656
+ info_msg = f"Total cells: {total}\nLive (green): {alive}\nDead (red): {dead}\n"
657
+ info_msg += f"Viability: {viab_pct:.1f}%\nConfluency: {confluency:.1f}%"
658
+ return Image.fromarray(overlay_np), alive, dead, viab_pct, info_msg, label_map
659
+ except Exception as e:
660
+ import traceback; traceback.print_exc()
661
+ return None, 0, 0, 0.0, f"Error: {str(e)}", {}
662
+
663
+
664
+ def pack_array(arr):
665
+ """
666
+ Serialise a numpy array to bytes for gr.State storage.
667
+ Uses numpy's .npy format (not PNG) so integer dtypes of any
668
+ magnitude are preserved exactly β€” PNG is 8-bit only and silently
669
+ truncates cell IDs above 255.
670
+ """
671
+ buf = io.BytesIO()
672
+ np.save(buf, arr)
673
+ return buf.getvalue()
674
+
675
+
676
+ def unpack_array(data):
677
+ buf = io.BytesIO(data)
678
+ return np.load(buf, allow_pickle=False)
679
+
680
+
681
+ def save_tab_result(cell_count, confluency, viab_percent):
682
+ """Package per-tab results into a dict for Tab 5 summary."""
683
+ return {
684
+ "cell_count": float(cell_count) if cell_count is not None else None,
685
+ "confluency": float(confluency) if confluency is not None else None,
686
+ "viab_percent": float(viab_percent) if viab_percent is not None else None,
687
+ }
688
+
689
+
690
+ def compute_summary(r1, r2, r3, r4):
691
+ """Average cell count, confluency, and viability across tabs that have data."""
692
+ all_results = [r1, r2, r3, r4]
693
+ valid = [(i + 1, r) for i, r in enumerate(all_results)
694
+ if r is not None and r.get("cell_count") is not None]
695
+
696
+ if not valid:
697
+ return (
698
+ 0.0, 0.0, 0.0,
699
+ "No data yet β€” run segmentation in at least one tab, then click Refresh Summary."
700
+ )
701
+
702
+ avg_count = sum(r["cell_count"] for _, r in valid) / len(valid)
703
+ avg_conf = sum(r["confluency"] for _, r in valid) / len(valid)
704
+ avg_viab = sum(r["viab_percent"] for _, r in valid) / len(valid)
705
+
706
+ lines = [f"Tab {tab_num}: {r['cell_count']:.0f} cells | "
707
+ f"{r['confluency']:.1f}% confluency | "
708
+ f"{r['viab_percent']:.1f}% viability"
709
+ for tab_num, r in valid]
710
+ lines.append(f"\nAverages ({len(valid)} tab{'s' if len(valid) > 1 else ''}):")
711
+ lines.append(f" Cell count: {avg_count:.1f}")
712
+ lines.append(f" Confluency: {avg_conf:.1f}%")
713
+ lines.append(f" Viability: {avg_viab:.1f}%")
714
+
715
+ return avg_count, avg_conf, avg_viab, "\n".join(lines)
716
+
717
+
718
+
719
+ # Training data export β€” feature extraction per cell
720
+
721
+
722
+ def extract_cell_features(image_np, masks):
723
+
724
+ if len(image_np.shape) == 2:
725
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
726
+ elif image_np.shape[2] == 4:
727
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
728
+
729
+ hsv = cv2.cvtColor(image_np, cv2.COLOR_RGB2HSV).astype(np.float32)
730
+
731
+ h_img, w_img = image_np.shape[:2]
732
+ grid_y, grid_x = np.mgrid[:h_img, :w_img]
733
+
734
+ cell_ids = np.unique(masks)
735
+ cell_ids = cell_ids[cell_ids > 0]
736
+ rows = []
737
+
738
+ for cid in cell_ids:
739
+ cell_mask = (masks == cid)
740
+ pixels_rgb = image_np[cell_mask].astype(np.float32)
741
+ pixels_hsv = hsv[cell_mask]
742
+
743
+ r, g, b = pixels_rgb[:, 0], pixels_rgb[:, 1], pixels_rgb[:, 2]
744
+ h, s, v = pixels_hsv[:, 0], pixels_hsv[:, 1], pixels_hsv[:, 2]
745
+
746
+ eps = 1e-6
747
+ blue_red_ratio = b.mean() / (r.mean() + eps)
748
+ blue_green_ratio = b.mean() / (g.mean() + eps)
749
+ rg_ratio = r.mean() / (g.mean() + eps)
750
+
751
+ area_px = int(cell_mask.sum())
752
+ contours, _ = cv2.findContours(
753
+ cell_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
754
+ )
755
+ perimeter = cv2.arcLength(contours[0], True) if contours else 1.0
756
+ circularity = (4 * np.pi * area_px / (perimeter ** 2 + eps)) if perimeter > 0 else 0.0
757
+
758
+ ys_cell = grid_y[cell_mask].astype(np.float32)
759
+ xs_cell = grid_x[cell_mask].astype(np.float32)
760
+ centroid_y = ys_cell.mean()
761
+ centroid_x = xs_cell.mean()
762
+
763
+ cell_radius = np.sqrt(area_px / np.pi) + eps
764
+ dist_norm = np.sqrt((xs_cell - centroid_x)**2 + (ys_cell - centroid_y)**2) / cell_radius
765
+
766
+ v_all = hsv[:, :, 2][cell_mask]
767
+
768
+ # Tight inner core (15% radius) β€” captures specular highlight spot only
769
+ inner_mask = dist_norm < 0.15
770
+ # Membrane ring zone (20-60%) β€” dark navy ring on live cells
771
+ ring_mask = (dist_norm >= 0.20) & (dist_norm <= 0.60)
772
+ # Outer zone (>60%) β€” denominator for centre ratio
773
+ outer_mask = dist_norm > 0.60
774
+
775
+ inner_brightness = float(v_all[inner_mask].mean()) if inner_mask.any() else float(v.mean())
776
+ ring_brightness = float(v_all[ring_mask].mean()) if ring_mask.any() else float(v.mean())
777
+ outer_brightness = float(v_all[outer_mask].mean()) if outer_mask.any() else float(v.mean())
778
+
779
+ # Peak V β€” specular spot is just a few pixels so mean dilutes it
780
+ peak_brightness = float(v_all.max())
781
+
782
+ # Fraction of cell pixels with V > 200 (specular highlight region)
783
+ bright_spot_fraction = float((v_all > 200).sum()) / (len(v_all) + eps)
784
+
785
+ # Ring darkness: ratio of ring zone to outer zone brightness
786
+ # Live: ring << outer (dark membrane ring) -> ratio < 1
787
+ # Dead: uniform blob -> ratio ~ 1
788
+ ring_darkness = ring_brightness / (outer_brightness + eps)
789
+
790
+ centre_periphery_ratio = inner_brightness / (outer_brightness + eps)
791
+
792
+ brightness_std_normalised = float(v.std()) / (float(v.mean()) + eps)
793
+
794
+ rows.append({
795
+ "cell_id": int(cid),
796
+ "mean_r": float(r.mean()),
797
+ "mean_g": float(g.mean()),
798
+ "mean_b": float(b.mean()),
799
+ "std_r": float(r.std()),
800
+ "std_g": float(g.std()),
801
+ "std_b": float(b.std()),
802
+ "mean_h": float(h.mean()),
803
+ "mean_s": float(s.mean()),
804
+ "mean_v": float(v.mean()),
805
+ "std_s": float(s.std()),
806
+ "std_v": float(v.std()),
807
+ "blue_red_ratio": round(blue_red_ratio, 5),
808
+ "blue_green_ratio": round(blue_green_ratio, 5),
809
+ "rg_ratio": round(rg_ratio, 5),
810
+ "area_px": area_px,
811
+ "circularity": round(float(circularity), 5),
812
+ "inner_brightness": round(inner_brightness, 3),
813
+ "peak_brightness": round(peak_brightness, 3),
814
+ "bright_spot_fraction": round(bright_spot_fraction, 6),
815
+ "ring_darkness": round(ring_darkness, 5),
816
+ "centre_periphery_ratio": round(centre_periphery_ratio, 5),
817
+ "brightness_std_normalised": round(brightness_std_normalised, 5),
818
+ })
819
+
820
+ return rows
821
+
822
+ def attach_viability_labels(cell_features, masks, image_np, label_map=None):
823
+ """
824
+ Attach model predictions (from label_map) to each feature dict.
825
+ label_map: {cell_id: 0=live, 1=dead} from classify_cells_by_model.
826
+ If label_map is None, defaults all labels to 0 (live).
827
+ """
828
+ if not cell_features:
829
+ return []
830
+ labelled = []
831
+ for feat in cell_features:
832
+ row = dict(feat)
833
+ cid = int(feat["cell_id"])
834
+ row["label"] = int(label_map.get(cid, 0)) if label_map else 0
835
+ row["corrected"] = False
836
+ labelled.append(row)
837
+ return labelled
838
+
839
+
840
+ def export_cell_data_csv(cell_data):
841
+ """Write cell_data list-of-dicts to a temp CSV and return the file path."""
842
+ if not cell_data:
843
+ return None
844
+ tmp = tempfile.NamedTemporaryFile(
845
+ mode="w", suffix=".csv", delete=False, newline=""
846
+ )
847
+ # Union of all keys across rows so any late-added keys (e.g. "corrected") are included
848
+ fieldnames = list(dict.fromkeys(k for row in cell_data for k in row.keys()))
849
+ writer = csv.DictWriter(tmp, fieldnames=fieldnames, extrasaction="ignore")
850
+ writer.writeheader()
851
+ writer.writerows(cell_data)
852
+ tmp.close()
853
+ return tmp.name
854
+
855
+
856
+ def prepare_export(stored_masks, stored_image, threshold_bias):
857
+ """
858
+ Called by the Export button. Unpacks state, extracts features,
859
+ attaches labels, writes CSV, returns (path, status_message).
860
+ """
861
+ if stored_masks is None or stored_image is None:
862
+ return None, "Run segmentation first before exporting."
863
+
864
+ masks = unpack_array(stored_masks)
865
+ image_np = unpack_array(stored_image)
866
+
867
+ features = extract_cell_features(image_np, masks)
868
+ if not features:
869
+ return None, "No cells found to export."
870
+
871
+ labelled = attach_viability_labels(features, masks, image_np, threshold_bias)
872
+ path = export_cell_data_csv(labelled)
873
+
874
+ n = len(labelled)
875
+ dead = sum(1 for r in labelled if r["label"] == 1)
876
+ alive = n - dead
877
+ msg = (f"Exported {n} cells ({alive} live, {dead} dead) β€” "
878
+ f"threshold bias={threshold_bias:+d}.\n"
879
+ f"Columns: {', '.join(list(labelled[0].keys())[:6])}… "
880
+ f"({len(labelled[0])} total).")
881
+ return path, msg
882
+
883
+
884
+
885
+ # Tab builder
886
+
887
+ def draw_polygon_overlay(image_pil, points):
888
+ """
889
+ Draw numbered vertex dots and polygon edges onto a copy of image_pil.
890
+ points: list of (x, y) tuples in original image pixel space.
891
+ Returns a new PIL image.
892
+ """
893
+ img = image_pil.copy().convert("RGBA")
894
+ overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
895
+ draw = ImageDraw.Draw(overlay)
896
+
897
+ if len(points) >= 2:
898
+ # Draw edges
899
+ for i in range(len(points) - 1):
900
+ draw.line([points[i], points[i + 1]], fill=(74, 170, 255, 220), width=3)
901
+ if len(points) == 4:
902
+ draw.line([points[-1], points[0]], fill=(74, 170, 255, 220), width=3)
903
+ # Semi-transparent fill
904
+ draw.polygon(points, fill=(74, 170, 255, 50))
905
+
906
+ # Draw vertex dots + numbers
907
+ r = max(8, min(img.width, img.height) // 60)
908
+ for i, (x, y) in enumerate(points):
909
+ draw.ellipse([x - r, y - r, x + r, y + r],
910
+ fill=(74, 170, 255, 255), outline=(255, 255, 255, 255))
911
+ draw.text((x, y), str(i + 1), fill=(255, 255, 255, 255), anchor="mm")
912
+
913
+ combined = Image.alpha_composite(img, overlay)
914
+ return combined.convert("RGB")
915
+
916
+
917
+ def add_crop_point(image_pil, points, evt: gr.SelectData):
918
+ """
919
+ Called by gr.Image .select(). Appends the clicked coordinate,
920
+ redraws the overlay, returns (updated_image, updated_points).
921
+ Ignores clicks once 4 points are set.
922
+ """
923
+ if image_pil is None:
924
+ return image_pil, points
925
+ if points is None:
926
+ points = []
927
+ if len(points) >= 4:
928
+ return draw_polygon_overlay(image_pil, points), points
929
+
930
+ x, y = int(evt.index[0]), int(evt.index[1])
931
+ new_points = points + [(x, y)]
932
+ return draw_polygon_overlay(image_pil, new_points), new_points
933
+
934
+
935
+ def clear_crop_points(image_pil):
936
+ """Reset polygon β€” return original image with no overlay and empty points."""
937
+ return image_pil, []
938
+
939
+
940
+
941
+
942
+
943
+ # Label correction grid
944
+
945
+ THUMB_SIZE = 80
946
+ GRID_COLS = 10
947
+ BORDER = 4
948
+ LABEL_H = 16
949
+
950
+ def _crop_cell_thumb(image_np, masks, cid):
951
+ """
952
+ Return a tight square crop of the cell, padded to THUMB_SIZE Γ— THUMB_SIZE.
953
+ """
954
+ ys, xs = np.where(masks == cid)
955
+ if len(ys) == 0:
956
+ return Image.fromarray(np.zeros((THUMB_SIZE, THUMB_SIZE, 3), dtype=np.uint8))
957
+
958
+ y0, y1 = ys.min(), ys.max() + 1
959
+ x0, x1 = xs.min(), xs.max() + 1
960
+
961
+ # add a small context border around the tight bounding box
962
+ pad = max(4, int(max(y1 - y0, x1 - x0) * 0.15))
963
+ h, w = image_np.shape[:2]
964
+ y0c = max(0, y0 - pad)
965
+ y1c = min(h, y1 + pad)
966
+ x0c = max(0, x0 - pad)
967
+ x1c = min(w, x1 + pad)
968
+
969
+ crop = image_np[y0c:y1c, x0c:x1c].copy()
970
+
971
+ # dim pixels that don't belong to this cell
972
+ dim_mask = (masks[y0c:y1c, x0c:x1c] != cid)
973
+ crop[dim_mask] = (crop[dim_mask] * 0.3).astype(np.uint8)
974
+
975
+ pil = Image.fromarray(crop).resize((THUMB_SIZE, THUMB_SIZE), Image.LANCZOS)
976
+ return pil
977
+
978
+
979
+ def build_correction_grid(image_np, masks, labelled_features, raw_image_np=None):
980
+
981
+ if not labelled_features:
982
+ placeholder = Image.fromarray(
983
+ np.zeros((THUMB_SIZE, THUMB_SIZE, 3), dtype=np.uint8)
984
+ )
985
+ return placeholder
986
+
987
+ thumb_src = raw_image_np if raw_image_np is not None else image_np
988
+
989
+ n = len(labelled_features)
990
+ n_cols = GRID_COLS
991
+ n_rows = (n + n_cols - 1) // n_cols
992
+
993
+ cell_h = THUMB_SIZE + 2 * BORDER + LABEL_H
994
+ cell_w = THUMB_SIZE + 2 * BORDER
995
+
996
+ grid_w = n_cols * cell_w
997
+ grid_h = n_rows * cell_h
998
+
999
+ grid = Image.new("RGB", (grid_w, grid_h), (30, 30, 30))
1000
+ draw = ImageDraw.Draw(grid)
1001
+
1002
+ for idx, feat in enumerate(labelled_features):
1003
+ cid = feat["cell_id"]
1004
+ label = feat["label"] # 0=live, 1=dead (may have been corrected)
1005
+ color = (220, 50, 50) if label == 1 else (50, 200, 80)
1006
+
1007
+ thumb = _crop_cell_thumb(thumb_src, masks, cid)
1008
+
1009
+ col = idx % n_cols
1010
+ row = idx // n_cols
1011
+ x0 = col * cell_w
1012
+ y0 = row * cell_h
1013
+
1014
+ # coloured border rectangle
1015
+ draw.rectangle([x0, y0, x0 + cell_w - 1, y0 + cell_h - 1], outline=color, width=BORDER)
1016
+
1017
+ # paste thumbnail inside border
1018
+ grid.paste(thumb, (x0 + BORDER, y0 + BORDER))
1019
+
1020
+ # small cell-id label strip
1021
+ strip_y = y0 + BORDER + THUMB_SIZE
1022
+ draw.rectangle([x0, strip_y, x0 + cell_w - 1, y0 + cell_h - 1],
1023
+ fill=(20, 20, 20))
1024
+ draw.text((x0 + BORDER + 2, strip_y + 1),
1025
+ f"#{cid} {'D' if label == 1 else 'L'}",
1026
+ fill=color)
1027
+
1028
+ return grid
1029
+
1030
+
1031
+ def toggle_cell_label(labelled_features, image_np, masks, raw_image_np, evt: gr.SelectData):
1032
+ """
1033
+ Called when user taps the correction grid image.
1034
+ Maps the tap pixel coordinate back to which thumbnail was tapped,
1035
+ flips that cell's label, rebuilds and returns the updated grid.
1036
+ """
1037
+ if not labelled_features or image_np is None:
1038
+ return build_correction_grid(image_np, masks, labelled_features), labelled_features
1039
+
1040
+ cell_w = THUMB_SIZE + 2 * BORDER
1041
+ cell_h = THUMB_SIZE + 2 * BORDER + LABEL_H
1042
+
1043
+ px, py = int(evt.index[0]), int(evt.index[1])
1044
+ col = px // cell_w
1045
+ row = py // cell_h
1046
+ idx = row * GRID_COLS + col
1047
+
1048
+ if idx < 0 or idx >= len(labelled_features):
1049
+ return build_correction_grid(image_np, masks, labelled_features, raw_image_np), labelled_features
1050
+
1051
+ # Flip the label
1052
+ updated = list(labelled_features) # shallow copy of list
1053
+ cell = dict(updated[idx]) # copy the dict so we don't mutate in place
1054
+ cell["label"] = 1 - cell["label"] # 0β†’1 or 1β†’0
1055
+ cell["corrected"] = True
1056
+ updated[idx] = cell
1057
+
1058
+ grid = build_correction_grid(image_np, masks, updated, raw_image_np)
1059
+ n_corrected = sum(1 for f in updated if f.get("corrected"))
1060
+ return grid, updated, f"Tapped cell #{cell['cell_id']} β†’ {'Dead' if cell['label']==1 else 'Live'}. {n_corrected} correction(s) total."
1061
+
1062
+
1063
+ def prepare_export_corrected(stored_masks, stored_image, labelled_features, label_map):
1064
+ """Export CSV using labelled_features with any manual corrections applied."""
1065
+ if stored_masks is None or stored_image is None:
1066
+ return None, "Run segmentation first before exporting."
1067
+ masks = unpack_array(stored_masks)
1068
+ image_np = unpack_array(stored_image)
1069
+ if not labelled_features:
1070
+ features = extract_cell_features(image_np, masks)
1071
+ labelled_features = attach_viability_labels(features, masks, image_np, label_map)
1072
+ if not labelled_features:
1073
+ return None, "No cells found to export."
1074
+ path = export_cell_data_csv(labelled_features)
1075
+ n = len(labelled_features)
1076
+ dead = sum(1 for r in labelled_features if r["label"] == 1)
1077
+ alive = n - dead
1078
+ corrected = sum(1 for r in labelled_features if r.get("corrected"))
1079
+ msg = (f"Exported {n} cells ({alive} live, {dead} dead). "
1080
+ f"{corrected} label(s) manually corrected.")
1081
+ return path, msg
1082
+
1083
+ def build_tab(tab_index, masks_state, image_state, result_state):
1084
+ with gr.Tab(f"Tab {tab_index}"):
1085
+ gr.Markdown("Run segmentation")
1086
+
1087
+ # Per-tab state: list of (x,y) crop polygon points
1088
+ crop_points_state = gr.State(value=[])
1089
+ # Clean copy of the uploaded image (no polygon drawn on it)
1090
+ base_image_state = gr.State(value=None)
1091
+ #raw image state
1092
+ raw_image_state = gr.State(value=None)
1093
+
1094
+ with gr.Row():
1095
+ with gr.Column():
1096
+ img_input = gr.Image(
1097
+ type="pil",
1098
+ label="Upload image",
1099
+ image_mode="RGB",
1100
+ height=512
1101
+ )
1102
+
1103
+ gr.Markdown(
1104
+ "### Crop region (optional)\n"
1105
+ "Click/tap up to **4 points** on the image below to define the region "
1106
+ "to segment. The polygon will be drawn as you click. "
1107
+ "Leave empty to segment the full image."
1108
+ )
1109
+
1110
+ crop_display = gr.Image(
1111
+ type="pil",
1112
+ label="Click to set crop vertices (up to 4)",
1113
+ interactive=True,
1114
+ height=400,
1115
+ )
1116
+
1117
+ crop_status = gr.Markdown("*Upload an image to enable cropping*")
1118
+
1119
+ clear_crop_btn = gr.Button("βœ• Clear crop points", size="sm")
1120
+
1121
+ model_dropdown = gr.Dropdown(
1122
+ choices=list(MODEL_OPTIONS.keys()),
1123
+ label="Select Model",
1124
+ value="Hemocytometer Model"
1125
+ )
1126
+
1127
+ gr.Markdown("### Size Filters")
1128
+
1129
+ use_min_filter = gr.Checkbox(
1130
+ label="Enable minimum size filter",
1131
+ value=False,
1132
+ info="Remove objects smaller than the threshold below"
1133
+ )
1134
+ min_size_slider = gr.Slider(
1135
+ minimum=0,
1136
+ maximum=500,
1137
+ value=0,
1138
+ step=10,
1139
+ label="Minimum Cell Size (pixels)",
1140
+ )
1141
+ min_size_recommendation = gr.Markdown(
1142
+ value="*Run segmentation to see recommended minimum*",
1143
+ )
1144
+
1145
+ use_max_filter = gr.Checkbox(
1146
+ label="Enable maximum size filter",
1147
+ value=False,
1148
+ info="Remove objects larger than the threshold below"
1149
+ )
1150
+ max_size_slider = gr.Slider(
1151
+ minimum=0,
1152
+ maximum=10000,
1153
+ value=10000,
1154
+ step=10,
1155
+ label="Maximum Cell Size (pixels)",
1156
+ )
1157
+
1158
+ gr.Markdown("### Stereological Counting")
1159
+ use_stereo = gr.Checkbox(
1160
+ label="Enable Stereological Counting",
1161
+ value=False,
1162
+ info="Use unbiased stereological rules for cell counting"
1163
+ )
1164
+
1165
+ with gr.Group(visible=False) as stereo_controls:
1166
+ gr.Markdown("""
1167
+ **Stereological Counting Rules:**
1168
+ - Cells touching LEFT or TOP exclusion zones are EXCLUDED
1169
+ - Cells touching RIGHT or BOTTOM edges are INCLUDED
1170
+ - This provides unbiased counting for quantification
1171
+ """)
1172
+
1173
+ excl_preview = gr.Image(
1174
+ type="pil",
1175
+ label="Exclusion Zone Preview (Red = Excluded)",
1176
+ height=500
1177
+ )
1178
+
1179
+ left_excl = gr.Slider(
1180
+ minimum=0,
1181
+ maximum=50,
1182
+ value=10,
1183
+ step=1,
1184
+ label="Left Exclusion Width (%)",
1185
+ info="Width of left exclusion zone"
1186
+ )
1187
+
1188
+ top_excl = gr.Slider(
1189
+ minimum=0,
1190
+ maximum=50,
1191
+ value=10,
1192
+ step=1,
1193
+ label="Top Exclusion Width (%)",
1194
+ info="Width of top exclusion zone"
1195
+ )
1196
+
1197
+ segment_btn = gr.Button("πŸ”¬ Run Segmentation", variant="primary", size="lg")
1198
+
1199
+ with gr.Column():
1200
+ cell_count_out = gr.Number(label="Total Cells Detected", precision=0)
1201
+ confluency_out = gr.Number(label="Confluency (%)", precision=1)
1202
+ overlay_out = gr.Image(type="pil", label="Segmentation Result")
1203
+ info_out = gr.Textbox(label="Processing Info", lines=4)
1204
+
1205
+ with gr.Group(visible=False) as viability_section:
1206
+ gr.Markdown("### Viability Assessment (Trypan Blue)")
1207
+
1208
+ viab_run_btn = gr.Button("Run Viability Analysis", variant="primary")
1209
+
1210
+ with gr.Row():
1211
+ live_count_out = gr.Number(label="Live Cells (Green)", precision=0)
1212
+ dead_count_out = gr.Number(label="Dead Cells (Red)", precision=0)
1213
+
1214
+ viab_overlay = gr.Image(type="pil", label="Viability (Green=Live Β· Red=Dead)")
1215
+ viab_percent_out = gr.Number(label="Viability (%)", precision=1)
1216
+ viab_info = gr.Textbox(label="Analysis Results", lines=4)
1217
+
1218
+ gr.Markdown("### Label Correction & Export")
1219
+ gr.Markdown(
1220
+ "After running viability, click **Build correction grid** to review every cell. "
1221
+ "**Green border = Live, Red border = Dead** (model predictions). "
1222
+ "Tap any thumbnail to flip its label β€” the counts and overlay update instantly. "
1223
+ "Export the corrected CSV for retraining."
1224
+ )
1225
+
1226
+ build_grid_btn = gr.Button("πŸ”² Build correction grid", variant="secondary")
1227
+ labelled_state = gr.State(value=[])
1228
+ label_map_state = gr.State(value={})
1229
+
1230
+ correction_grid = gr.Image(
1231
+ type="pil",
1232
+ label="Tap a cell to flip its label (green=live Β· red=dead)",
1233
+ interactive=True,
1234
+ visible=False,
1235
+ )
1236
+ correction_status = gr.Markdown(visible=False)
1237
+
1238
+ with gr.Row():
1239
+ export_btn = gr.Button("⬇️ Export corrected CSV", variant="secondary")
1240
+ export_info = gr.Textbox(label="Export status", lines=2, interactive=False)
1241
+ export_file = gr.File(label="Download CSV", visible=False)
1242
+
1243
+ # ---- Event handlers ------------------------------------------------
1244
+
1245
+ use_stereo.change(
1246
+ fn=toggle_stereological_mode,
1247
+ inputs=[use_stereo],
1248
+ outputs=[stereo_controls]
1249
+ )
1250
+
1251
+ def on_image_upload(img):
1252
+ if img is None:
1253
+ return None, None, "*Upload an image to enable cropping*"
1254
+ return img, img, "*Image loaded β€” click up to 4 points to define crop region*"
1255
+
1256
+ img_input.change(
1257
+ fn=on_image_upload,
1258
+ inputs=[img_input],
1259
+ outputs=[crop_display, base_image_state, crop_status]
1260
+ ).then(fn=lambda: [], outputs=[crop_points_state])
1261
+
1262
+ img_input.change(fn=update_exclusion_preview,
1263
+ inputs=[img_input, left_excl, top_excl], outputs=[excl_preview])
1264
+ left_excl.change(fn=update_exclusion_preview,
1265
+ inputs=[img_input, left_excl, top_excl], outputs=[excl_preview])
1266
+ top_excl.change(fn=update_exclusion_preview,
1267
+ inputs=[img_input, left_excl, top_excl], outputs=[excl_preview])
1268
+
1269
+ def on_crop_click(base_img, points, evt: gr.SelectData):
1270
+ updated_img, updated_pts = add_crop_point(base_img, points, evt)
1271
+ n = len(updated_pts)
1272
+ status = (f"*{n} / 4 points set β€” keep clicking*" if n < 4
1273
+ else "*4 points set βœ“ β€” click **βœ• Clear** to redo, or run segmentation*")
1274
+ return updated_img, updated_pts, status
1275
+
1276
+ crop_display.select(fn=on_crop_click,
1277
+ inputs=[base_image_state, crop_points_state],
1278
+ outputs=[crop_display, crop_points_state, crop_status])
1279
+
1280
+ def on_clear_crop(base_img):
1281
+ img, pts = clear_crop_points(base_img)
1282
+ return img, pts, "*Points cleared β€” click to set new vertices*"
1283
+
1284
+ clear_crop_btn.click(fn=on_clear_crop,
1285
+ inputs=[base_image_state],
1286
+ outputs=[crop_display, crop_points_state, crop_status])
1287
+
1288
+ segment_btn.click(
1289
+ fn=run_segmentation,
1290
+ inputs=[img_input, model_dropdown, min_size_slider, max_size_slider,
1291
+ use_min_filter, use_max_filter,
1292
+ use_stereo, left_excl, top_excl, crop_points_state],
1293
+ outputs=[cell_count_out, overlay_out, info_out, viability_section,
1294
+ masks_state, image_state, confluency_out, min_size_recommendation, raw_image_state]
1295
+ )
1296
+
1297
+ # ---- Run Viability button -------------------------------------------
1298
+ def on_run_viability(stored_masks, stored_image):
1299
+ overlay, alive, dead, viab_pct, info, label_map = run_viability(stored_masks, stored_image)
1300
+ return overlay, alive, dead, viab_pct, info, label_map
1301
+
1302
+ viab_run_btn.click(
1303
+ fn=on_run_viability,
1304
+ inputs=[masks_state, image_state],
1305
+ outputs=[viab_overlay, live_count_out, dead_count_out,
1306
+ viab_percent_out, viab_info, label_map_state]
1307
+ ).then(
1308
+ fn=save_tab_result,
1309
+ inputs=[cell_count_out, confluency_out, viab_percent_out],
1310
+ outputs=[result_state]
1311
+ )
1312
+
1313
+ # ---- Build correction grid -----------------------------------------
1314
+ def on_build_grid(stored_masks, stored_image, label_map, stored_raw_image):
1315
+ if stored_masks is None or stored_image is None or not label_map:
1316
+ return (gr.update(visible=False), [],
1317
+ gr.update(value="*Run viability analysis first.*", visible=True))
1318
+ masks = unpack_array(stored_masks)
1319
+ image_np = unpack_array(stored_image)
1320
+ raw_image_np = unpack_array(stored_raw_image) if stored_raw_image is not None else None
1321
+ features = extract_cell_features(image_np, masks)
1322
+ labelled = attach_viability_labels(features, masks, image_np, label_map)
1323
+ if not labelled:
1324
+ return (gr.update(visible=False), [],
1325
+ gr.update(value="*No cells found.*", visible=True))
1326
+ grid = build_correction_grid(image_np, masks, labelled, raw_image_np)
1327
+ n = len(labelled)
1328
+ dead = sum(1 for r in labelled if r["label"] == 1)
1329
+ msg = (f"*{n} cells β€” {n-dead} live (green), {dead} dead (red). "
1330
+ f"Tap any thumbnail to flip its label.*")
1331
+ return gr.update(value=grid, visible=True), labelled, gr.update(value=msg, visible=True)
1332
+
1333
+ build_grid_btn.click(
1334
+ fn=on_build_grid,
1335
+ inputs=[masks_state, image_state, label_map_state, raw_image_state],
1336
+ outputs=[correction_grid, labelled_state, correction_status]
1337
+ )
1338
+
1339
+ # ---- Grid tap β€” flip label, update overlay + counts ----------------
1340
+ def on_grid_tap(labelled, stored_masks, stored_image, stored_raw_image, evt: gr.SelectData):
1341
+ if not labelled or stored_masks is None:
1342
+ return None, labelled, "", 0, 0, 0.0, None, {}
1343
+ masks = unpack_array(stored_masks)
1344
+ image_np = unpack_array(stored_image)
1345
+ raw_image_np = unpack_array(stored_raw_image) if stored_raw_image is not None else None
1346
+ grid, updated, msg = toggle_cell_label(labelled, image_np, masks, raw_image_np, evt)
1347
+
1348
+ # Rebuild label_map from corrected labelled list
1349
+ new_label_map = {int(f["cell_id"]): int(f["label"]) for f in updated}
1350
+ overlay_np = draw_viability_overlay(image_np, masks, new_label_map)
1351
+ dead = sum(1 for f in updated if f["label"] == 1)
1352
+ alive = len(updated) - dead
1353
+ total = alive + dead
1354
+ viab_pct = (alive / total * 100) if total > 0 else 0.0
1355
+
1356
+ return (grid, updated, f"*{msg}*",
1357
+ alive, dead, viab_pct,
1358
+ Image.fromarray(overlay_np), new_label_map)
1359
+
1360
+ correction_grid.select(
1361
+ fn=on_grid_tap,
1362
+ inputs=[labelled_state, masks_state, image_state, raw_image_state],
1363
+ outputs=[correction_grid, labelled_state, correction_status,
1364
+ live_count_out, dead_count_out, viab_percent_out,
1365
+ viab_overlay, label_map_state]
1366
+ )
1367
+
1368
+ # ---- Export --------------------------------------------------------
1369
+ def on_export(stored_masks, stored_image, labelled, label_map):
1370
+ path, msg = prepare_export_corrected(stored_masks, stored_image, labelled, label_map)
1371
+ if path is None:
1372
+ return gr.update(visible=False), msg
1373
+ return gr.update(value=path, visible=True), msg
1374
+
1375
+ export_btn.click(
1376
+ fn=on_export,
1377
+ inputs=[masks_state, image_state, labelled_state, label_map_state],
1378
+ outputs=[export_file, export_info]
1379
+ )
1380
+
1381
+
1382
+
1383
+ # Gradio interface
1384
+
1385
+ with gr.Blocks(
1386
+ title="CellposeCellCounter",
1387
+ theme=gr.themes.Soft(),
1388
+ ) as demo:
1389
+ gr.Markdown("# CellposeCellCounter")
1390
+ gr.Markdown("For accurate cell confluency, crop the image to display only desired area. Note that some image file types are not yet supported. PNG and JPEG are preferred.")
1391
+
1392
+ # Shared mask/image state (one pair per tab so tabs don't clobber each other)
1393
+ masks_states = [gr.State(value=None) for _ in range(4)]
1394
+ image_states = [gr.State(value=None) for _ in range(4)]
1395
+ result_states = [gr.State(value=None) for _ in range(4)]
1396
+
1397
+ # Build Tabs 1–4 with a loop
1398
+ for i in range(4):
1399
+ build_tab(i + 1, masks_states[i], image_states[i], result_states[i])
1400
+
1401
+ # -------------------------------------------------------------------------
1402
+ # Tab 5 β€” Summary
1403
+ # -------------------------------------------------------------------------
1404
+ with gr.Tab("Tab 5 β€” Summary"):
1405
+ gr.Markdown("## Average Results Across All Tabs")
1406
+ gr.Markdown(
1407
+ "Run segmentation in one or more tabs, "
1408
+ "then click **Refresh Summary** to see the averages."
1409
+ )
1410
+
1411
+ refresh_btn = gr.Button("πŸ”„ Refresh Summary", variant="primary", size="lg")
1412
+
1413
+ with gr.Row():
1414
+ avg_count_out = gr.Number(label="Avg Cell Count", precision=1)
1415
+ avg_conf_out = gr.Number(label="Avg Confluency (%)", precision=1)
1416
+ avg_viab_out = gr.Number(label="Avg Viability (%)", precision=1)
1417
+
1418
+ summary_box = gr.Textbox(label="Per-Tab Breakdown", lines=10)
1419
+
1420
+ refresh_btn.click(
1421
+ fn=compute_summary,
1422
+ inputs=result_states, # list of 4 gr.State components
1423
+ outputs=[avg_count_out, avg_conf_out, avg_viab_out, summary_box]
1424
+ )
1425
+
1426
+
1427
+
1428
+ if __name__ == "__main__":
1429
+ demo.launch()