File size: 29,428 Bytes
92e2d37
 
 
 
 
666f8a5
 
1c3a61d
5ed0f39
206a302
92e2d37
666f8a5
92e2d37
666f8a5
2cd8aa5
63f35ec
92e2d37
 
 
 
 
ef78770
 
 
 
 
 
 
 
 
0b0bc00
92e2d37
1b3ec1b
 
 
 
 
 
 
666f8a5
92e2d37
 
 
ef78770
 
92e2d37
 
6c70288
 
5646c55
 
6c70288
194946e
6c70288
ab16831
1b3ec1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96c5ca5
92e2d37
6c70288
ef78770
92e2d37
ef78770
92e2d37
 
 
 
 
ef78770
 
 
92e2d37
 
ef78770
92e2d37
ef78770
92e2d37
666f8a5
ef78770
92e2d37
5ed0f39
3c46cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194946e
1c3a61d
206a302
 
 
 
 
 
5ed0f39
666f8a5
206a302
b9ba433
cedd38a
 
 
 
 
 
 
 
 
 
 
 
b9ba433
206a302
666f8a5
194946e
206a302
666f8a5
5ed0f39
 
 
 
3c46cb0
 
666f8a5
b9ba433
 
ef78770
b9ba433
3c46cb0
b9ba433
 
 
 
 
6c70288
1b3ec1b
 
 
 
 
5ed0f39
 
 
57d9eac
1b3ec1b
 
 
 
 
6c70288
 
1b3ec1b
 
 
 
 
 
6c70288
1b3ec1b
6c70288
 
 
 
3c46cb0
1b3ec1b
 
 
 
 
6c70288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab16831
6c70288
 
 
57d9eac
 
 
 
96c5ca5
2cdd74c
6c70288
 
 
96c5ca5
1b3ec1b
96c5ca5
 
 
 
1b3ec1b
96c5ca5
5646c55
6c70288
 
 
 
 
 
 
5ed0f39
6c70288
b9ba433
5ed0f39
 
194946e
0b0bc00
 
666f8a5
ef78770
0b0bc00
ef78770
 
92e2d37
0b0bc00
92e2d37
 
5646c55
92e2d37
5ed0f39
 
6c70288
5646c55
6c70288
 
 
92e2d37
 
ef78770
 
 
0b0bc00
ef78770
 
9e1244f
 
0b0bc00
9e1244f
 
f412687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b0bc00
f412687
 
dda77e2
 
 
 
 
 
 
 
 
 
 
83c8b1b
 
dda77e2
 
83c8b1b
dda77e2
 
0b0bc00
 
dda77e2
 
 
 
 
 
 
 
0b0bc00
dda77e2
 
 
 
 
 
 
 
 
 
 
 
 
 
0b0bc00
dda77e2
 
 
0b0bc00
dda77e2
 
 
0b0bc00
dda77e2
 
 
 
 
 
6c70288
dda77e2
 
 
 
 
83c8b1b
dda77e2
 
 
0b0bc00
dda77e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b0bc00
 
dda77e2
0b0bc00
dda77e2
 
 
83c8b1b
 
 
dda77e2
 
ce672ba
0b0bc00
dda77e2
 
f412687
 
 
 
 
 
 
0b0bc00
f412687
 
 
0b0bc00
f412687
 
dda77e2
 
 
83c8b1b
1b3ec1b
ce672ba
dda77e2
 
 
0b0bc00
dda77e2
 
 
ce672ba
0b0bc00
 
 
dda77e2
 
f412687
0b0bc00
dda77e2
 
 
 
1b3ec1b
dda77e2
 
 
 
 
 
 
0b0bc00
 
dda77e2
 
0b0bc00
1b3ec1b
 
 
 
 
 
 
 
0b0bc00
1b3ec1b
 
 
 
ce672ba
 
 
 
 
 
8b62e36
 
 
1b3ec1b
dda77e2
 
 
 
 
 
 
0b0bc00
dda77e2
 
46544ed
 
5646c55
46544ed
57d9eac
 
46544ed
 
5646c55
 
46544ed
 
 
 
5646c55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46544ed
 
 
dda77e2
 
 
46544ed
 
 
 
 
 
f8151e3
46544ed
 
 
 
 
 
 
 
 
 
f8151e3
 
 
eb998aa
 
f8151e3
 
 
 
 
 
 
 
