File size: 23,671 Bytes
4deb2a3
 
 
 
 
 
 
90297a9
4deb2a3
 
90297a9
4deb2a3
 
 
 
 
 
 
 
 
90297a9
1ce6867
 
4deb2a3
 
1ce6867
 
 
 
 
 
 
4deb2a3
1ce6867
 
4deb2a3
90297a9
1ce6867
 
 
4deb2a3
1ce6867
 
4deb2a3
1ce6867
4deb2a3
90297a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4deb2a3
 
 
90297a9
4deb2a3
 
 
 
90297a9
4deb2a3
 
 
90297a9
4deb2a3
 
 
 
 
90297a9
4deb2a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90297a9
4deb2a3
90297a9
4deb2a3
 
 
 
 
 
90297a9
4deb2a3
 
 
 
 
 
 
 
 
 
 
 
 
 
90297a9
4deb2a3
 
 
90297a9
4deb2a3
 
 
 
 
 
 
 
 
 
 
 
 
 
1ce6867
28ebb60
 
784e413
90297a9
784e413
 
28ebb60
784e413
 
 
 
 
28ebb60
784e413
 
 
28ebb60
784e413
 
 
26e54cc
766d78d
28ebb60
 
 
 
5ca7d24
 
 
 
 
 
 
 
 
28ebb60
5ca7d24
 
 
28ebb60
5ca7d24
 
 
 
 
28ebb60
 
85cbde9
28ebb60
 
85cbde9
26e54cc
85cbde9
 
 
4deb2a3
90297a9
 
 
 
 
 
 
 
 
 
 
 
 
 
28ebb60
4deb2a3
90297a9
 
1ce6867
 
 
4deb2a3
 
 
 
 
 
 
 
 
 
 
1ce6867
 
 
 
4deb2a3
1ce6867
4deb2a3
 
28ebb60
 
de1204f
 
 
 
 
90297a9
 
 
 
 
1ce6867
28ebb60
 
4deb2a3
28ebb60
 
 
 
1ce6867
5ca7d24
 
26e54cc
28ebb60
 
5ca7d24
 
28ebb60
5ca7d24
90297a9
 
 
 
 
 
5ca7d24
28ebb60
 
 
 
 
90297a9
 
28ebb60
4deb2a3
 
28ebb60
4deb2a3
 
 
90297a9
4deb2a3
28ebb60
4deb2a3
 
 
 
90297a9
 
 
 
4deb2a3
28ebb60
 
 
 
4deb2a3
90297a9
 
28ebb60
4deb2a3
 
 
 
 
 
 
 
 
28ebb60
4deb2a3
 
 
90297a9
 
28ebb60
 
 
 
 
 
 
 
 
 
766d78d
4deb2a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ce6867
4deb2a3
 
 
 
 
1ce6867
 
4deb2a3
 
 
90297a9
4deb2a3
1ce6867
4deb2a3
1ce6867
4deb2a3
1ce6867
4deb2a3
 
 
 
 
90297a9
af3397f
784e413
 
28ebb60
784e413
28ebb60
784e413
90297a9
5ca7d24
 
17fdd27
 
5ca7d24
 
 
90297a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ce6867
4deb2a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90297a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4deb2a3
1ce6867
90297a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4deb2a3
90297a9
4deb2a3
 
 
 
 
 
 
 
 
 
 
 
 
 
90297a9
4deb2a3
 
 
90297a9
4deb2a3
90297a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4deb2a3
90297a9
 
 
4deb2a3
90297a9
 
 
 
4deb2a3
 
 
90297a9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
import gradio as gr
import spaces
from cellpose import models
import numpy as np
import cv2
import matplotlib.pyplot as plt
import tempfile
from PIL import Image, ImageDraw
import io
from huggingface_hub import hf_hub_download
import base64 

HF_REPO_ID = "myang4218/cellposemodel"
MODEL_OPTIONS = {
    "Hemocytometer Model": "hemocytometermodel.npy",
    "General Model": "generalmodel.npy"
}

loaded_models = {}

# ---- mobile-safe size limits (aggressive for Safari) ----
MAX_SIDE = 1024          
MAX_PIXELS = 1024 * 1024


