import gradio as gr import asyncio import os import glob import torch import sys import builtins import pandas as pd import numpy as np import time from pathlib import Path from PIL import Image # --- Safe Input Mocking --- builtins.input = lambda *args: "y" # GenAI & ADK Imports from google.adk.runners import InMemoryRunner from google.genai import types # Project Imports try: from cellemetry import root_agent from cellemetry.config import AnalysisDeps from transformers import Sam3Processor, Sam3Model except ImportError as e: print(f"āš ļø Import Error (Non-fatal for UI startup): {e}") Sam3Model = None Sam3Processor = None root_agent = None AnalysisDeps = None # Optional: Distinctipy for better colors try: from distinctipy import distinctipy except ImportError: distinctipy = None print("āš ļø distinctipy not found. Using fallback colors.") # --- Global State --- MODEL_CACHE = { "model": None, "processor": None, "device": "cpu", "loaded": False } MASK_CACHE = { "current_path": None, "base_image": None, "layers": {} } ACTIVE_RUNNER = None # --- Dynamic Color Helper --- def generate_color_palette(n=50): """Generates a palette of N distinct colors [0-255].""" if distinctipy: print(f"šŸŽØ Generating {n} distinct colors using distinctipy...") colors = distinctipy.get_colors(n) return [tuple(int(c * 255) for c in color) for color in colors] try: import matplotlib.pyplot as plt cmap = plt.get_cmap('tab20') return [tuple(int(c * 255) for c in cmap(i % 20)[:3]) for i in range(n)] except Exception: pass return [ (0, 255, 0), (0, 0, 255), (255, 0, 0), (255, 255, 0), (0, 255, 255), (255, 0, 255), (255, 128, 0), (128, 0, 255), (0, 128, 0), (0, 0, 128), (128, 0, 0), (128, 128, 0) ] * (n // 12 + 1) COLOR_PALETTE = generate_color_palette(50) def load_models(): """Initialize SAM3 model. Now called AFTER app startup.""" if MODEL_CACHE["loaded"]: return print("--- Loading SAM3 Model ---") device = "cuda" if torch.cuda.is_available() else "cpu" MODEL_CACHE["device"] = device try: if Sam3Model is None: raise ImportError("Sam3Model not found. Please check requirements.") MODEL_CACHE["model"] = Sam3Model.from_pretrained("facebook/sam3").to(device) MODEL_CACHE["processor"] = Sam3Processor.from_pretrained("facebook/sam3") MODEL_CACHE["loaded"] = True print(f"āœ… SAM3 loaded on {device}") return f"āœ… SAM3 loaded on {device}" except Exception as e: print(f"āš ļø SAM3 load failed: {e}") return f"āš ļø Model load failed: {e}" # --- Helpers --- def clean_layer_name(filename): """ Converts 'data_blue_nuclei.npz' -> 'Nuclei'. Removes standard color names and underscores. """ raw = os.path.basename(filename).replace("data_", "").replace(".npz", "") parts = raw.split('_') colors = { 'blue', 'green', 'red', 'yellow', 'cyan', 'magenta', 'orange', 'purple', 'white', 'black', 'gray', 'grey', 'pink', 'brown', 'lime', 'teal' } cleaned_parts = [p for p in parts if p.lower() not in colors] if not cleaned_parts: return raw.replace("_", " ").title() return " ".join(cleaned_parts).title() def load_excel_data(logs_text): placeholder = pd.DataFrame({"Status": ["No Data Available"]}) candidates = glob.glob("/tmp/*.xlsx") + glob.glob("*.xlsx") if not candidates: return None, placeholder, placeholder, placeholder report_file = max(candidates, key=os.path.getmtime) try: xls = pd.ExcelFile(report_file, engine='openpyxl') def process_sheet(sheet_name): if sheet_name in xls.sheet_names: df = pd.read_excel(xls, sheet_name) if not df.empty and len(df.columns) > 0: df = df.set_index(df.columns[0]).T.reset_index() df.rename(columns={df.columns[0]: "Metric"}, inplace=True) return df return placeholder morph = process_sheet("Morphology") spatial = process_sheet("Spatial") relational = process_sheet("Relational") return report_file, morph, spatial, relational except Exception as e: print(f"āš ļø Error reading Excel: {e}") return report_file, placeholder, placeholder, placeholder def get_available_layers(): files = glob.glob("/tmp/data_*.npz") layers = [] for f in files: layers.append(clean_layer_name(f)) return sorted(list(set(layers))) def update_opacity_sliders(layers): updates = [] for i in range(4): if i < len(layers): layer_name = layers[i] updates.append(gr.update(visible=True, label=f"{layer_name} Opacity", value=0.6)) else: updates.append(gr.update(visible=False)) return updates # --- OPTIMIZED OVERLAY GENERATION --- def generate_overlay(image_path_str, selected_layers, layer_opacities=None, force_reload=False): """ Regenerates overlay. force_reload: If True, clears the layer cache to pick up new files from agent. """ if not image_path_str: return None # Force reload if requested if force_reload: MASK_CACHE["layers"] = {} # Check cache loading if MASK_CACHE["current_path"] != image_path_str or MASK_CACHE["base_image"] is None or not MASK_CACHE["layers"]: print(f"šŸ”„ Caching masks for {os.path.basename(image_path_str)}...") try: if MASK_CACHE["current_path"] != image_path_str or MASK_CACHE["base_image"] is None: base_img = Image.open(image_path_str).convert("RGBA") MASK_CACHE["base_image"] = base_img MASK_CACHE["current_path"] = image_path_str else: base_img = MASK_CACHE["base_image"] # Always scan for new layers if we are here all_layer_files = glob.glob("/tmp/data_*.npz") base_w, base_h = base_img.size for file_path in all_layer_files: layer_name = clean_layer_name(file_path) # Skip if already cached if layer_name in MASK_CACHE["layers"]: continue try: data = np.load(file_path) masks = data['masks'] if 'masks' in data else data[data.files[0]] if masks.size > 0: if masks.ndim == 3: combined_mask = np.max(masks, axis=0) else: combined_mask = masks mask_pil = Image.fromarray(combined_mask.astype(np.uint8) * 255) if mask_pil.size != (base_w, base_h): mask_pil = mask_pil.resize((base_w, base_h), Image.Resampling.NEAREST) MASK_CACHE["layers"][layer_name] = np.array(mask_pil, dtype=bool) except Exception as e: print(f"Failed to cache layer {layer_name}: {e}") except Exception as e: print(f"Failed to load base image: {e}") return None if MASK_CACHE["base_image"] is None: return None base_image = MASK_CACHE["base_image"] overlay_accum = Image.new('RGBA', base_image.size, (0, 0, 0, 0)) # Ensure selected_layers is iterable even if empty if selected_layers is None: selected_layers = [] all_known_layers = sorted(MASK_CACHE["layers"].keys()) for layer_name in selected_layers: if layer_name in MASK_CACHE["layers"]: mask_bool = MASK_CACHE["layers"][layer_name] # Use the global generated palette if layer_name in all_known_layers: color_idx = all_known_layers.index(layer_name) % len(COLOR_PALETTE) color = COLOR_PALETTE[color_idx] else: color = (255, 255, 0) opacity = 0.6 if layer_opacities and layer_name in layer_opacities: opacity = layer_opacities[layer_name] layer_rgba = np.zeros((mask_bool.shape[0], mask_bool.shape[1], 4), dtype=np.uint8) layer_rgba[mask_bool] = (*color, int(255 * opacity)) layer_img = Image.fromarray(layer_rgba, 'RGBA') overlay_accum = Image.alpha_composite(overlay_accum, layer_img) result = Image.alpha_composite(base_image, overlay_accum) return result.convert("RGB") # --- Core Logic --- async def run_analysis(image_path_str, user_prompt, session_id_state): # FIX: Use gr.skip() for updates to prevent UI jitter during streaming skipped_updates = [gr.skip()] * 4 if not MODEL_CACHE["loaded"]: yield [], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates load_models() if not image_path_str: yield [], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates return # Cleanup for f in glob.glob("/tmp/out_*.png") + glob.glob("/tmp/data_*.npz") + glob.glob("/tmp/*.xlsx"): try: os.remove(f) except: pass # Reset Cache MASK_CACHE["current_path"] = None MASK_CACHE["base_image"] = None MASK_CACHE["layers"] = {} image_path = Path(image_path_str) if MODEL_CACHE["model"] is None: error_msg = "āŒ Model failed to load. Please check logs." yield [{"role": "assistant", "content": error_msg}], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates return if AnalysisDeps is None: error_msg = "āŒ Project imports failed. 'AnalysisDeps' is missing. Check your 'cellemetry' package installation." yield [{"role": "assistant", "content": error_msg}], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates return try: deps = AnalysisDeps( sam_model=MODEL_CACHE["model"], sam_processor=MODEL_CACHE["processor"], image_path=image_path, device=MODEL_CACHE["device"], pixel_size_microns=None ) global ACTIVE_RUNNER if root_agent is None: raise ValueError("Root agent is not loaded.") ACTIVE_RUNNER = InMemoryRunner(agent=root_agent, app_name="cellemetry_demo") session = await ACTIVE_RUNNER.session_service.create_session( app_name="cellemetry_demo", user_id="demo_user", state=deps.to_state_dict() ) session_id = session.id except Exception as e: error_msg = f"āŒ Agent Initialization Failed: {str(e)}" print(error_msg) yield [{"role": "assistant", "content": error_msg}], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates return image_bytes = image_path.read_bytes() content = types.Content( role="user", parts=[ types.Part.from_text(text=user_prompt), types.Part.from_bytes(data=image_bytes, mime_type="image/png"), ] ) logs = [f"šŸ”„ **Starting analysis** on {MODEL_CACHE['device']}..."] display_path = image_path_str.replace(" ", "%20") def yield_status(log_list): full_log = "\n\n".join(log_list) user_msg = f"![](file={display_path})\n\n{user_prompt}" return [{"role": "user", "content": user_msg}, {"role": "assistant", "content": full_log}] # FIX: Yield skips instead of updates yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates try: async for event in ACTIVE_RUNNER.run_async(user_id="demo_user", session_id=session.id, new_message=content): author = event.author if event.get_function_calls(): for fc in event.get_function_calls(): logs.append(f"šŸ”§ **{author}**: Calling `{fc.name}`") yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates if event.content and event.content.parts: for part in event.content.parts: if hasattr(part, 'text') and part.text: if event.partial: if logs and logs[-1].startswith(f"šŸ’¬ **{author}**"): logs[-1] = f"šŸ’¬ **{author}**: {part.text}..." else: logs.append(f"šŸ’¬ **{author}**: {part.text}...") else: if logs and logs[-1].startswith(f"šŸ’¬ **{author}**"): logs[-1] = f"āœ… **{author}**: {part.text}" else: logs.append(f"āœ… **{author}**: {part.text}") yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates except Exception as e: logs.append(f"āŒ Error: {e}") yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates return logs.append("\nāœ… **Analysis Complete!** Loading results...") yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates await asyncio.sleep(0.5) full_log_text = "\n".join(logs) report_file, df_m, df_s, df_r = load_excel_data(full_log_text) layers = get_available_layers() initial_overlay = generate_overlay(image_path_str, layers) completion_msg = f"\n\n---\n\n✨ **Analysis finished!** Found {len(layers)} layer(s). Results are now available in the Segmentation and Quantitative Results tabs." full_log_text += completion_msg final_user_msg = f"![](file={display_path})\n\n{user_prompt}" final_history = [{"role": "user", "content": final_user_msg}, {"role": "assistant", "content": full_log_text}] slider_updates = update_opacity_sliders(layers) # Final yield is the ONLY one with real data yield final_history, session_id, initial_overlay, gr.CheckboxGroup(choices=layers, value=layers), report_file, df_m, df_s, df_r, *slider_updates async def unified_chat_handler(message, history, session_id, current_img_path): if history is None: history = [] user_text = message.get("text", "").strip() if isinstance(message, dict) else str(message).strip() files = message.get("files", []) if isinstance(message, dict) else [] image_path = None if files: image_path = files[0] if isinstance(files[0], str) else files[0].get("path") elif current_img_path: image_path = current_img_path waiting_df = pd.DataFrame({"Status": ["Waiting..."]}) # FIX: Prepare skips skipped_updates = [gr.skip()] * 4 # CASE 1: INITIAL ANALYSIS if image_path and (not session_id or files): if not user_text: user_text = "Analyze this microscopy image." display_path = image_path.replace(" ", "%20") history.append({"role": "user", "content": f"![](file={display_path})\n\n{user_text}"}) history.append({"role": "assistant", "content": "šŸ”„ Starting analysis (Model loading may take a moment)..."}) # Show Loading, Hide Results yield history, session_id, image_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *skipped_updates, None, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) final_result = None try: async for result in run_analysis(image_path, user_text, session_id): final_result = result updated_history = result[0].copy() if files and len(updated_history) > 0: updated_history[0] = {"role": "user", "content": f"![](file={display_path})\n\n{user_text}"} # Pass through the skips/data from run_analysis yield (updated_history, result[1], image_path, *result[2:], None, gr.update(), gr.update(), gr.update()) except Exception as e: history.append({"role": "assistant", "content": f"āŒ Critical Error: {str(e)}"}) yield history, session_id, image_path, None, gr.CheckboxGroup(), None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) return if final_result: updated_history = final_result[0].copy() if files and len(updated_history) > 0: updated_history[0] = {"role": "user", "content": f"![](file={display_path})\n\n{user_text}"} # Hide Loading, Show Results yield (updated_history, final_result[1], image_path, *final_result[2:], None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)) return # CASE 2: FOLLOW-UP ANALYSIS elif session_id and user_text: history.append({"role": "user", "content": user_text}) history.append({"role": "assistant", "content": "šŸ’­ Thinking..."}) # Don't show loading overlay for follow-ups # FIX: Send gr.skip() to all result components to prevent jitter yield history, session_id, current_img_path, gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(), gr.update(), gr.update() if not ACTIVE_RUNNER: history[-1]["content"] = "āš ļø Session expired or Agent not initialized." yield history, None, current_img_path, gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(), gr.update(), gr.update() return content = types.Content(role="user", parts=[types.Part.from_text(text=user_text)]) accumulated_response = "" try: async for event in ACTIVE_RUNNER.run_async(user_id="demo_user", session_id=session_id, new_message=content): if event.content and event.content.parts: for part in event.content.parts: if hasattr(part, 'text') and part.text: accumulated_response += part.text history[-1]["content"] = accumulated_response # FIX: Keep sending gr.skip() during stream yield history, session_id, current_img_path, gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(), gr.update(), gr.update() except Exception as e: history[-1]["content"] = f"āŒ Error: {e}" yield history, session_id, current_img_path, gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(), gr.update(), gr.update() return report_file, df_m, df_s, df_r = load_excel_data("") layers = get_available_layers() new_overlay = generate_overlay(current_img_path, layers, force_reload=True) slider_updates = update_opacity_sliders(layers) # Final yield updates the components yield ( history, session_id, current_img_path, new_overlay, gr.CheckboxGroup(value=layers, choices=layers), report_file, df_m, df_s, df_r, *slider_updates, None, gr.update(), gr.update(), gr.update() ) return else: if not history: history = [{"role": "assistant", "content": "šŸ‘‹ Welcome! Upload a microscopy image and describe what you'd like to analyze."}] else: history.append({"role": "assistant", "content": "āš ļø Please provide a question or upload a new image."}) yield history, session_id, current_img_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *skipped_updates, None, gr.update(), gr.update(), gr.update() # --- UI Layout --- custom_css = """ /* 1. Global Margin Setting */ #main_container { margin-left: 10% !important; margin-right: 10% !important; width: auto !important; } /* 2. Fix Panel Width */ .right-panel { min-width: 600px !important; flex-grow: 2 !important; } /* 3. Consistent Blue Border for Results Panel */ .bordered-panel { border: 2px solid #3498db !important; border-radius: 8px !important; padding: 10px !important; background: #ffffff !important; } /* 4. Fix Table Width Overflow */ .gradio-dataframe td { white-space: normal !important; } .gradio-dataframe { overflow-x: auto !important; max-width: 100% !important; display: block !important; } """ with gr.Blocks(title="Cellemetry Agent", css=custom_css) as demo: session_id_state = gr.State(None) current_image_path = gr.State(None) with gr.Column(elem_id="main_container"): with gr.Row(): # --- LEFT COLUMN (Chat) --- with gr.Column(scale=1, min_width=300): chatbot = gr.Chatbot( label="Agent Conversation", height=400, value=[{"role": "assistant", "content": "šŸ‘‹ Welcome to Cellemetry! Upload a microscopy image and describe what you'd like to analyze."}], show_label=True ) chat_input = gr.MultimodalTextbox( file_types=["image"], placeholder="Upload an image and describe your analysis...", show_label=False, submit_btn="Send" ) # --- NEW: Examples Component --- # NOTE: Ensure you have an 'examples' folder with 'sample_1.png' and 'sample_2.png' example_data = [ [{"text": "Analyze this image and describe the cell morphology.", "files": ["examples/sample_1.jpg"]}], [{"text": "Segment the nuclei and calculate spatial distribution.", "files": ["examples/sample_2.jpg"]}], ] gr.Examples( examples=example_data, inputs=chat_input, label="Try an Example", ) # --- RIGHT COLUMN (Results) --- with gr.Column(scale=2, elem_classes=["right-panel"]): # Welcome overlay with gr.Column(visible=True, elem_id="welcome-overlay") as welcome_overlay: gr.HTML(f"""
Logo

