Cellemetry / app.py
hmgill's picture
Update app.py
eb998aa verified
import gradio as gr
import asyncio
import os
import glob
import torch
import sys
import builtins
import pandas as pd
import numpy as np
import time
from pathlib import Path
from PIL import Image
# --- Safe Input Mocking ---
builtins.input = lambda *args: "y"
# GenAI & ADK Imports
from google.adk.runners import InMemoryRunner
from google.genai import types
# Project Imports
try:
from cellemetry import root_agent
from cellemetry.config import AnalysisDeps
from transformers import Sam3Processor, Sam3Model
except ImportError as e:
print(f"⚠️ Import Error (Non-fatal for UI startup): {e}")
Sam3Model = None
Sam3Processor = None
root_agent = None
AnalysisDeps = None
# Optional: Distinctipy for better colors
try:
from distinctipy import distinctipy
except ImportError:
distinctipy = None
print("⚠️ distinctipy not found. Using fallback colors.")
# --- Global State ---
MODEL_CACHE = {
"model": None,
"processor": None,
"device": "cpu",
"loaded": False
}
MASK_CACHE = {
"current_path": None,
"base_image": None,
"layers": {}
}
ACTIVE_RUNNER = None
# --- Dynamic Color Helper ---
def generate_color_palette(n=50):
"""Generates a palette of N distinct colors [0-255]."""
if distinctipy:
print(f"🎨 Generating {n} distinct colors using distinctipy...")
colors = distinctipy.get_colors(n)
return [tuple(int(c * 255) for c in color) for color in colors]
try:
import matplotlib.pyplot as plt
cmap = plt.get_cmap('tab20')
return [tuple(int(c * 255) for c in cmap(i % 20)[:3]) for i in range(n)]
except Exception:
pass
return [
(0, 255, 0), (0, 0, 255), (255, 0, 0), (255, 255, 0),
(0, 255, 255), (255, 0, 255), (255, 128, 0), (128, 0, 255),
(0, 128, 0), (0, 0, 128), (128, 0, 0), (128, 128, 0)
] * (n // 12 + 1)
COLOR_PALETTE = generate_color_palette(50)
def load_models():
"""Initialize SAM3 model. Now called AFTER app startup."""
if MODEL_CACHE["loaded"]:
return
print("--- Loading SAM3 Model ---")
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_CACHE["device"] = device
try:
if Sam3Model is None:
raise ImportError("Sam3Model not found. Please check requirements.")
MODEL_CACHE["model"] = Sam3Model.from_pretrained("facebook/sam3").to(device)
MODEL_CACHE["processor"] = Sam3Processor.from_pretrained("facebook/sam3")
MODEL_CACHE["loaded"] = True
print(f"βœ… SAM3 loaded on {device}")
return f"βœ… SAM3 loaded on {device}"
except Exception as e:
print(f"⚠️ SAM3 load failed: {e}")
return f"⚠️ Model load failed: {e}"
# --- Helpers ---
def clean_layer_name(filename):
"""
Converts 'data_blue_nuclei.npz' -> 'Nuclei'.
Removes standard color names and underscores.
"""
raw = os.path.basename(filename).replace("data_", "").replace(".npz", "")
parts = raw.split('_')
colors = {
'blue', 'green', 'red', 'yellow', 'cyan', 'magenta',
'orange', 'purple', 'white', 'black', 'gray', 'grey',
'pink', 'brown', 'lime', 'teal'
}
cleaned_parts = [p for p in parts if p.lower() not in colors]
if not cleaned_parts:
return raw.replace("_", " ").title()
return " ".join(cleaned_parts).title()
def load_excel_data(logs_text):
placeholder = pd.DataFrame({"Status": ["No Data Available"]})
candidates = glob.glob("/tmp/*.xlsx") + glob.glob("*.xlsx")
if not candidates:
return None, placeholder, placeholder, placeholder
report_file = max(candidates, key=os.path.getmtime)
try:
xls = pd.ExcelFile(report_file, engine='openpyxl')
def process_sheet(sheet_name):
if sheet_name in xls.sheet_names:
df = pd.read_excel(xls, sheet_name)
if not df.empty and len(df.columns) > 0:
df = df.set_index(df.columns[0]).T.reset_index()
df.rename(columns={df.columns[0]: "Metric"}, inplace=True)
return df
return placeholder
morph = process_sheet("Morphology")
spatial = process_sheet("Spatial")
relational = process_sheet("Relational")
return report_file, morph, spatial, relational
except Exception as e:
print(f"⚠️ Error reading Excel: {e}")
return report_file, placeholder, placeholder, placeholder
def get_available_layers():
files = glob.glob("/tmp/data_*.npz")
layers = []
for f in files:
layers.append(clean_layer_name(f))
return sorted(list(set(layers)))
def update_opacity_sliders(layers):
updates = []
for i in range(4):
if i < len(layers):
layer_name = layers[i]
updates.append(gr.update(visible=True, label=f"{layer_name} Opacity", value=0.6))
else:
updates.append(gr.update(visible=False))
return updates
# --- OPTIMIZED OVERLAY GENERATION ---
def generate_overlay(image_path_str, selected_layers, layer_opacities=None, force_reload=False):
"""
Regenerates overlay.
force_reload: If True, clears the layer cache to pick up new files from agent.
"""
if not image_path_str:
return None
# Force reload if requested
if force_reload:
MASK_CACHE["layers"] = {}
# Check cache loading
if MASK_CACHE["current_path"] != image_path_str or MASK_CACHE["base_image"] is None or not MASK_CACHE["layers"]:
print(f"πŸ”„ Caching masks for {os.path.basename(image_path_str)}...")
try:
if MASK_CACHE["current_path"] != image_path_str or MASK_CACHE["base_image"] is None:
base_img = Image.open(image_path_str).convert("RGBA")
MASK_CACHE["base_image"] = base_img
MASK_CACHE["current_path"] = image_path_str
else:
base_img = MASK_CACHE["base_image"]
# Always scan for new layers if we are here
all_layer_files = glob.glob("/tmp/data_*.npz")
base_w, base_h = base_img.size
for file_path in all_layer_files:
layer_name = clean_layer_name(file_path)
# Skip if already cached
if layer_name in MASK_CACHE["layers"]:
continue
try:
data = np.load(file_path)
masks = data['masks'] if 'masks' in data else data[data.files[0]]
if masks.size > 0:
if masks.ndim == 3:
combined_mask = np.max(masks, axis=0)
else:
combined_mask = masks
mask_pil = Image.fromarray(combined_mask.astype(np.uint8) * 255)
if mask_pil.size != (base_w, base_h):
mask_pil = mask_pil.resize((base_w, base_h), Image.Resampling.NEAREST)
MASK_CACHE["layers"][layer_name] = np.array(mask_pil, dtype=bool)
except Exception as e:
print(f"Failed to cache layer {layer_name}: {e}")
except Exception as e:
print(f"Failed to load base image: {e}")
return None
if MASK_CACHE["base_image"] is None:
return None
base_image = MASK_CACHE["base_image"]
overlay_accum = Image.new('RGBA', base_image.size, (0, 0, 0, 0))
# Ensure selected_layers is iterable even if empty
if selected_layers is None:
selected_layers = []
all_known_layers = sorted(MASK_CACHE["layers"].keys())
for layer_name in selected_layers:
if layer_name in MASK_CACHE["layers"]:
mask_bool = MASK_CACHE["layers"][layer_name]
# Use the global generated palette
if layer_name in all_known_layers:
color_idx = all_known_layers.index(layer_name) % len(COLOR_PALETTE)
color = COLOR_PALETTE[color_idx]
else:
color = (255, 255, 0)
opacity = 0.6
if layer_opacities and layer_name in layer_opacities:
opacity = layer_opacities[layer_name]
layer_rgba = np.zeros((mask_bool.shape[0], mask_bool.shape[1], 4), dtype=np.uint8)
layer_rgba[mask_bool] = (*color, int(255 * opacity))
layer_img = Image.fromarray(layer_rgba, 'RGBA')
overlay_accum = Image.alpha_composite(overlay_accum, layer_img)
result = Image.alpha_composite(base_image, overlay_accum)
return result.convert("RGB")
# --- Core Logic ---
async def run_analysis(image_path_str, user_prompt, session_id_state):
# FIX: Use gr.skip() for updates to prevent UI jitter during streaming
skipped_updates = [gr.skip()] * 4
if not MODEL_CACHE["loaded"]:
yield [], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates
load_models()
if not image_path_str:
yield [], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates
return
# Cleanup
for f in glob.glob("/tmp/out_*.png") + glob.glob("/tmp/data_*.npz") + glob.glob("/tmp/*.xlsx"):
try: os.remove(f)
except: pass
# Reset Cache
MASK_CACHE["current_path"] = None
MASK_CACHE["base_image"] = None
MASK_CACHE["layers"] = {}
image_path = Path(image_path_str)
if MODEL_CACHE["model"] is None:
error_msg = "❌ Model failed to load. Please check logs."
yield [{"role": "assistant", "content": error_msg}], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates
return
if AnalysisDeps is None:
error_msg = "❌ Project imports failed. 'AnalysisDeps' is missing. Check your 'cellemetry' package installation."
yield [{"role": "assistant", "content": error_msg}], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates
return
try:
deps = AnalysisDeps(
sam_model=MODEL_CACHE["model"],
sam_processor=MODEL_CACHE["processor"],
image_path=image_path,
device=MODEL_CACHE["device"],
pixel_size_microns=None
)
global ACTIVE_RUNNER
if root_agent is None:
raise ValueError("Root agent is not loaded.")
ACTIVE_RUNNER = InMemoryRunner(agent=root_agent, app_name="cellemetry_demo")
session = await ACTIVE_RUNNER.session_service.create_session(
app_name="cellemetry_demo",
user_id="demo_user",
state=deps.to_state_dict()
)
session_id = session.id
except Exception as e:
error_msg = f"❌ Agent Initialization Failed: {str(e)}"
print(error_msg)
yield [{"role": "assistant", "content": error_msg}], None, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates
return
image_bytes = image_path.read_bytes()
content = types.Content(
role="user",
parts=[
types.Part.from_text(text=user_prompt),
types.Part.from_bytes(data=image_bytes, mime_type="image/png"),
]
)
logs = [f"πŸ”„ **Starting analysis** on {MODEL_CACHE['device']}..."]
display_path = image_path_str.replace(" ", "%20")
def yield_status(log_list):
full_log = "\n\n".join(log_list)
user_msg = f"![](file={display_path})\n\n{user_prompt}"
return [{"role": "user", "content": user_msg}, {"role": "assistant", "content": full_log}]
# FIX: Yield skips instead of updates
yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates
try:
async for event in ACTIVE_RUNNER.run_async(user_id="demo_user", session_id=session.id, new_message=content):
author = event.author
if event.get_function_calls():
for fc in event.get_function_calls():
logs.append(f"πŸ”§ **{author}**: Calling `{fc.name}`")
yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates
if event.content and event.content.parts:
for part in event.content.parts:
if hasattr(part, 'text') and part.text:
if event.partial:
if logs and logs[-1].startswith(f"πŸ’¬ **{author}**"):
logs[-1] = f"πŸ’¬ **{author}**: {part.text}..."
else:
logs.append(f"πŸ’¬ **{author}**: {part.text}...")
else:
if logs and logs[-1].startswith(f"πŸ’¬ **{author}**"):
logs[-1] = f"βœ… **{author}**: {part.text}"
else:
logs.append(f"βœ… **{author}**: {part.text}")
yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates
except Exception as e:
logs.append(f"❌ Error: {e}")
yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates
return
logs.append("\nβœ… **Analysis Complete!** Loading results...")
yield yield_status(logs), session_id, None, [], None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates
await asyncio.sleep(0.5)
full_log_text = "\n".join(logs)
report_file, df_m, df_s, df_r = load_excel_data(full_log_text)
layers = get_available_layers()
initial_overlay = generate_overlay(image_path_str, layers)
completion_msg = f"\n\n---\n\n✨ **Analysis finished!** Found {len(layers)} layer(s). Results are now available in the Segmentation and Quantitative Results tabs."
full_log_text += completion_msg
final_user_msg = f"![](file={display_path})\n\n{user_prompt}"
final_history = [{"role": "user", "content": final_user_msg}, {"role": "assistant", "content": full_log_text}]
slider_updates = update_opacity_sliders(layers)
# Final yield is the ONLY one with real data
yield final_history, session_id, initial_overlay, gr.CheckboxGroup(choices=layers, value=layers), report_file, df_m, df_s, df_r, *slider_updates
async def unified_chat_handler(message, history, session_id, current_img_path):
if history is None:
history = []
user_text = message.get("text", "").strip() if isinstance(message, dict) else str(message).strip()
files = message.get("files", []) if isinstance(message, dict) else []
image_path = None
if files:
image_path = files[0] if isinstance(files[0], str) else files[0].get("path")
elif current_img_path:
image_path = current_img_path
waiting_df = pd.DataFrame({"Status": ["Waiting..."]})
# FIX: Prepare skips
skipped_updates = [gr.skip()] * 4
# CASE 1: INITIAL ANALYSIS
if image_path and (not session_id or files):
if not user_text: user_text = "Analyze this microscopy image."
display_path = image_path.replace(" ", "%20")
history.append({"role": "user", "content": f"![](file={display_path})\n\n{user_text}"})
history.append({"role": "assistant", "content": "πŸ”„ Starting analysis (Model loading may take a moment)..."})
# Show Loading, Hide Results
yield history, session_id, image_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *skipped_updates, None, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
final_result = None
try:
async for result in run_analysis(image_path, user_text, session_id):
final_result = result
updated_history = result[0].copy()
if files and len(updated_history) > 0:
updated_history[0] = {"role": "user", "content": f"![](file={display_path})\n\n{user_text}"}
# Pass through the skips/data from run_analysis
yield (updated_history, result[1], image_path, *result[2:], None, gr.update(), gr.update(), gr.update())
except Exception as e:
history.append({"role": "assistant", "content": f"❌ Critical Error: {str(e)}"})
yield history, session_id, image_path, None, gr.CheckboxGroup(), None, gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
return
if final_result:
updated_history = final_result[0].copy()
if files and len(updated_history) > 0:
updated_history[0] = {"role": "user", "content": f"![](file={display_path})\n\n{user_text}"}
# Hide Loading, Show Results
yield (updated_history, final_result[1], image_path, *final_result[2:], None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=True))
return
# CASE 2: FOLLOW-UP ANALYSIS
elif session_id and user_text:
history.append({"role": "user", "content": user_text})
history.append({"role": "assistant", "content": "πŸ’­ Thinking..."})
# Don't show loading overlay for follow-ups
# FIX: Send gr.skip() to all result components to prevent jitter
yield history, session_id, current_img_path, gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(), gr.update(), gr.update()
if not ACTIVE_RUNNER:
history[-1]["content"] = "⚠️ Session expired or Agent not initialized."
yield history, None, current_img_path, gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(), gr.update(), gr.update()
return
content = types.Content(role="user", parts=[types.Part.from_text(text=user_text)])
accumulated_response = ""
try:
async for event in ACTIVE_RUNNER.run_async(user_id="demo_user", session_id=session_id, new_message=content):
if event.content and event.content.parts:
for part in event.content.parts:
if hasattr(part, 'text') and part.text:
accumulated_response += part.text
history[-1]["content"] = accumulated_response
# FIX: Keep sending gr.skip() during stream
yield history, session_id, current_img_path, gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(), gr.update(), gr.update()
except Exception as e:
history[-1]["content"] = f"❌ Error: {e}"
yield history, session_id, current_img_path, gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), *skipped_updates, None, gr.update(), gr.update(), gr.update()
return
report_file, df_m, df_s, df_r = load_excel_data("")
layers = get_available_layers()
new_overlay = generate_overlay(current_img_path, layers, force_reload=True)
slider_updates = update_opacity_sliders(layers)
# Final yield updates the components
yield (
history,
session_id,
current_img_path,
new_overlay,
gr.CheckboxGroup(value=layers, choices=layers),
report_file,
df_m, df_s, df_r,
*slider_updates,
None,
gr.update(),
gr.update(),
gr.update()
)
return
else:
if not history:
history = [{"role": "assistant", "content": "πŸ‘‹ Welcome! Upload a microscopy image and describe what you'd like to analyze."}]
else:
history.append({"role": "assistant", "content": "⚠️ Please provide a question or upload a new image."})
yield history, session_id, current_img_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *skipped_updates, None, gr.update(), gr.update(), gr.update()
# --- UI Layout ---
custom_css = """
/* 1. Global Margin Setting */
#main_container {
margin-left: 10% !important;
margin-right: 10% !important;
width: auto !important;
}
/* 2. Fix Panel Width */
.right-panel {
min-width: 600px !important;
flex-grow: 2 !important;
}
/* 3. Consistent Blue Border for Results Panel */
.bordered-panel {
border: 2px solid #3498db !important;
border-radius: 8px !important;
padding: 10px !important;
background: #ffffff !important;
}
/* 4. Fix Table Width Overflow */
.gradio-dataframe td {
white-space: normal !important;
}
.gradio-dataframe {
overflow-x: auto !important;
max-width: 100% !important;
display: block !important;
}
"""
with gr.Blocks(title="Cellemetry Agent", css=custom_css) as demo:
session_id_state = gr.State(None)
current_image_path = gr.State(None)
with gr.Column(elem_id="main_container"):
with gr.Row():
# --- LEFT COLUMN (Chat) ---
with gr.Column(scale=1, min_width=300):
chatbot = gr.Chatbot(
label="Agent Conversation",
height=400,
value=[{"role": "assistant", "content": "πŸ‘‹ Welcome to Cellemetry! Upload a microscopy image and describe what you'd like to analyze."}],
show_label=True
)
chat_input = gr.MultimodalTextbox(
file_types=["image"],
placeholder="Upload an image and describe your analysis...",
show_label=False,
submit_btn="Send"
)
# --- NEW: Examples Component ---
# NOTE: Ensure you have an 'examples' folder with 'sample_1.png' and 'sample_2.png'
example_data = [
[{"text": "Analyze this image and describe the cell morphology.", "files": ["examples/sample_1.jpg"]}],
[{"text": "Segment the nuclei and calculate spatial distribution.", "files": ["examples/sample_2.jpg"]}],
]
gr.Examples(
examples=example_data,
inputs=chat_input,
label="Try an Example",
)
# --- RIGHT COLUMN (Results) ---
with gr.Column(scale=2, elem_classes=["right-panel"]):
# Welcome overlay
with gr.Column(visible=True, elem_id="welcome-overlay") as welcome_overlay:
gr.HTML(f"""
<div style="display: flex; flex-direction: column; align-items: center; justify-content: center; min-height: 780px; padding: 40px; background: #ffffff; border-radius: 8px; border: 2px solid #3498db;">
<div style="text-align: center;">
<div style='text-align: center;'>
<img src="https://raw.githubusercontent.com/hmgill/Cellemetry/main/logo.png" alt="Logo" style="height:200px; display: block; margin: 0 auto;">
</div>
<h2 style="color: #333; margin: 20px 0 10px; font-weight: 600; font-size: 28px;">Welcome to Cellemetry</h2>
<p style="color: #666; font-size: 16px; max-width: 400px; margin: 0 auto 30px; line-height: 1.6;">Upload a microscopy image to get started with AI-powered cell analysis and segmentation</p>
<div style="padding: 20px; background: #fff; border-radius: 8px; border-left: 4px solid #3498db; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
<p style="color: #555; margin: 0; font-size: 14px;">πŸ‘ˆ Use the chat on the left to begin</p>
</div>
</div>
</div>
""")
# Loading overlay
with gr.Column(visible=False, elem_id="loading-overlay") as loading_overlay:
gr.HTML("""
<div style="display: flex; flex-direction: column; align-items: center; justify-content: center; height: 780px; background: rgba(255, 255, 255, 0.95); border-radius: 8px; box-shadow: 0 4px 6px rgba(0,0,0,0.1);">
<div style="text-align: center;">
<div style="border: 8px solid #f3f3f3; border-top: 8px solid #3498db; border-radius: 50%; width: 60px; height: 60px; animation: spin 1s linear infinite; margin: 0 auto 20px;"></div>
<h3 style="color: #555; margin: 0;">βš™οΈ Analyzing</h3>
<p style="color: #888; margin-top: 10px;">Your image is being processed...</p>
</div>
<style>@keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }</style>
</div>
""")
# Results tabs
with gr.Column(visible=False, elem_classes=["bordered-panel"]) as results_container:
with gr.Tabs() as results_tabs:
with gr.Tab("πŸ” Segmentation"):
with gr.Row():
with gr.Column(scale=3):
overlay_output = gr.Image(label="Segmentation Result", height=780, type="pil")
with gr.Column(scale=1):
gr.Markdown("**Layer Controls**")
layer_checkboxes = gr.CheckboxGroup(label="Visible Layers", choices=[], value=[], interactive=True)
gr.Markdown("**Opacity Controls**")
opacity_slider_1 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 1 Opacity", visible=False)
opacity_slider_2 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 2 Opacity", visible=False)
opacity_slider_3 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 3 Opacity", visible=False)
opacity_slider_4 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 4 Opacity", visible=False)
with gr.Tab("πŸ“Š Quantitative Results"):
download_btn = gr.File(label="Download Excel Report")
with gr.Tabs():
with gr.Tab("Morphology"):
tbl_morph = gr.Dataframe(interactive=False, wrap=True)
with gr.Tab("Spatial"):
tbl_spatial = gr.Dataframe(interactive=False, wrap=True)
with gr.Tab("Relational"):
tbl_rel = gr.Dataframe(interactive=False, wrap=True)
def regenerate_overlay_with_opacity(img_path, selected_layers, op1, op2, op3, op4):
# FIX: Allow empty selected_layers to pass through (returns just the base image)
if not img_path: return None
if selected_layers is None: selected_layers = []
opacities = {}
opacity_values = [op1, op2, op3, op4]
all_layers = get_available_layers()
for i, layer in enumerate(all_layers[:4]):
opacities[layer] = opacity_values[i]
return generate_overlay(img_path, selected_layers, opacities)
chat_input.submit(
fn=unified_chat_handler,
inputs=[chat_input, chatbot, session_id_state, current_image_path],
outputs=[chatbot, session_id_state, current_image_path, overlay_output, layer_checkboxes, download_btn, tbl_morph, tbl_spatial, tbl_rel, opacity_slider_1, opacity_slider_2, opacity_slider_3, opacity_slider_4, chat_input, welcome_overlay, loading_overlay, results_container]
)
for component in [layer_checkboxes, opacity_slider_1, opacity_slider_2, opacity_slider_3, opacity_slider_4]:
component.change(
fn=regenerate_overlay_with_opacity,
inputs=[current_image_path, layer_checkboxes, opacity_slider_1, opacity_slider_2, opacity_slider_3, opacity_slider_4],
outputs=[overlay_output]
)
demo.load(load_models)
if __name__ == "__main__":
demo.queue().launch(
ssr_mode=False,
theme=gr.themes.Soft(),
server_name="0.0.0.0",
server_port=7860,
allowed_paths=[".", "/tmp"]
)