def safe_resize(image_np):
    """
    Downscale image to fit within MAX_SIDE and MAX_PIXELS while
    preserving aspect ratio. Works for RGB / RGBA / grayscale.
    """
    h, w = image_np.shape[:2]
    total = h * w

    if max(h, w) <= MAX_SIDE and total <= MAX_PIXELS:
        return image_np

    # compute scale 
    scale_side = MAX_SIDE / max(h, w)
    scale_pixels = (MAX_PIXELS / total) ** 0.5
    scale = min(scale_side, scale_pixels)

    new_w = max(1, int(w * scale))
    new_h = max(1, int(h * scale))

    return cv2.resize(image_np, (new_w, new_h), interpolation=cv2.INTER_AREA)


def draw_exclusion_overlay(image_np, left_width_pct, top_width_pct):
    
    h, w = image_np.shape[:2]
    
    # Convert to PIL for drawing
    img_pil = Image.fromarray(image_np)
    draw = ImageDraw.Draw(img_pil, 'RGBA')
    
    # Calculate pixel widths from percentages
    left_px = int(w * left_width_pct / 100)
    top_px = int(h * top_width_pct / 100)
    
    # Draw overlays for exclusion zones
    if left_px > 0:
        # Left exclusion zone
        draw.rectangle(
            [(0, 0), (left_px, h)],
            fill=(255, 0, 0, 80)  # Semi-transparent red
        )
        # border line
        draw.line([(left_px, 0), (left_px, h)], fill=(255, 0, 0, 255), width=3)
    
    if top_px > 0:
        # Top exclusion zone
        draw.rectangle(
            [(0, 0), (w, top_px)],
            fill=(255, 0, 0, 80)  # Semi-transparent red
        )
        # border line
        draw.line([(0, top_px), (w, top_px)], fill=(255, 0, 0, 255), width=3)
    
    return np.array(img_pil)


def apply_stereological_exclusion(masks, left_width_pct, top_width_pct):
    h, w = masks.shape
    
    # Calculate pixel widths from percentages
    left_px = int(w * left_width_pct / 100)
    top_px = int(h * top_width_pct / 100)
    
    filtered_masks = masks.copy()
    cell_ids = np.unique(masks)
    cell_ids = cell_ids[cell_ids > 0]
    
    excluded_cells = []
    included_cells = []
    
    for cell_id in cell_ids:
        cell_mask = (masks == cell_id)
        
        # Get cell boundary coordinates
        rows, cols = np.where(cell_mask)
        
        # Check if cell touches left exclusion zone
        touches_left = np.any(cols < left_px) if left_px > 0 else False
        
        # Check if cell touches top exclusion zone
        touches_top = np.any(rows < top_px) if top_px > 0 else False
        
        # Exclude if touching left or top
        if touches_left or touches_top:
            filtered_masks[cell_mask] = 0
            excluded_cells.append(cell_id)
        else:
            included_cells.append(cell_id)
    
    # Renumber remaining cells
    unique_ids = np.unique(filtered_masks)
    unique_ids = unique_ids[unique_ids > 0]
    
    renumbered_masks = np.zeros_like(filtered_masks)
    for new_id, old_id in enumerate(unique_ids, start=1):
        renumbered_masks[filtered_masks == old_id] = new_id
    
    return renumbered_masks, len(excluded_cells), len(included_cells)