Welcome to Cellemetry

Upload a microscopy image to get started with AI-powered cell analysis and segmentation

šŸ‘ˆ Use the chat on the left to begin

""") # Loading overlay with gr.Column(visible=False, elem_id="loading-overlay") as loading_overlay: gr.HTML("""

āš™ļø Analyzing

Your image is being processed...

""") # Results tabs with gr.Column(visible=False, elem_classes=["bordered-panel"]) as results_container: with gr.Tabs() as results_tabs: with gr.Tab("šŸ” Segmentation"): with gr.Row(): with gr.Column(scale=3): overlay_output = gr.Image(label="Segmentation Result", height=780, type="pil") with gr.Column(scale=1): gr.Markdown("**Layer Controls**") layer_checkboxes = gr.CheckboxGroup(label="Visible Layers", choices=[], value=[], interactive=True) gr.Markdown("**Opacity Controls**") opacity_slider_1 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 1 Opacity", visible=False) opacity_slider_2 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 2 Opacity", visible=False) opacity_slider_3 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 3 Opacity", visible=False) opacity_slider_4 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 4 Opacity", visible=False) with gr.Tab("šŸ“Š Quantitative Results"): download_btn = gr.File(label="Download Excel Report") with gr.Tabs(): with gr.Tab("Morphology"): tbl_morph = gr.Dataframe(interactive=False, wrap=True) with gr.Tab("Spatial"): tbl_spatial = gr.Dataframe(interactive=False, wrap=True) with gr.Tab("Relational"): tbl_rel = gr.Dataframe(interactive=False, wrap=True) def regenerate_overlay_with_opacity(img_path, selected_layers, op1, op2, op3, op4): # FIX: Allow empty selected_layers to pass through (returns just the base image) if not img_path: return None if selected_layers is None: selected_layers = [] opacities = {} opacity_values = [op1, op2, op3, op4] all_layers = get_available_layers() for i, layer in enumerate(all_layers[:4]): opacities[layer] = opacity_values[i] return generate_overlay(img_path, selected_layers, opacities) chat_input.submit( fn=unified_chat_handler, inputs=[chat_input, chatbot, session_id_state, current_image_path], outputs=[chatbot, session_id_state, current_image_path, overlay_output, layer_checkboxes, download_btn, tbl_morph, tbl_spatial, tbl_rel, opacity_slider_1, opacity_slider_2, opacity_slider_3, opacity_slider_4, chat_input, welcome_overlay, loading_overlay, results_container] ) for component in [layer_checkboxes, opacity_slider_1, opacity_slider_2, opacity_slider_3, opacity_slider_4]: component.change( fn=regenerate_overlay_with_opacity, inputs=[current_image_path, layer_checkboxes, opacity_slider_1, opacity_slider_2, opacity_slider_3, opacity_slider_4], outputs=[overlay_output] ) demo.load(load_models) if __name__ == "__main__": demo.queue().launch( ssr_mode=False, theme=gr.themes.Soft(), server_name="0.0.0.0", server_port=7860, allowed_paths=[".", "/tmp"] )