Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import asyncio | |
| import os | |
| import glob | |
| import torch | |
| import sys | |
| import builtins | |
| import pandas as pd | |
| import numpy as np | |
| import time | |
| from pathlib import Path | |
| from PIL import Image | |
| # --- Safe Input Mocking --- | |
| builtins.input = lambda *args: "y" | |
| # GenAI & ADK Imports | |
| from google.adk.runners import InMemoryRunner | |
| from google.genai import types | |
| # Project Imports | |
| try: | |
| from cellemetry import root_agent | |
| from cellemetry.config import AnalysisDeps | |
| from transformers import Sam3Processor, Sam3Model | |
| except ImportError as e: | |
| print(f"β οΈ Import Error (Non-fatal for UI startup): {e}") | |
| Sam3Model = None | |
| Sam3Processor = None | |
| root_agent = None | |
| AnalysisDeps = None | |
| # Optional: Distinctipy for better colors | |
| try: | |
| from distinctipy import distinctipy | |
| except ImportError: | |
| distinctipy = None | |
| print("β οΈ distinctipy not found. Using fallback colors.") | |
| # --- Global State --- | |
| MODEL_CACHE = { | |
| "model": None, | |
| "processor": None, | |
| "device": "cpu", | |
| "loaded": False | |
| } | |
| MASK_CACHE = { | |
| "current_path": None, | |
| "base_image": None, | |
| "layers": {} | |
| } | |
| ACTIVE_RUNNER = None | |
| # --- Dynamic Color Helper --- | |
| def generate_color_palette(n=50): | |
| """Generates a palette of N distinct colors [0-255].""" | |
| if distinctipy: | |
| print(f"π¨ Generating {n} distinct colors using distinctipy...") | |
| colors = distinctipy.get_colors(n) | |
| return [tuple(int(c * 255) for c in color) for color in colors] | |
| try: | |
| import matplotlib.pyplot as plt | |
| cmap = plt.get_cmap('tab20') | |
| return [tuple(int(c * 255) for c in cmap(i % 20)[:3]) for i in range(n)] | |
| except Exception: | |
| pass | |
| return [ | |
| (0, 255, 0), (0, 0, 255), (255, 0, 0), (255, 255, 0), | |
| (0, 255, 255), (255, 0, 255), (255, 128, 0), (128, 0, 255), | |
| (0, 128, 0), (0, 0, 128), (128, 0, 0), (128, 128, 0) | |
| ] * (n // 12 + 1) | |
| COLOR_PALETTE = generate_color_palette(50) | |
| def load_models(): | |
| """Initialize SAM3 model. Now called AFTER app startup.""" | |
| if MODEL_CACHE["loaded"]: | |
| return | |
| print("--- Loading SAM3 Model ---") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_CACHE["device"] = device | |
| try: | |
| if Sam3Model is None: | |
| raise ImportError("Sam3Model not found. Please check requirements.") | |
| MODEL_CACHE["model"] = Sam3Model.from_pretrained("facebook/sam3").to(device) | |
| MODEL_CACHE["processor"] = Sam3Processor.from_pretrained("facebook/sam3") | |
| MODEL_CACHE["loaded"] = True | |
| print(f"β SAM3 loaded on {device}") | |
| return f"β SAM3 loaded on {device}" | |
| except Exception as e: | |
| print(f"β οΈ SAM3 load failed: {e}") | |
| return f"β οΈ Model load failed: {e}" | |
| # --- Helpers --- | |
| def clean_layer_name(filename): | |
| """ | |
| Converts 'data_blue_nuclei.npz' -> 'Nuclei'. | |
| Removes standard color names and underscores. | |
| """ | |
| raw = os.path.basename(filename).replace("data_", "").replace(".npz", "") | |
| parts = raw.split('_') | |
| colors = { | |
| 'blue', 'green', 'red', 'yellow', 'cyan', 'magenta', | |
| 'orange', 'purple', 'white', 'black', 'gray', 'grey', | |
| 'pink', 'brown', 'lime', 'teal' | |
| } | |
| cleaned_parts = [p for p in parts if p.lower() not in colors] | |
| if not cleaned_parts: | |
| return raw.replace("_", " ").title() | |
| return " ".join(cleaned_parts).title() | |
| def load_excel_data(logs_text): | |
| placeholder = pd.DataFrame({"Status": ["No Data Available"]}) | |
| candidates = glob.glob("/tmp/*.xlsx") + glob.glob("*.xlsx") | |
| if not candidates: | |
| return None, placeholder, placeholder, placeholder | |
| report_file = max(candidates, key=os.path.getmtime) | |
| try: | |
| xls = pd.ExcelFile(report_file, engine='openpyxl') | |
| def process_sheet(sheet_name): | |
| if sheet_name in xls.sheet_names: | |
| df = pd.read_excel(xls, sheet_name) | |
| if not df.empty and len(df.columns) > 0: | |
| df = df.set_index(df.columns[0]).T.reset_index() | |
| df.rename(columns={df.columns[0]: "Metric"}, inplace=True) | |
| return df | |
| return placeholder | |
| morph = process_sheet("Morphology") | |
| spatial = process_sheet("Spatial") | |
| relational = process_sheet("Relational") | |
| return report_file, morph, spatial, relational | |
| except Exception as e: | |
| print(f"β οΈ Error reading Excel: {e}") | |
| return report_file, placeholder, placeholder, placeholder | |
| def get_available_layers(): | |
| files = glob.glob("/tmp/data_*.npz") | |
| layers = [] | |
| for f in files: | |
| layers.append(clean_layer_name(f)) | |
| return sorted(list(set(layers))) | |
| def update_opacity_sliders(layers): | |
| updates = [] | |
| for i in range(4): | |
| if i < len(layers): | |
| layer_name = layers[i] | |
| updates.append(gr.update(visible=True, label=f"{layer_name} Opacity", value=0.6)) | |
| else: | |
| updates.append(gr.update(visible=False)) | |
| return updates | |
| # --- OPTIMIZED OVERLAY GENERATION --- | |
| def generate_overlay(image_path_str, selected_layers, layer_opacities=None, force_reload=False): | |
| """ | |
| Regenerates overlay. | |
| force_reload: If True, clears the layer cache to pick up new files from agent. | |
| """ | |
| if not image_path_str: | |
| return None | |
| # Force reload if requested | |
| if force_reload: | |
| MASK_CACHE["layers"] = {} | |
| # Check cache loading | |
| if MASK_CACHE["current_path"] != image_path_str or MASK_CACHE["base_image"] is None or not MASK_CACHE["layers"]: | |
| print(f"π Caching masks for {os.path.basename(image_path_str)}...") | |
| try: | |
| if MASK_CACHE["current_path"] != image_path_str or MASK_CACHE["base_image"] is None: | |
| base_img = Image.open(image_path_str).convert("RGBA") | |
| MASK_CACHE["base_image"] = base_img | |
| MASK_CACHE["current_path"] = image_path_str | |
| else: | |
| base_img = MASK_CACHE["base_image"] | |
| # Always scan for new layers if we are here | |
| all_layer_files = glob.glob("/tmp/data_*.npz") | |
| base_w, base_h = base_img.size | |
| for file_path in all_layer_files: | |
| layer_name = clean_layer_name(file_path) | |
| # Skip if already cached | |
| if layer_name in MASK_CACHE["layers"]: | |
| continue | |
| try: | |
| data = np.load(file_path) | |
| masks = data['masks'] if 'masks' in data else data[data.files[0]] | |
| if masks.size > 0: | |
| if masks.ndim == 3: | |
| combined_mask = np.max(masks, axis=0) | |
| else: | |
| combined_mask = masks | |
| mask_pil = Image.fromarray(combined_mask.astype(np.uint8) * 255) | |
| if mask_pil.size != (base_w, base_h): | |
| mask_pil = mask_pil.resize((base_w, base_h), Image.Resampling.NEAREST) | |
| MASK_CACHE["layers"][layer_name] = np.array(mask_pil, dtype=bool) | |
| except Exception as e: | |
| print(f"Failed to cache layer {layer_name}: {e}") | |
| except Exception as e: | |
| print(f"Failed to load base image: {e}") | |
| return None | |
| if MASK_CACHE["base_image"] is None: | |
| return None | |
| base_image = MASK_CACHE["base_image"] | |
| overlay_accum = Image.new('RGBA', base_image.size, (0, 0, 0, 0)) | |
| # Ensure selected_layers is iterable even if empty | |
| if selected_layers is None: | |
| selected_layers = [] | |
| all_known_layers = sorted(MASK_CACHE["layers"].keys()) | |
| for layer_name in selected_layers: | |
| if layer_name in MASK_CACHE["layers"]: | |
| mask_bool = MASK_CACHE["layers"][layer_name] | |
| # Use the global generated palette | |
| if layer_name in all_known_layers: | |
| color_idx = all_known_layers.index(layer_name) % len(COLOR_PALETTE) | |
| color = COLOR_PALETTE[color_idx] | |
| else: | |
| color = (255, 255, 0) | |
| opacity = 0.6 | |
| if layer_opacities and layer_name in layer_opacities: | |
| opacity = layer_opacities[layer_name] | |
| layer_rgba = np.zeros((mask_bool.shape[0], mask_bool.shape[1], 4), dtype=np.uint8) | |
| layer_rgba[mask_bool] = (*color, int(255 * opacity)) | |
| layer_img = Image.fromarray(layer_rgba, 'RGBA') | |
| overlay_accum = Image.alpha_composite(overlay_accum, layer_img) | |
| result = Image.alpha_composite(base_image, overlay_accum) | |
| return result.convert("RGB") | |
| # --- Core Logic --- | |
| async def run_analysis(image_path_str, user_prompt, session_id_state): | |
| # FIX: Use gr.skip() for updates to prevent UI jitter during streaming | |
| skipped_updates = [gr.skip()] * 4 | |
| if not MODEL_CACHE["loaded"]: | |
| yield [], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates | |
| load_models() | |
| if not image_path_str: | |
| yield [], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates | |
| return | |
| # Cleanup | |
| for f in glob.glob("/tmp/out_*.png") + glob.glob("/tmp/data_*.npz") + glob.glob("/tmp/*.xlsx"): | |
| try: os.remove(f) | |
| except: pass | |
| # Reset Cache | |
| MASK_CACHE["current_path"] = None | |
| MASK_CACHE["base_image"] = None | |
| MASK_CACHE["layers"] = {} | |
| image_path = Path(image_path_str) | |
| if MODEL_CACHE["model"] is None: | |
| error_msg = "β Model failed to load. Please check logs." | |
| yield [{"role": "assistant", "content": error_msg}], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates | |
| return | |
| if AnalysisDeps is None: | |
| error_msg = "β Project imports failed. 'AnalysisDeps' is missing. Check your 'cellemetry' package installation." | |
| yield [{"role": "assistant", "content": error_msg}], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates | |
| return | |
| try: | |
| deps = AnalysisDeps( | |
| sam_model=MODEL_CACHE["model"], | |
| sam_processor=MODEL_CACHE["processor"], | |
| image_path=image_path, | |
| device=MODEL_CACHE["device"], | |
| pixel_size_microns=None | |
| ) | |
| global ACTIVE_RUNNER | |
| if root_agent is None: | |
| raise ValueError("Root agent is not loaded.") | |
| ACTIVE_RUNNER = InMemoryRunner(agent=root_agent, app_name="cellemetry_demo") | |
| session = await ACTIVE_RUNNER.session_service.create_session( | |
| app_name="cellemetry_demo", | |
| user_id="demo_user", | |
| state=deps.to_state_dict() | |
| ) | |
| session_id = session.id | |
| except Exception as e: | |
| error_msg = f"β Agent Initialization Failed: {str(e)}" | |
| print(error_msg) | |
| yield [{"role": "assistant", "content": error_msg}], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates | |
| return | |
| image_bytes = image_path.read_bytes() | |
| content = types.Content( | |
| role="user", | |
| parts=[ | |
| types.Part.from_text(text=user_prompt), | |
| types.Part.from_bytes(data=image_bytes, mime_type="image/png"), | |
| ] | |
| ) | |
| logs = [f"π **Starting analysis** on {MODEL_CACHE['device']}..."] | |
| display_path = image_path_str.replace(" ", "%20") | |
| def yield_status(log_list): | |
| full_log = "\n\n".join(log_list) | |
| user_msg = f"\n\n{user_prompt}" | |
| return [{"role": "user", "content": user_msg}, {"role": "assistant", "content": full_log}] | |
| # FIX: Yield skips instead of updates | |
| yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates | |
| try: | |
| async for event in ACTIVE_RUNNER.run_async(user_id="demo_user", session_id=session.id, new_message=content): | |
| author = event.author | |
| if event.get_function_calls(): | |
| for fc in event.get_function_calls(): | |
| logs.append(f"π§ **{author}**: Calling `{fc.name}`") | |
| yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates | |
| if event.content and event.content.parts: | |
| for part in event.content.parts: | |
| if hasattr(part, 'text') and part.text: | |
| if event.partial: | |
| if logs and logs[-1].startswith(f"π¬ **{author}**"): | |
| logs[-1] = f"π¬ **{author}**: {part.text}..." | |
| else: | |
| logs.append(f"π¬ **{author}**: {part.text}...") | |
| else: | |
| if logs and logs[-1].startswith(f"π¬ **{author}**"): | |
| logs[-1] = f"β **{author}**: {part.text}" | |
| else: | |
| logs.append(f"β **{author}**: {part.text}") | |
| yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates | |
| except Exception as e: | |
| logs.append(f"β Error: {e}") | |
| yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates | |
| return | |
| logs.append("\nβ **Analysis Complete!** Loading results...") | |
| yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates | |
| await asyncio.sleep(0.5) | |
| full_log_text = "\n".join(logs) | |
| report_file, df_m, df_s, df_r = load_excel_data(full_log_text) | |
| layers = get_available_layers() | |
| initial_overlay = generate_overlay(image_path_str, layers) | |
| completion_msg = f"\n\n---\n\n⨠**Analysis finished!** Found {len(layers)} layer(s). Results are now available in the Segmentation and Quantitative Results tabs." | |
| full_log_text += completion_msg | |
| final_user_msg = f"\n\n{user_prompt}" | |
| final_history = [{"role": "user", "content": final_user_msg}, {"role": "assistant", "content": full_log_text}] | |
| slider_updates = update_opacity_sliders(layers) | |
| # Final yield is the ONLY one with real data | |
| yield final_history, session_id, initial_overlay, gr.CheckboxGroup(choices=layers, value=layers), report_file, df_m, df_s, df_r, *slider_updates | |
| async def unified_chat_handler(message, history, session_id, current_img_path): | |
| if history is None: | |
| history = [] | |
| user_text = message.get("text", "").strip() if isinstance(message, dict) else str(message).strip() | |
| files = message.get("files", []) if isinstance(message, dict) else [] | |
| image_path = None | |
| if files: | |
| image_path = files[0] if isinstance(files[0], str) else files[0].get("path") | |
| elif current_img_path: | |
| image_path = current_img_path | |
| waiting_df = pd.DataFrame({"Status": ["Waiting..."]}) | |
| # FIX: Prepare skips | |
| skipped_updates = [gr.skip()] * 4 | |
| # CASE 1: INITIAL ANALYSIS | |
| if image_path and (not session_id or files): | |
| if not user_text: user_text = "Analyze this microscopy image." | |
| display_path = image_path.replace(" ", "%20") | |
| history.append({"role": "user", "content": f"\n\n{user_text}"}) | |
| history.append({"role": "assistant", "content": "π Starting analysis (Model loading may take a moment)..."}) | |
| # Show Loading, Hide Results | |
| yield history, session_id, image_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *skipped_updates, None, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) | |
| final_result = None | |
| try: | |
| async for result in run_analysis(image_path, user_text, session_id): | |
| final_result = result | |
| updated_history = result[0].copy() | |
| if files and len(updated_history) > 0: | |
| updated_history[0] = {"role": "user", "content": f"\n\n{user_text}"} | |
| # Pass through the skips/data from run_analysis | |
| yield (updated_history, result[1], image_path, *result[2:], None, gr.update(), gr.update(), gr.update()) | |
| except Exception as e: | |
| history.append({"role": "assistant", "content": f"β Critical Error: {str(e)}"}) | |
| yield history, session_id, image_path, None, gr.CheckboxGroup(), None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) | |
| return | |
| if final_result: | |
| updated_history = final_result[0].copy() | |
| if files and len(updated_history) > 0: | |
| updated_history[0] = {"role": "user", "content": f"\n\n{user_text}"} | |
| # Hide Loading, Show Results | |
| yield (updated_history, final_result[1], image_path, *final_result[2:], None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)) | |
| return | |
| # CASE 2: FOLLOW-UP ANALYSIS | |
| elif session_id and user_text: | |
| history.append({"role": "user", "content": user_text}) | |
| history.append({"role": "assistant", "content": "π Thinking..."}) | |
| # Don't show loading overlay for follow-ups | |
| # FIX: Send gr.skip() to all result components to prevent jitter | |
| yield history, session_id, current_img_path, gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(), gr.update(), gr.update() | |
| if not ACTIVE_RUNNER: | |
| history[-1]["content"] = "β οΈ Session expired or Agent not initialized." | |
| yield history, None, current_img_path, gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(), gr.update(), gr.update() | |
| return | |
| content = types.Content(role="user", parts=[types.Part.from_text(text=user_text)]) | |
| accumulated_response = "" | |
| try: | |
| async for event in ACTIVE_RUNNER.run_async(user_id="demo_user", session_id=session_id, new_message=content): | |
| if event.content and event.content.parts: | |
| for part in event.content.parts: | |
| if hasattr(part, 'text') and part.text: | |
| accumulated_response += part.text | |
| history[-1]["content"] = accumulated_response | |
| # FIX: Keep sending gr.skip() during stream | |
| yield history, session_id, current_img_path, gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(), gr.update(), gr.update() | |
| except Exception as e: | |
| history[-1]["content"] = f"β Error: {e}" | |
| yield history, session_id, current_img_path, gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(), gr.update(), gr.update() | |
| return | |
| report_file, df_m, df_s, df_r = load_excel_data("") | |
| layers = get_available_layers() | |
| new_overlay = generate_overlay(current_img_path, layers, force_reload=True) | |
| slider_updates = update_opacity_sliders(layers) | |
| # Final yield updates the components | |
| yield ( | |
| history, | |
| session_id, | |
| current_img_path, | |
| new_overlay, | |
| gr.CheckboxGroup(value=layers, choices=layers), | |
| report_file, | |
| df_m, df_s, df_r, | |
| *slider_updates, | |
| None, | |
| gr.update(), | |
| gr.update(), | |
| gr.update() | |
| ) | |
| return | |
| else: | |
| if not history: | |
| history = [{"role": "assistant", "content": "π Welcome! Upload a microscopy image and describe what you'd like to analyze."}] | |
| else: | |
| history.append({"role": "assistant", "content": "β οΈ Please provide a question or upload a new image."}) | |
| yield history, session_id, current_img_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *skipped_updates, None, gr.update(), gr.update(), gr.update() | |
| # --- UI Layout --- | |
| custom_css = """ | |
| /* 1. Global Margin Setting */ | |
| #main_container { | |
| margin-left: 10% !important; | |
| margin-right: 10% !important; | |
| width: auto !important; | |
| } | |
| /* 2. Fix Panel Width */ | |
| .right-panel { | |
| min-width: 600px !important; | |
| flex-grow: 2 !important; | |
| } | |
| /* 3. Consistent Blue Border for Results Panel */ | |
| .bordered-panel { | |
| border: 2px solid #3498db !important; | |
| border-radius: 8px !important; | |
| padding: 10px !important; | |
| background: #ffffff !important; | |
| } | |
| /* 4. Fix Table Width Overflow */ | |
| .gradio-dataframe td { | |
| white-space: normal !important; | |
| } | |
| .gradio-dataframe { | |
| overflow-x: auto !important; | |
| max-width: 100% !important; | |
| display: block !important; | |
| } | |
| """ | |
| with gr.Blocks(title="Cellemetry Agent", css=custom_css) as demo: | |
| session_id_state = gr.State(None) | |
| current_image_path = gr.State(None) | |
| with gr.Column(elem_id="main_container"): | |
| with gr.Row(): | |
| # --- LEFT COLUMN (Chat) --- | |
| with gr.Column(scale=1, min_width=300): | |
| chatbot = gr.Chatbot( | |
| label="Agent Conversation", | |
| height=400, | |
| value=[{"role": "assistant", "content": "π Welcome to Cellemetry! Upload a microscopy image and describe what you'd like to analyze."}], | |
| show_label=True | |
| ) | |
| chat_input = gr.MultimodalTextbox( | |
| file_types=["image"], | |
| placeholder="Upload an image and describe your analysis...", | |
| show_label=False, | |
| submit_btn="Send" | |
| ) | |
| # --- NEW: Examples Component --- | |
| # NOTE: Ensure you have an 'examples' folder with 'sample_1.png' and 'sample_2.png' | |
| example_data = [ | |
| [{"text": "Analyze this image and describe the cell morphology.", "files": ["examples/sample_1.jpg"]}], | |
| [{"text": "Segment the nuclei and calculate spatial distribution.", "files": ["examples/sample_2.jpg"]}], | |
| ] | |
| gr.Examples( | |
| examples=example_data, | |
| inputs=chat_input, | |
| label="Try an Example", | |
| ) | |
| # --- RIGHT COLUMN (Results) --- | |
| with gr.Column(scale=2, elem_classes=["right-panel"]): | |
| # Welcome overlay | |
| with gr.Column(visible=True, elem_id="welcome-overlay") as welcome_overlay: | |
| gr.HTML(f""" | |
| <div style="display: flex; flex-direction: column; align-items: center; justify-content: center; min-height: 780px; padding: 40px; background: #ffffff; border-radius: 8px; border: 2px solid #3498db;"> | |
| <div style="text-align: center;"> | |
| <div style='text-align: center;'> | |
| <img src="https://raw.githubusercontent.com/hmgill/Cellemetry/main/logo.png" alt="Logo" style="height:200px; display: block; margin: 0 auto;"> | |
| </div> | |
| <h2 style="color: #333; margin: 20px 0 10px; font-weight: 600; font-size: 28px;">Welcome to Cellemetry</h2> | |
| <p style="color: #666; font-size: 16px; max-width: 400px; margin: 0 auto 30px; line-height: 1.6;">Upload a microscopy image to get started with AI-powered cell analysis and segmentation</p> | |
| <div style="padding: 20px; background: #fff; border-radius: 8px; border-left: 4px solid #3498db; box-shadow: 0 2px 4px rgba(0,0,0,0.05);"> | |
| <p style="color: #555; margin: 0; font-size: 14px;">π Use the chat on the left to begin</p> | |
| </div> | |
| </div> | |
| </div> | |
| """) | |
| # Loading overlay | |
| with gr.Column(visible=False, elem_id="loading-overlay") as loading_overlay: | |
| gr.HTML(""" | |
| <div style="display: flex; flex-direction: column; align-items: center; justify-content: center; height: 780px; background: rgba(255, 255, 255, 0.95); border-radius: 8px; box-shadow: 0 4px 6px rgba(0,0,0,0.1);"> | |
| <div style="text-align: center;"> | |
| <div style="border: 8px solid #f3f3f3; border-top: 8px solid #3498db; border-radius: 50%; width: 60px; height: 60px; animation: spin 1s linear infinite; margin: 0 auto 20px;"></div> | |
| <h3 style="color: #555; margin: 0;">βοΈ Analyzing</h3> | |
| <p style="color: #888; margin-top: 10px;">Your image is being processed...</p> | |
| </div> | |
| <style>@keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }</style> | |
| </div> | |
| """) | |
| # Results tabs | |
| with gr.Column(visible=False, elem_classes=["bordered-panel"]) as results_container: | |
| with gr.Tabs() as results_tabs: | |
| with gr.Tab("π Segmentation"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| overlay_output = gr.Image(label="Segmentation Result", height=780, type="pil") | |
| with gr.Column(scale=1): | |
| gr.Markdown("**Layer Controls**") | |
| layer_checkboxes = gr.CheckboxGroup(label="Visible Layers", choices=[], value=[], interactive=True) | |
| gr.Markdown("**Opacity Controls**") | |
| opacity_slider_1 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 1 Opacity", visible=False) | |
| opacity_slider_2 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 2 Opacity", visible=False) | |
| opacity_slider_3 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 3 Opacity", visible=False) | |
| opacity_slider_4 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 4 Opacity", visible=False) | |
| with gr.Tab("π Quantitative Results"): | |
| download_btn = gr.File(label="Download Excel Report") | |
| with gr.Tabs(): | |
| with gr.Tab("Morphology"): | |
| tbl_morph = gr.Dataframe(interactive=False, wrap=True) | |
| with gr.Tab("Spatial"): | |
| tbl_spatial = gr.Dataframe(interactive=False, wrap=True) | |
| with gr.Tab("Relational"): | |
| tbl_rel = gr.Dataframe(interactive=False, wrap=True) | |
| def regenerate_overlay_with_opacity(img_path, selected_layers, op1, op2, op3, op4): | |
| # FIX: Allow empty selected_layers to pass through (returns just the base image) | |
| if not img_path: return None | |
| if selected_layers is None: selected_layers = [] | |
| opacities = {} | |
| opacity_values = [op1, op2, op3, op4] | |
| all_layers = get_available_layers() | |
| for i, layer in enumerate(all_layers[:4]): | |
| opacities[layer] = opacity_values[i] | |
| return generate_overlay(img_path, selected_layers, opacities) | |
| chat_input.submit( | |
| fn=unified_chat_handler, | |
| inputs=[chat_input, chatbot, session_id_state, current_image_path], | |
| outputs=[chatbot, session_id_state, current_image_path, overlay_output, layer_checkboxes, download_btn, tbl_morph, tbl_spatial, tbl_rel, opacity_slider_1, opacity_slider_2, opacity_slider_3, opacity_slider_4, chat_input, welcome_overlay, loading_overlay, results_container] | |
| ) | |
| for component in [layer_checkboxes, opacity_slider_1, opacity_slider_2, opacity_slider_3, opacity_slider_4]: | |
| component.change( | |
| fn=regenerate_overlay_with_opacity, | |
| inputs=[current_image_path, layer_checkboxes, opacity_slider_1, opacity_slider_2, opacity_slider_3, opacity_slider_4], | |
| outputs=[overlay_output] | |
| ) | |
| demo.load(load_models) | |
| if __name__ == "__main__": | |
| demo.queue().launch( | |
| ssr_mode=False, | |
| theme=gr.themes.Soft(), | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| allowed_paths=[".", "/tmp"] | |
| ) |