def classify_cells_by_blueness(image_np, masks, blue_threshold):
    """
    Classify cells as dead (blue) or alive based on single blueness metric

    Args:
        image_np: RGB image array
        masks: Cellpose segmentation masks
        blue_threshold: Single threshold value (0-100) for blueness detection

    Returns:
        dead_count, alive_count, colored_overlay
    """
    
    if len(image_np.shape) == 2:
        image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
    elif len(image_np.shape) == 3 and image_np.shape[2] == 4:
        image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)

    
    hsv = cv2.cvtColor(image_np, cv2.COLOR_RGB2HSV)

    # Calculate blueness index for each pixel
    hue = hsv[:, :, 0].astype(np.float32)
    saturation = hsv[:, :, 1].astype(np.float32)

    # Hue score: peaks around 115 (blue in HSV), drops off towards edges
    # Handle hue wrap-around for blue detection (100-130 range)
    hue_distance = np.minimum(np.abs(hue - 115), 180 - np.abs(hue - 115))
    hue_score = np.maximum(0, 1 - hue_distance / 65)  # 65 gives good blue range

    # Combine hue proximity with saturation intensity
    blueness = hue_score * (saturation / 255.0)

    # Convert threshold from 0-100 to 0-1 scale
    threshold = blue_threshold / 100.0

    # Get unique cell IDs 
    cell_ids = np.unique(masks)
    cell_ids = cell_ids[cell_ids > 0]  

    dead_cells = []
    alive_cells = []

    # Classify each cell
    for cell_id in cell_ids:
        
        cell_mask = (masks == cell_id)
        cell_blueness = np.mean(blueness[cell_mask])

        if cell_blueness > threshold:
            dead_cells.append(cell_id)
        else:
            alive_cells.append(cell_id)

    # Create colored overlay
    overlay = image_np.copy().astype(np.float32)  # Ensure float for blending

    # Color dead cells red, alive cells green
    for cell_id in dead_cells:
        cell_mask = (masks == cell_id)
        overlay[cell_mask] = [255, 0, 0]  

    for cell_id in alive_cells:
        cell_mask = (masks == cell_id)
        overlay[cell_mask] = [0, 255, 0]  

    # Blend with original image
    alpha = 0.4
    final_overlay = (1 - alpha) * image_np.astype(np.float32) + alpha * overlay
    final_overlay = np.clip(final_overlay, 0, 255).astype(np.uint8)

    return len(dead_cells), len(alive_cells), final_overlay


def measure_confluency(masks, image_np):
    tot_pixels = image_np.shape[0] * image_np.shape[1]
    cell_pixels = np.count_nonzero(masks)
    confluency = cell_pixels / tot_pixels * 100
    return confluency
    
def filter_mask_by_size(masks, minimum_pixels):
    filtered_masks = masks.copy()
    cell_ids = np.unique(masks)
    cell_ids = cell_ids[cell_ids > 0]

    removed_count = 0
    
    for cell_id in cell_ids:
        cell_mask = (masks == cell_id)
        cell_pixels = np.count_nonzero(cell_mask)
        if cell_pixels < minimum_pixels:
            filtered_masks[cell_mask] = 0
            removed_count += 1

    unique_ids = np.unique(filtered_masks)
    unique_ids = unique_ids[unique_ids > 0]

    renumbered_masks = np.zeros_like(filtered_masks)
    for new_id, old_id in enumerate(unique_ids, start=1):
        renumbered_masks[filtered_masks == old_id] = new_id

    return renumbered_masks, removed_count


def filter_mask_by_maxsize(masks, maximum_pixels):
    filtered_masks = masks.copy()
    cell_ids = np.unique(masks)
    cell_ids = cell_ids[cell_ids > 0]

    removed_count = 0
    for cell_id in cell_ids:
        cell_mask = (masks == cell_id)
        cell_pixels = np.count_nonzero(cell_mask)
        if cell_pixels > maximum_pixels:
            filtered_masks[cell_mask] = 0
            removed_count += 1

    unique_ids = np.unique(filtered_masks)
    unique_ids = unique_ids[unique_ids > 0]

    renumbered_masks = np.zeros_like(filtered_masks)
    for new_id, old_id in enumerate(unique_ids, start=1):
        renumbered_masks[filtered_masks == old_id] = new_id

    return renumbered_masks, removed_count


def rec_min_size(masks, q=25):
    ids = np.unique(masks)
    ids = ids[ids > 0]
    if len(ids) == 0:
        return 0
    sizes = np.array([np.count_nonzero(masks == cid) for cid in ids])
    return int(round(np.percentile(sizes, q)))


def toggle_stereological_mode(use_stereology):
    """Show/hide stereological controls based on checkbox"""
    return gr.update(visible=use_stereology)


def update_exclusion_preview(image, left_width, top_width):
    """Update the preview image with exclusion zone overlay"""
    if image is None:
        return None
    
    image_np = np.array(image)
    overlay = draw_exclusion_overlay(image_np, left_width, top_width)
    return Image.fromarray(overlay)