46544ed
 
 
96c5ca5
46544ed
 
09bce3d
46544ed
 
 
 
 
 
 
 
 
dda77e2
 
83c8b1b
46544ed
 
 
 
 
 
 
8b62e36
0b0bc00
46544ed
 
dda77e2
46544ed
dda77e2
96c5ca5
5646c55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dda77e2
 
57d9eac
 
 
 
dda77e2
 
 
 
 
 
 
 
 
 
5646c55
dda77e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83c8b1b
dda77e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
import gradio as gr
import asyncio
import os
import glob
import torch
import sys
import builtins
import pandas as pd
import numpy as np
import time
from pathlib import Path
from PIL import Image

# --- Safe Input Mocking ---
builtins.input = lambda *args: "y"

# GenAI & ADK Imports
from google.adk.runners import InMemoryRunner
from google.genai import types

# Project Imports
try:
    from cellemetry import root_agent
    from cellemetry.config import AnalysisDeps
    from transformers import Sam3Processor, Sam3Model
except ImportError as e:
    print(f"⚠️ Import Error (Non-fatal for UI startup): {e}")
    Sam3Model = None
    Sam3Processor = None
    root_agent = None
    AnalysisDeps = None

# Optional: Distinctipy for better colors
try:
    from distinctipy import distinctipy
except ImportError:
    distinctipy = None
    print("⚠️ distinctipy not found. Using fallback colors.")

# --- Global State ---
MODEL_CACHE = {
    "model": None,
    "processor": None,
    "device": "cpu",
    "loaded": False
}

MASK_CACHE = {
    "current_path": None,
    "base_image": None,
    "layers": {}
}

ACTIVE_RUNNER = None