@spaces.GPU
def run_segmentation(image, model_choice, min_cell_size, max_cell_size, 
                     use_stereology, left_exclusion, top_exclusion):
    image_np = np.array(image)
    image_np = safe_resize(image_np)
    
    try:
        model_filename = MODEL_OPTIONS[model_choice]
        model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=model_filename)

        if model_filename in loaded_models:
            model = loaded_models[model_filename]
        else:
            model = models.CellposeModel(gpu=True, pretrained_model=model_path)
            loaded_models[model_filename] = model

        # Process image format to RGB
        if len(image_np.shape) == 2:
            processed_image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
        elif len(image_np.shape) == 3 and image_np.shape[2] == 4:
            processed_image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
        else:
            processed_image_np = image_np

        # Run Cellpose segmentation
        masks_raw, flows, styles = model.eval(processed_image_np, diameter=None, channels=[0, 0])

        ids = np.unique(masks_raw)
        ids = ids[ids > 0]

        sizes = np.array([np.count_nonzero(masks_raw == cid) for cid in ids])

        print("num_cells:", len(ids))
        print("mean:", sizes.mean() if len(sizes) > 0 else 0)
        print("median:", np.median(sizes) if len(sizes) > 0 else 0)
        print("p90:", np.percentile(sizes, 90) if len(sizes) > 0 else 0)
        print("max:", sizes.max() if len(sizes) > 0 else 0)
        
        # Compute recommendation from RAW masks 
        recommend_min = rec_min_size(masks_raw)

        # If user sets slider to 0, use the recommendation
        min_used = recommend_min if (min_cell_size == 0) else int(min_cell_size)

        # Apply filters
        masks = masks_raw.copy()
        removed_small = 0
        removed_large = 0

        if min_used > 0:
            masks, removed_small = filter_mask_by_size(masks, min_used)

        if max_cell_size > 0:
            masks, removed_large = filter_mask_by_maxsize(masks, int(max_cell_size))

        # Apply stereological exclusion if enabled
        excluded_count = 0
        if use_stereology:
            masks, excluded_count, included_count = apply_stereological_exclusion(
                masks, left_exclusion, top_exclusion
            )
        
        filter_msg = ""
        if removed_small:
            filter_msg += f"Removed {removed_small} small objects (< {min_used} pixels).\n"
        if removed_large:
            filter_msg += f"Removed {removed_large} large objects (> {int(max_cell_size)} pixels).\n"
        if use_stereology and excluded_count > 0:
            filter_msg += f"Stereological exclusion: {excluded_count} cells excluded (touching left/top zones).\n"

        cell_count = len(np.unique(masks)) - 1
        confluency = measure_confluency(masks, processed_image_np)

        # Create a basic segmentation overlay (without viability)
        segmentation_overlay = processed_image_np.copy().astype(np.float32)
        if masks.max() > 0:
            np.random.seed(42)  # For consistent random colors
            colors = np.random.randint(0, 255, size=(masks.max() + 1, 3))
            colors[0] = [0, 0, 0]
            colored_mask = colors[masks]
            alpha = 0.4
            segmentation_overlay = (1 - alpha) * segmentation_overlay + alpha * colored_mask
        segmentation_overlay = np.clip(segmentation_overlay, 0, 255).astype(np.uint8)
        
        # Add exclusion zone overlay if stereology is enabled
        if use_stereology:
            segmentation_overlay = draw_exclusion_overlay(segmentation_overlay, left_exclusion, top_exclusion)

        info_msg = ""
        if filter_msg:
            info_msg += filter_msg
        info_msg += f"Segmentation complete! Found {cell_count} cells.\n"
        info_msg += f"Confluency: {confluency:.1f}%\n"
        if use_stereology:
            info_msg += f"Stereological counting enabled (Left: {left_exclusion}%, Top: {top_exclusion}%)\n"
        info_msg += "Now adjust the Blue Threshold for viability assessment."

        return (
            cell_count,
            Image.fromarray(segmentation_overlay),
            info_msg,
            gr.update(visible=True),
            pack_array(masks),
            pack_array(processed_image_np),
            confluency,
            gr.update(value=recommend_min),  # update slider display to recommended
        )

    except Exception as e:
        import traceback
        traceback.print_exc()
        return (
            0,
            None,
            f"Error during segmentation: {str(e)}",
            gr.update(visible=False),
            None,
            None,
            0.0,
            gr.update(),
        )


def update_viability_realtime(blue_threshold, stored_masks, stored_image_np):
    # avoid unpacking None (e.g. slider moved before segmentation)
    if stored_masks is None or stored_image_np is None:
        return None, 0, 0, 0.0, "Please run segmentation first."

    stored_masks = unpack_array(stored_masks)
    stored_image_np = unpack_array(stored_image_np)

    try:
        dead_count, alive_count, viability_overlay_np = classify_cells_by_blueness(
            stored_image_np, stored_masks, blue_threshold
        )

        total_count = alive_count + dead_count
        viability_percent = (alive_count / total_count * 100) if total_count > 0 else 0.0
        confluency = measure_confluency(stored_masks, stored_image_np)

        overlay_image = Image.fromarray(viability_overlay_np)
        info_msg = f"Total cells: {total_count}\nLive (green): {alive_count}\nDead (red): {dead_count}\n"
        info_msg += f"Viability: {viability_percent:.1f}%\nConfluency: {confluency:.1f}%\nBlue threshold: {blue_threshold}%"

        return overlay_image, alive_count, dead_count, viability_percent, info_msg

    except Exception as e:
        return None, 0, 0, 0.0, f"Error updating viability: {str(e)}"


def pack_array(arr):
    pil = Image.fromarray(arr.astype(np.uint8))
    buf = io.BytesIO()
    pil.save(buf, format="PNG")
    return buf.getvalue()


def unpack_array(data):
    return np.array(Image.open(io.BytesIO(data)))