# --- Dynamic Color Helper ---
def generate_color_palette(n=50):
    """Generates a palette of N distinct colors [0-255]."""
    if distinctipy:
        print(f"🎨 Generating {n} distinct colors using distinctipy...")
        colors = distinctipy.get_colors(n)
        return [tuple(int(c * 255) for c in color) for color in colors]
    
    try:
        import matplotlib.pyplot as plt
        cmap = plt.get_cmap('tab20')
        return [tuple(int(c * 255) for c in cmap(i % 20)[:3]) for i in range(n)]
    except Exception:
        pass

    return [
        (0, 255, 0), (0, 0, 255), (255, 0, 0), (255, 255, 0), 
        (0, 255, 255), (255, 0, 255), (255, 128, 0), (128, 0, 255),
        (0, 128, 0), (0, 0, 128), (128, 0, 0), (128, 128, 0)
    ] * (n // 12 + 1)

COLOR_PALETTE = generate_color_palette(50)

def load_models():
    """Initialize SAM3 model. Now called AFTER app startup."""
    if MODEL_CACHE["loaded"]:
        return
        
    print("--- Loading SAM3 Model ---")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    MODEL_CACHE["device"] = device
    
    try:
        if Sam3Model is None:
            raise ImportError("Sam3Model not found. Please check requirements.")

        MODEL_CACHE["model"] = Sam3Model.from_pretrained("facebook/sam3").to(device)
        MODEL_CACHE["processor"] = Sam3Processor.from_pretrained("facebook/sam3")
        MODEL_CACHE["loaded"] = True
        print(f"βœ… SAM3 loaded on {device}")
        return f"βœ… SAM3 loaded on {device}"
    except Exception as e:
        print(f"⚠️ SAM3 load failed: {e}")
        return f"⚠️ Model load failed: {e}"

# --- Helpers ---
def clean_layer_name(filename):
    """
    Converts 'data_blue_nuclei.npz' -> 'Nuclei'.
    Removes standard color names and underscores.
    """
    raw = os.path.basename(filename).replace("data_", "").replace(".npz", "")
    parts = raw.split('_')
    colors = {
        'blue', 'green', 'red', 'yellow', 'cyan', 'magenta', 
        'orange', 'purple', 'white', 'black', 'gray', 'grey', 
        'pink', 'brown', 'lime', 'teal'
    }
    cleaned_parts = [p for p in parts if p.lower() not in colors]
    if not cleaned_parts:
        return raw.replace("_", " ").title()
    return " ".join(cleaned_parts).title()

def load_excel_data(logs_text):
    placeholder = pd.DataFrame({"Status": ["No Data Available"]})
    candidates = glob.glob("/tmp/*.xlsx") + glob.glob("*.xlsx")
    
    if not candidates:
        return None, placeholder, placeholder, placeholder

    report_file = max(candidates, key=os.path.getmtime)
    
    try:
        xls = pd.ExcelFile(report_file, engine='openpyxl')
        
        def process_sheet(sheet_name):
            if sheet_name in xls.sheet_names:
                df = pd.read_excel(xls, sheet_name)
                if not df.empty and len(df.columns) > 0:
                    df = df.set_index(df.columns[0]).T.reset_index()
                    df.rename(columns={df.columns[0]: "Metric"}, inplace=True)
                    return df
            return placeholder

        morph = process_sheet("Morphology")
        spatial = process_sheet("Spatial")
        relational = process_sheet("Relational")
            
        return report_file, morph, spatial, relational
    except Exception as e:
        print(f"⚠️ Error reading Excel: {e}")
        return report_file, placeholder, placeholder, placeholder

def get_available_layers():
    files = glob.glob("/tmp/data_*.npz")
    layers = []
    for f in files:
        layers.append(clean_layer_name(f))
    return sorted(list(set(layers)))

def update_opacity_sliders(layers):
    updates = []
    for i in range(4):
        if i < len(layers):
            layer_name = layers[i]
            updates.append(gr.update(visible=True, label=f"{layer_name} Opacity", value=0.6))
        else:
            updates.append(gr.update(visible=False))
    return updates

# --- OPTIMIZED OVERLAY GENERATION ---
def generate_overlay(image_path_str, selected_layers, layer_opacities=None, force_reload=False):
    """
    Regenerates overlay.
    force_reload: If True, clears the layer cache to pick up new files from agent.
    """
    if not image_path_str:
        return None
    
    # Force reload if requested
    if force_reload:
        MASK_CACHE["layers"] = {}

    # Check cache loading
    if MASK_CACHE["current_path"] != image_path_str or MASK_CACHE["base_image"] is None or not MASK_CACHE["layers"]:
        print(f"πŸ”„ Caching masks for {os.path.basename(image_path_str)}...")
        try:
            if MASK_CACHE["current_path"] != image_path_str or MASK_CACHE["base_image"] is None:
                base_img = Image.open(image_path_str).convert("RGBA")
                MASK_CACHE["base_image"] = base_img
                MASK_CACHE["current_path"] = image_path_str
            else:
                base_img = MASK_CACHE["base_image"]

            # Always scan for new layers if we are here
            all_layer_files = glob.glob("/tmp/data_*.npz")
            base_w, base_h = base_img.size
            
            for file_path in all_layer_files:
                layer_name = clean_layer_name(file_path)
                
                # Skip if already cached
                if layer_name in MASK_CACHE["layers"]:
                    continue
                    
                try:
                    data = np.load(file_path)
                    masks = data['masks'] if 'masks' in data else data[data.files[0]]
                    
                    if masks.size > 0:
                        if masks.ndim == 3:
                            combined_mask = np.max(masks, axis=0)
                        else:
                            combined_mask = masks
                        
                        mask_pil = Image.fromarray(combined_mask.astype(np.uint8) * 255)
                        if mask_pil.size != (base_w, base_h):
                            mask_pil = mask_pil.resize((base_w, base_h), Image.Resampling.NEAREST)
                        
                        MASK_CACHE["layers"][layer_name] = np.array(mask_pil, dtype=bool)
                except Exception as e:
                    print(f"Failed to cache layer {layer_name}: {e}")
        except Exception as e:
            print(f"Failed to load base image: {e}")
            return None

    if MASK_CACHE["base_image"] is None:
        return None

    base_image = MASK_CACHE["base_image"]
    overlay_accum = Image.new('RGBA', base_image.size, (0, 0, 0, 0))
    
    # Ensure selected_layers is iterable even if empty
    if selected_layers is None:
        selected_layers = []

    all_known_layers = sorted(MASK_CACHE["layers"].keys())

    for layer_name in selected_layers:
        if layer_name in MASK_CACHE["layers"]:
            mask_bool = MASK_CACHE["layers"][layer_name]
            
            # Use the global generated palette
            if layer_name in all_known_layers:
                color_idx = all_known_layers.index(layer_name) % len(COLOR_PALETTE)
                color = COLOR_PALETTE[color_idx]
            else:
                color = (255, 255, 0)

            opacity = 0.6
            if layer_opacities and layer_name in layer_opacities:
                opacity = layer_opacities[layer_name]
            
            layer_rgba = np.zeros((mask_bool.shape[0], mask_bool.shape[1], 4), dtype=np.uint8)
            layer_rgba[mask_bool] = (*color, int(255 * opacity))
            layer_img = Image.fromarray(layer_rgba, 'RGBA')
            overlay_accum = Image.alpha_composite(overlay_accum, layer_img)

    result = Image.alpha_composite(base_image, overlay_accum)
    return result.convert("RGB")

# --- Core Logic ---
async def run_analysis(image_path_str, user_prompt, session_id_state):
    # FIX: Use gr.skip() for updates to prevent UI jitter during streaming
    skipped_updates = [gr.skip()] * 4
    
    if not MODEL_CACHE["loaded"]:
        yield [], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates
        load_models() 

    if not image_path_str:
        yield [], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates
        return

    # Cleanup
    for f in glob.glob("/tmp/out_*.png") + glob.glob("/tmp/data_*.npz") + glob.glob("/tmp/*.xlsx"):
        try: os.remove(f)
        except: pass
    
    # Reset Cache
    MASK_CACHE["current_path"] = None 
    MASK_CACHE["base_image"] = None
    MASK_CACHE["layers"] = {}

    image_path = Path(image_path_str)
    
    if MODEL_CACHE["model"] is None:
        error_msg = "❌ Model failed to load. Please check logs."
        yield [{"role": "assistant", "content": error_msg}], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates
        return

    if AnalysisDeps is None:
        error_msg = "❌ Project imports failed. 'AnalysisDeps' is missing. Check your 'cellemetry' package installation."
        yield [{"role": "assistant", "content": error_msg}], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates
        return

    try:
        deps = AnalysisDeps(
            sam_model=MODEL_CACHE["model"],
            sam_processor=MODEL_CACHE["processor"],
            image_path=image_path,
            device=MODEL_CACHE["device"],
            pixel_size_microns=None 
        )
        
        global ACTIVE_RUNNER
        if root_agent is None:
            raise ValueError("Root agent is not loaded.")
            
        ACTIVE_RUNNER = InMemoryRunner(agent=root_agent, app_name="cellemetry_demo")
        
        session = await ACTIVE_RUNNER.session_service.create_session(
            app_name="cellemetry_demo",
            user_id="demo_user",
            state=deps.to_state_dict()
        )
        session_id = session.id
        
    except Exception as e:
        error_msg = f"❌ Agent Initialization Failed: {str(e)}"
        print(error_msg)
        yield [{"role": "assistant", "content": error_msg}], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates
        return

    image_bytes = image_path.read_bytes()
    content = types.Content(
        role="user",
        parts=[
            types.Part.from_text(text=user_prompt),
            types.Part.from_bytes(data=image_bytes, mime_type="image/png"),
        ]
    )

    logs = [f"πŸ”„ **Starting analysis** on {MODEL_CACHE['device']}..."]
    
    display_path = image_path_str.replace(" ", "%20")

    def yield_status(log_list):
        full_log = "\n\n".join(log_list)
        user_msg = f"![](file={display_path})\n\n{user_prompt}"
        return [{"role": "user", "content": user_msg}, {"role": "assistant", "content": full_log}]

    # FIX: Yield skips instead of updates
    yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates

    try:
        async for event in ACTIVE_RUNNER.run_async(user_id="demo_user", session_id=session.id, new_message=content):
            author = event.author
            
            if event.get_function_calls():
                for fc in event.get_function_calls():
                    logs.append(f"πŸ”§ **{author}**: Calling `{fc.name}`")
                    yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates
            
            if event.content and event.content.parts:
                for part in event.content.parts:
                    if hasattr(part, 'text') and part.text:
                        if event.partial:
                            if logs and logs[-1].startswith(f"πŸ’¬ **{author}**"):
                                logs[-1] = f"πŸ’¬ **{author}**: {part.text}..."
                            else:
                                logs.append(f"πŸ’¬ **{author}**: {part.text}...")
                        else:
                            if logs and logs[-1].startswith(f"πŸ’¬ **{author}**"):
                                logs[-1] = f"βœ… **{author}**: {part.text}"
                            else:
                                logs.append(f"βœ… **{author}**: {part.text}")
                        yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates

    except Exception as e:
        logs.append(f"❌ Error: {e}")
        yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates
        return

    logs.append("\nβœ… **Analysis Complete!** Loading results...")
    yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates
    
    await asyncio.sleep(0.5) 
    
    full_log_text = "\n".join(logs)
    report_file, df_m, df_s, df_r = load_excel_data(full_log_text)
    layers = get_available_layers()
    
    initial_overlay = generate_overlay(image_path_str, layers)

    completion_msg = f"\n\n---\n\n✨ **Analysis finished!** Found {len(layers)} layer(s). Results are now available in the Segmentation and Quantitative Results tabs."
    full_log_text += completion_msg
    
    final_user_msg = f"![](file={display_path})\n\n{user_prompt}"
    final_history = [{"role": "user", "content": final_user_msg}, {"role": "assistant", "content": full_log_text}]
    slider_updates = update_opacity_sliders(layers)
    
    # Final yield is the ONLY one with real data
    yield final_history, session_id, initial_overlay, gr.CheckboxGroup(choices=layers, value=layers), report_file, df_m, df_s, df_r, *slider_updates

async def unified_chat_handler(message, history, session_id, current_img_path):
    if history is None:
        history = []

    user_text = message.get("text", "").strip() if isinstance(message, dict) else str(message).strip()
    files = message.get("files", []) if isinstance(message, dict) else []
    
    image_path = None
    if files:
        image_path = files[0] if isinstance(files[0], str) else files[0].get("path")
    elif current_img_path:
        image_path = current_img_path
    
    waiting_df = pd.DataFrame({"Status": ["Waiting..."]})
    # FIX: Prepare skips
    skipped_updates = [gr.skip()] * 4

    # CASE 1: INITIAL ANALYSIS
    if image_path and (not session_id or files):
        if not user_text: user_text = "Analyze this microscopy image."
        
        display_path = image_path.replace(" ", "%20")
        
        history.append({"role": "user", "content": f"![](file={display_path})\n\n{user_text}"})
        history.append({"role": "assistant", "content": "πŸ”„ Starting analysis (Model loading may take a moment)..."})

        # Show Loading, Hide Results
        yield history, session_id, image_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *skipped_updates, None, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
        
        final_result = None
        try:
            async for result in run_analysis(image_path, user_text, session_id):
                final_result = result
                updated_history = result[0].copy()
                if files and len(updated_history) > 0:
                    updated_history[0] = {"role": "user", "content": f"![](file={display_path})\n\n{user_text}"}
                
                # Pass through the skips/data from run_analysis
                yield (updated_history, result[1], image_path, *result[2:], None, gr.update(), gr.update(), gr.update())
        except Exception as e:
            history.append({"role": "assistant", "content": f"❌ Critical Error: {str(e)}"})
            yield history, session_id, image_path, None, gr.CheckboxGroup(), None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
            return

        if final_result:
            updated_history = final_result[0].copy()
            if files and len(updated_history) > 0:
                updated_history[0] = {"role": "user", "content": f"![](file={display_path})\n\n{user_text}"}
            
            # Hide Loading, Show Results
            yield (updated_history, final_result[1], image_path, *final_result[2:], None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=True))
        return

    # CASE 2: FOLLOW-UP ANALYSIS
    elif session_id and user_text:
        history.append({"role": "user", "content": user_text})
        history.append({"role": "assistant", "content": "πŸ’­ Thinking..."})
        
        # Don't show loading overlay for follow-ups
        # FIX: Send gr.skip() to all result components to prevent jitter
        yield history, session_id, current_img_path, gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(), gr.update(), gr.update()
        
        if not ACTIVE_RUNNER:
            history[-1]["content"] = "⚠️ Session expired or Agent not initialized."
            yield history, None, current_img_path, gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(), gr.update(), gr.update()
            return
        
        content = types.Content(role="user", parts=[types.Part.from_text(text=user_text)])
        accumulated_response = ""
        
        try:
            async for event in ACTIVE_RUNNER.run_async(user_id="demo_user", session_id=session_id, new_message=content):
                if event.content and event.content.parts:
                    for part in event.content.parts:
                        if hasattr(part, 'text') and part.text:
                            accumulated_response += part.text
                            history[-1]["content"] = accumulated_response
                            # FIX: Keep sending gr.skip() during stream
                            yield history, session_id, current_img_path, gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(), gr.update(), gr.update()
        except Exception as e:
            history[-1]["content"] = f"❌ Error: {e}"
            yield history, session_id, current_img_path, gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(), gr.update(), gr.update()
            return

        report_file, df_m, df_s, df_r = load_excel_data("") 
        layers = get_available_layers()
        
        new_overlay = generate_overlay(current_img_path, layers, force_reload=True)
        slider_updates = update_opacity_sliders(layers)
        
        # Final yield updates the components
        yield (
            history, 
            session_id, 
            current_img_path, 
            new_overlay,                            
            gr.CheckboxGroup(value=layers, choices=layers),
            report_file,                            
            df_m, df_s, df_r,                       
            *slider_updates,                        
            None,                                   
            gr.update(),              
            gr.update(),              
            gr.update()               
        )
        return

    else:
        if not history:
            history = [{"role": "assistant", "content": "πŸ‘‹ Welcome! Upload a microscopy image and describe what you'd like to analyze."}]
        else:
            history.append({"role": "assistant", "content": "⚠️ Please provide a question or upload a new image."})
        yield history, session_id, current_img_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *skipped_updates, None, gr.update(), gr.update(), gr.update()

# --- UI Layout ---

custom_css = """
/* 1. Global Margin Setting */
#main_container {
    margin-left: 10% !important;
    margin-right: 10% !important;
    width: auto !important;
}

/* 2. Fix Panel Width */
.right-panel {
    min-width: 600px !important;
    flex-grow: 2 !important;
}

/* 3. Consistent Blue Border for Results Panel */
.bordered-panel {
    border: 2px solid #3498db !important;
    border-radius: 8px !important;
    padding: 10px !important;
    background: #ffffff !important;
}

/* 4. Fix Table Width Overflow */
.gradio-dataframe td {
    white-space: normal !important;
}
.gradio-dataframe {
    overflow-x: auto !important;
    max-width: 100% !important;
    display: block !important;
}
"""

with gr.Blocks(title="Cellemetry Agent", css=custom_css) as demo:
    session_id_state = gr.State(None)
    current_image_path = gr.State(None)
    
    with gr.Column(elem_id="main_container"):
        with gr.Row():
            # --- LEFT COLUMN (Chat) ---
            with gr.Column(scale=1, min_width=300):
                chatbot = gr.Chatbot(
                    label="Agent Conversation", 
                    height=400, 
                    value=[{"role": "assistant", "content": "πŸ‘‹ Welcome to Cellemetry! Upload a microscopy image and describe what you'd like to analyze."}],
                    show_label=True
                )
                chat_input = gr.MultimodalTextbox(
                    file_types=["image"],
                    placeholder="Upload an image and describe your analysis...",
                    show_label=False,
                    submit_btn="Send"
                )

                # --- NEW: Examples Component ---
                # NOTE: Ensure you have an 'examples' folder with 'sample_1.png' and 'sample_2.png'
                example_data = [
                    [{"text": "Analyze this image and describe the cell morphology.", "files": ["examples/sample_1.jpg"]}],
                    [{"text": "Segment the nuclei and calculate spatial distribution.", "files": ["examples/sample_2.jpg"]}],
                ]

                gr.Examples(
                    examples=example_data,
                    inputs=chat_input,
                    label="Try an Example",
                )

            # --- RIGHT COLUMN (Results) ---
            with gr.Column(scale=2, elem_classes=["right-panel"]):
                
                # Welcome overlay
                with gr.Column(visible=True, elem_id="welcome-overlay") as welcome_overlay:
                    gr.HTML(f"""
                        <div style="display: flex; flex-direction: column; align-items: center; justify-content: center; min-height: 780px; padding: 40px; background: #ffffff; border-radius: 8px; border: 2px solid #3498db;">
                            <div style="text-align: center;">
                                <div style='text-align: center;'>
                                    <img src="https://raw.githubusercontent.com/hmgill/Cellemetry/main/logo.png" alt="Logo" style="height:200px; display: block; margin: 0 auto;">
                                </div>
                                <h2 style="color: #333; margin: 20px 0 10px; font-weight: 600; font-size: 28px;">Welcome to Cellemetry</h2>
                                <p style="color: #666; font-size: 16px; max-width: 400px; margin: 0 auto 30px; line-height: 1.6;">Upload a microscopy image to get started with AI-powered cell analysis and segmentation</p>
                                <div style="padding: 20px; background: #fff; border-radius: 8px; border-left: 4px solid #3498db; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
                                    <p style="color: #555; margin: 0; font-size: 14px;">πŸ‘ˆ Use the chat on the left to begin</p>
                                </div>
                            </div>
                        </div>
                    """)
                
                # Loading overlay
                with gr.Column(visible=False, elem_id="loading-overlay") as loading_overlay:
                    gr.HTML("""
                        <div style="display: flex; flex-direction: column; align-items: center; justify-content: center; height: 780px; background: rgba(255, 255, 255, 0.95); border-radius: 8px; box-shadow: 0 4px 6px rgba(0,0,0,0.1);">
                            <div style="text-align: center;">
                                <div style="border: 8px solid #f3f3f3; border-top: 8px solid #3498db; border-radius: 50%; width: 60px; height: 60px; animation: spin 1s linear infinite; margin: 0 auto 20px;"></div>
                                <h3 style="color: #555; margin: 0;">βš™οΈ Analyzing</h3>
                                <p style="color: #888; margin-top: 10px;">Your image is being processed...</p>
                            </div>
                            <style>@keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }</style>
                        </div>
                    """)
                
                # Results tabs
                with gr.Column(visible=False, elem_classes=["bordered-panel"]) as results_container:
                    with gr.Tabs() as results_tabs:
                        with gr.Tab("πŸ” Segmentation"):
                            with gr.Row():
                                with gr.Column(scale=3):
                                    overlay_output = gr.Image(label="Segmentation Result", height=780, type="pil")
                                with gr.Column(scale=1):
                                    gr.Markdown("**Layer Controls**")
                                    layer_checkboxes = gr.CheckboxGroup(label="Visible Layers", choices=[], value=[], interactive=True)
                                    gr.Markdown("**Opacity Controls**")
                                    opacity_slider_1 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 1 Opacity", visible=False)
                                    opacity_slider_2 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 2 Opacity", visible=False)
                                    opacity_slider_3 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 3 Opacity", visible=False)
                                    opacity_slider_4 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 4 Opacity", visible=False)
                        
                        with gr.Tab("πŸ“Š Quantitative Results"):
                            download_btn = gr.File(label="Download Excel Report")
                            with gr.Tabs():
                                with gr.Tab("Morphology"):
                                    tbl_morph = gr.Dataframe(interactive=False, wrap=True)
                                with gr.Tab("Spatial"):
                                    tbl_spatial = gr.Dataframe(interactive=False, wrap=True)
                                with gr.Tab("Relational"):
                                    tbl_rel = gr.Dataframe(interactive=False, wrap=True)

    def regenerate_overlay_with_opacity(img_path, selected_layers, op1, op2, op3, op4):
        # FIX: Allow empty selected_layers to pass through (returns just the base image)
        if not img_path: return None
        if selected_layers is None: selected_layers = []

        opacities = {}
        opacity_values = [op1, op2, op3, op4]
        all_layers = get_available_layers()
        for i, layer in enumerate(all_layers[:4]):
            opacities[layer] = opacity_values[i]
        return generate_overlay(img_path, selected_layers, opacities)

    chat_input.submit(
        fn=unified_chat_handler,
        inputs=[chat_input, chatbot, session_id_state, current_image_path],
        outputs=[chatbot, session_id_state, current_image_path, overlay_output, layer_checkboxes, download_btn, tbl_morph, tbl_spatial, tbl_rel, opacity_slider_1, opacity_slider_2, opacity_slider_3, opacity_slider_4, chat_input, welcome_overlay, loading_overlay, results_container]
    )
    
    for component in [layer_checkboxes, opacity_slider_1, opacity_slider_2, opacity_slider_3, opacity_slider_4]:
        component.change(
            fn=regenerate_overlay_with_opacity,
            inputs=[current_image_path, layer_checkboxes, opacity_slider_1, opacity_slider_2, opacity_slider_3, opacity_slider_4],
            outputs=[overlay_output]
        )

    demo.load(load_models)

if __name__ == "__main__":
    demo.queue().launch(
        ssr_mode=False, 
        theme=gr.themes.Soft(), 
        server_name="0.0.0.0", 
        server_port=7860,
        allowed_paths=[".", "/tmp"] 
    )