# Gradio interface
with gr.Blocks(
    title="CellposeCellCounter",
    theme=gr.themes.Soft(),
) as demo:
    gr.Markdown("# CellposeCellCounter")
    gr.Markdown("For accurate cell confluency, crop the image to display only desired area. Note that some image file types are not yet supported. PNG and JPEG are preferred.")

    # Define State components to store masks and image data across function calls
    masks_state = gr.State(value=None)
    image_state = gr.State(value=None)

    with gr.Tab("Cell Quantification"):
        gr.Markdown("Run segmentation")

        with gr.Row():
            with gr.Column():
                img_input = gr.Image(
                    type="pil",
                    label="Microscopy image",
                    image_mode="RGB",
                    height=512
                )
                    
                model_dropdown1 = gr.Dropdown(
                    choices=list(MODEL_OPTIONS.keys()),
                    label="Select Model",
                    value="Hemocytometer Model"
                )
                
                min_size_slider1 = gr.Slider(
                    minimum=0,
                    maximum=500,
                    value=0,
                    step=10,
                    label="Minimum Cell Size (pixels). Leave at zero for automated recommendation",
                )
                
                max_size_slider1 = gr.Slider(
                    minimum=0,
                    maximum=1000,
                    value=1000,
                    step=10,
                    label="Maximum Cell Size (pixels)",
                )
                
                # Stereological counting option
                gr.Markdown("### Stereological Counting")
                use_stereology_checkbox = gr.Checkbox(
                    label="Enable Stereological Counting",
                    value=False,
                    info="Use unbiased stereological rules for cell counting"
                )
                
                # Stereological controls (initially hidden)
                with gr.Group(visible=False) as stereology_controls:
                    gr.Markdown("""
                    **Stereological Counting Rules:**
                    - Cells touching LEFT or TOP exclusion zones are EXCLUDED
                    - Cells touching RIGHT or BOTTOM edges are INCLUDED
                    - This provides unbiased counting for quantification
                    """)
                    
                    exclusion_preview = gr.Image(
                        type="pil",
                        label="Exclusion Zone Preview (Red = Excluded)",
                        height=300
                    )
                    
                    left_exclusion_slider = gr.Slider(
                        minimum=0,
                        maximum=50,
                        value=10,
                        step=1,
                        label="Left Exclusion Width (%)",
                        info="Width of left exclusion zone"
                    )
                    
                    top_exclusion_slider = gr.Slider(
                        minimum=0,
                        maximum=50,
                        value=10,
                        step=1,
                        label="Top Exclusion Width (%)",
                        info="Width of top exclusion zone"
                    )
                
                segment_btn1 = gr.Button("🔬 Run Segmentation", variant="primary", size="lg")

            with gr.Column():
                cell_count_output1 = gr.Number(label="Total Cells Detected", precision=0)
                confluency_output1 = gr.Number(label="Confluency (%)", precision=1)
                overlay_output1 = gr.Image(type="pil", label="Segmentation Result")
                info_output1 = gr.Textbox(label="Processing Info", lines=4)

        # Viability Assessment Section
        with gr.Group(visible=False) as viability_section1:
            gr.Markdown("### Viability Assessment (Trypan Blue)")
            gr.Markdown("Adjust the threshold to classify cells as live (green) or dead (red).")

            with gr.Row():
                with gr.Column():
                    blue_threshold1 = gr.Slider(
                        minimum=0,
                        maximum=100,
                        value=25,
                        step=1,
                        label="Blue Threshold (%)",
                        info="Higher values = more selective for blue cells"
                    )

                with gr.Column():
                    live_count_output1 = gr.Number(label="Live Cells (Green)", precision=0)
                    dead_count_output1 = gr.Number(label="Dead Cells (Red)", precision=0)

            viability_overlay1 = gr.Image(type="pil", label="Viability Assessment (Green=Live, Red=Dead)")
            viability_percent_output1 = gr.Number(label="Viability (%)", precision=1)
            viability_info1 = gr.Textbox(label="Analysis Results", lines=5)

        # Event handlers
        
        # Toggle stereological controls visibility
        use_stereology_checkbox.change(
            fn=toggle_stereological_mode,
            inputs=[use_stereology_checkbox],
            outputs=[stereology_controls]
        )
        
        # Update exclusion preview when image is uploaded or sliders change
        img_input.change(
            fn=update_exclusion_preview,
            inputs=[img_input, left_exclusion_slider, top_exclusion_slider],
            outputs=[exclusion_preview]
        )
        
        left_exclusion_slider.change(
            fn=update_exclusion_preview,
            inputs=[img_input, left_exclusion_slider, top_exclusion_slider],
            outputs=[exclusion_preview]
        )
        
        top_exclusion_slider.change(
            fn=update_exclusion_preview,
            inputs=[img_input, left_exclusion_slider, top_exclusion_slider],
            outputs=[exclusion_preview]
        )
        
        # Run segmentation
        segment_btn1.click(
            fn=run_segmentation,
            inputs=[
                img_input, 
                model_dropdown1, 
                min_size_slider1, 
                max_size_slider1,
                use_stereology_checkbox,
                left_exclusion_slider,
                top_exclusion_slider
            ],
            outputs=[
                cell_count_output1, 
                overlay_output1, 
                info_output1, 
                viability_section1, 
                masks_state, 
                image_state, 
                confluency_output1, 
                min_size_slider1
            ]
        ).then(  # Chain the initial viability assessment after segmentation
            fn=update_viability_realtime,
            inputs=[blue_threshold1, masks_state, image_state],
            outputs=[viability_overlay1, live_count_output1, dead_count_output1, viability_percent_output1, viability_info1]
        )

        # Slider changes update viability in real-time
        blue_threshold1.change(
            fn=update_viability_realtime,
            inputs=[blue_threshold1, masks_state, image_state],
            outputs=[viability_overlay1, live_count_output1, dead_count_output1, viability_percent_output1, viability_info1]
        )

    # Instructions
    with gr.Accordion("Instructions", open=False):
        gr.Markdown("""
        ### How to use:

        1. **Upload and Segment**:
            - Upload your microscopy image.
            - Select a Cellpose model (e.g., "Hemocytometer Model" for suspension culture).
            - **(Optional)** Enable Stereological Counting for unbiased quantification.
            - Click "Run Segmentation".

        2. **Stereological Counting** (Optional):
            - Check "Enable Stereological Counting" to use unbiased counting rules.
            - Adjust the Left and Top exclusion zone widths using the sliders.
            - Preview shows excluded areas in red.
            - **Counting Rules**:
                - Cells touching LEFT or TOP exclusion zones are EXCLUDED
                - Cells touching RIGHT or BOTTOM edges are INCLUDED
                - This ensures unbiased, systematic counting

        3. **Analysis Results**:
            - **Cell Count**: Total number of detected cells (after exclusions if using stereology)
            - **Confluency**: Percentage of image area covered by cells

        4. **Real-time Viability Assessment (Trypan Blue)**:
            - After segmentation, the viability section will become visible.
            - Adjust the **"Blue Threshold (%)"** slider in real-time.
            - **Lower values (10-20%)** are more sensitive.
            - **Higher values (30-50%)** are more selective.
            - Green cells = Live, Red cells = Dead.

        5. **Interpreting Results**:
            - The app displays total, live, and dead cell counts, viability percentage, and confluency.
            - If stereological counting is enabled, excluded cells are noted in the processing info.
        """)

if __name__ == "__main__":
    demo.launch()