hmgill commited on
Commit
ef78770
Β·
verified Β·
1 Parent(s): cedd38a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -93
app.py CHANGED
@@ -19,37 +19,52 @@ from google.adk.runners import InMemoryRunner
19
  from google.genai import types
20
 
21
  # Project Imports
22
- from cellemetry import root_agent
23
- from cellemetry.config import AnalysisDeps
24
- from transformers import Sam3Processor, Sam3Model
 
 
 
 
 
 
 
 
25
 
26
  # --- Global State ---
27
  MODEL_CACHE = {
28
  "model": None,
29
  "processor": None,
30
- "device": "cpu"
 
31
  }
32
 
33
- # Store the active runner globally so we can access it in Q&A turns
34
  ACTIVE_RUNNER = None
35
 
36
 
37
  def load_models():
38
- """Initialize SAM3 model."""
39
- if MODEL_CACHE["model"] is not None:
40
  return
 
41
  print("--- Loading SAM3 Model ---")
42
  device = "cuda" if torch.cuda.is_available() else "cpu"
43
  MODEL_CACHE["device"] = device
44
 
45
  try:
 
 
 
 
46
  MODEL_CACHE["model"] = Sam3Model.from_pretrained("facebook/sam3").to(device)
47
  MODEL_CACHE["processor"] = Sam3Processor.from_pretrained("facebook/sam3")
 
48
  print(f"βœ… SAM3 loaded on {device}")
 
49
  except Exception as e:
50
  print(f"⚠️ SAM3 load failed: {e}")
51
-
52
- load_models()
53
 
54
  # --- Helpers ---
55
  def load_excel_data(logs_text):
@@ -65,14 +80,11 @@ def load_excel_data(logs_text):
65
  try:
66
  xls = pd.ExcelFile(report_file, engine='openpyxl')
67
 
68
- # Helper to transpose and fix index so row labels are visible
69
  def process_sheet(sheet_name):
70
  if sheet_name in xls.sheet_names:
71
  df = pd.read_excel(xls, sheet_name)
72
  if not df.empty and len(df.columns) > 0:
73
- # Set first column as index, transpose, then reset index to make it a column again
74
  df = df.set_index(df.columns[0]).T.reset_index()
75
- # Rename the new first column (formerly the index) to 'Metric'
76
  df.rename(columns={df.columns[0]: "Metric"}, inplace=True)
77
  return df
78
  return placeholder
@@ -87,7 +99,6 @@ def load_excel_data(logs_text):
87
  return report_file, placeholder, placeholder, placeholder
88
 
89
  def get_available_layers():
90
- """Scans /tmp for .npz files and returns a list of available layer names."""
91
  files = glob.glob("/tmp/data_*.npz")
92
  layers = []
93
  for f in files:
@@ -96,9 +107,8 @@ def get_available_layers():
96
  return sorted(layers)
97
 
98
  def update_opacity_sliders(layers):
99
- """Returns updated slider configurations based on available layers."""
100
  updates = []
101
- for i in range(4): # We have 4 sliders
102
  if i < len(layers):
103
  layer_name = layers[i].replace("_", " ").title()
104
  updates.append(gr.update(visible=True, label=f"{layer_name} Opacity", value=0.6))
@@ -106,22 +116,12 @@ def update_opacity_sliders(layers):
106
  updates.append(gr.update(visible=False))
107
  return updates
108
 
109
- def collect_layer_opacities(layers, op1, op2, op3, op4):
110
- """Collects opacity values into a dictionary."""
111
- opacities = {}
112
- opacity_values = [op1, op2, op3, op4]
113
- for i, layer in enumerate(layers[:4]): # Only use first 4 layers
114
- opacities[layer] = opacity_values[i]
115
- return opacities
116
-
117
  def generate_overlay(image_path_str, selected_layers, layer_opacities=None):
118
- """Regenerates the overlay image with adjustable opacity for each layer."""
119
  if not image_path_str:
120
  return None
121
 
122
  base_image = Image.open(image_path_str).convert("RGBA")
123
 
124
- # Default colors for different layers (can expand as needed)
125
  color_map = {
126
  "green_cell": (0, 255, 0),
127
  "blue_nucleus": (0, 0, 255),
@@ -129,7 +129,6 @@ def generate_overlay(image_path_str, selected_layers, layer_opacities=None):
129
  "nucleus": (0, 0, 255),
130
  }
131
 
132
- # Create overlay layer
133
  overlay = Image.new('RGBA', base_image.size, (0, 0, 0, 0))
134
 
135
  for layer_name in selected_layers:
@@ -145,58 +144,59 @@ def generate_overlay(image_path_str, selected_layers, layer_opacities=None):
145
  else:
146
  combined_mask = masks
147
 
148
- # Resize if mask dimensions differ from image
149
  h_mask, w_mask = combined_mask.shape
150
  if overlay.size != (w_mask, h_mask):
151
  overlay = overlay.resize((w_mask, h_mask), Image.Resampling.LANCZOS)
152
  base_image = base_image.resize((w_mask, h_mask), Image.Resampling.LANCZOS)
153
 
154
  combined_mask = combined_mask.astype(bool)
 
155
 
156
- # Get color for this layer
157
- color = color_map.get(layer_name.lower(), (255, 255, 0)) # Default to yellow
158
-
159
- # Get opacity (default 0.5)
160
  opacity = 0.5
161
  if layer_opacities and layer_name in layer_opacities:
162
  opacity = layer_opacities[layer_name]
163
 
164
- # Create colored mask with opacity
165
  mask_overlay = np.zeros((*combined_mask.shape, 4), dtype=np.uint8)
166
  mask_overlay[combined_mask] = (*color, int(255 * opacity))
167
 
168
- # Composite onto overlay
169
  mask_image = Image.fromarray(mask_overlay, 'RGBA')
170
  overlay = Image.alpha_composite(overlay, mask_image)
171
-
172
  except Exception as e:
173
  print(f"Error loading layer {layer_name}: {e}")
174
 
175
- # Composite overlay onto base image
176
  result = Image.alpha_composite(base_image, overlay)
177
  return result.convert("RGB")
178
 
179
  # --- Core Logic ---
180
  async def run_analysis(image_path_str, user_prompt, session_id_state):
181
- """
182
- Runs the initial analysis using the Agent.
183
- Updates the global ACTIVE_RUNNER and returns a session ID.
184
- """
185
  waiting_df = pd.DataFrame({"Status": ["Waiting..."]})
186
- empty_slider_updates = [gr.update()] * 4 # Placeholder for 4 sliders
187
 
 
 
 
 
 
 
188
  if not image_path_str:
189
- # Return empty state
190
  yield [], None, None, [], None, waiting_df, waiting_df, waiting_df, *empty_slider_updates
191
  return
192
 
193
- # Cleanup previous run files
194
  for f in glob.glob("/tmp/out_*.png") + glob.glob("/tmp/data_*.npz") + glob.glob("/tmp/*.xlsx"):
195
  try: os.remove(f)
196
  except: pass
197
 
198
- # Setup Dependencies
199
  image_path = Path(image_path_str)
 
 
 
 
 
 
 
200
  deps = AnalysisDeps(
201
  sam_model=MODEL_CACHE["model"],
202
  sam_processor=MODEL_CACHE["processor"],
@@ -205,11 +205,9 @@ async def run_analysis(image_path_str, user_prompt, session_id_state):
205
  pixel_size_microns=None
206
  )
207
 
208
- # Initialize Runner
209
  global ACTIVE_RUNNER
210
  ACTIVE_RUNNER = InMemoryRunner(agent=root_agent, app_name="cellemetry_demo")
211
 
212
- # Create Session
213
  session = await ACTIVE_RUNNER.session_service.create_session(
214
  app_name="cellemetry_demo",
215
  user_id="demo_user",
@@ -217,8 +215,6 @@ async def run_analysis(image_path_str, user_prompt, session_id_state):
217
  )
218
 
219
  session_id = session.id
220
-
221
- # Prepare Input
222
  image_bytes = image_path.read_bytes()
223
  content = types.Content(
224
  role="user",
@@ -230,15 +226,10 @@ async def run_analysis(image_path_str, user_prompt, session_id_state):
230
 
231
  logs = [f"πŸ”„ **Starting analysis** on {MODEL_CACHE['device']}..."]
232
 
233
- # Helper to format output for the chatbot (UPDATED for Gradio 5.0 Messages format)
234
  def yield_status(log_list):
235
  full_log = "\n\n".join(log_list)
236
- # Use Markdown to render the image in the chat
237
  user_msg = f"![](file={image_path_str})\n\n{user_prompt}"
238
- return [
239
- {"role": "user", "content": user_msg},
240
- {"role": "assistant", "content": full_log}
241
- ]
242
 
243
  yield yield_status(logs), session_id, None, [], None, waiting_df, waiting_df, waiting_df, *empty_slider_updates
244
 
@@ -264,7 +255,6 @@ async def run_analysis(image_path_str, user_prompt, session_id_state):
264
  logs[-1] = f"βœ… **{author}**: {part.text}"
265
  else:
266
  logs.append(f"βœ… **{author}**: {part.text}")
267
-
268
  yield yield_status(logs), session_id, None, [], None, waiting_df, waiting_df, waiting_df, *empty_slider_updates
269
 
270
  except Exception as e:
@@ -272,99 +262,203 @@ async def run_analysis(image_path_str, user_prompt, session_id_state):
272
  yield yield_status(logs), session_id, None, [], None, waiting_df, waiting_df, waiting_df, *empty_slider_updates
273
  return
274
 
275
- # Finalize
276
  logs.append("\nβœ… **Analysis Complete!** Loading results...")
277
  yield yield_status(logs), session_id, None, [], None, waiting_df, waiting_df, waiting_df, *empty_slider_updates
278
 
279
  await asyncio.sleep(0.5)
280
 
281
- # Load Data
282
  full_log_text = "\n".join(logs)
283
  report_file, df_m, df_s, df_r = load_excel_data(full_log_text)
284
  layers = get_available_layers()
285
  initial_overlay = generate_overlay(image_path_str, layers)
286
 
287
- # Add completion summary
288
  completion_msg = f"\n\n---\n\n✨ **Analysis finished!** Found {len(layers)} layer(s). Results are now available in the Segmentation and Quantitative Results tabs."
289
  full_log_text += completion_msg
290
 
291
- # Final Yield with all data
292
  final_user_msg = f"![](file={image_path_str})\n\n{user_prompt}"
293
- final_history = [
294
- {"role": "user", "content": final_user_msg},
295
- {"role": "assistant", "content": full_log_text}
296
- ]
297
  slider_updates = update_opacity_sliders(layers)
 
298
  yield final_history, session_id, initial_overlay, gr.CheckboxGroup(choices=layers, value=layers), report_file, df_m, df_s, df_r, *slider_updates
299
 
300
  async def unified_chat_handler(message, history, session_id, current_img_path):
301
- """
302
- Unified handler for both initial analysis and follow-up questions.
303
- message: dict with 'text' and optionally 'files' keys (Gradio MultimodalTextbox format)
304
- """
305
- waiting_df = pd.DataFrame({"Status": ["Waiting..."]})
306
- empty_slider_updates = [gr.update()] * 4
307
-
308
- # Ensure history is a list
309
  if history is None:
310
  history = []
311
 
312
- # Extract text and files from message
313
  user_text = message.get("text", "").strip() if isinstance(message, dict) else str(message).strip()
314
  files = message.get("files", []) if isinstance(message, dict) else []
315
 
316
- # Determine if we have an image
317
  image_path = None
318
  if files:
319
  image_path = files[0] if isinstance(files[0], str) else files[0].get("path")
320
  elif current_img_path:
321
  image_path = current_img_path
322
 
323
- # Case 1: Initial analysis (has image, no session or new image)
 
 
 
324
  if image_path and (not session_id or files):
325
- # This is a new analysis request
326
- if not user_text:
327
- user_text = "Analyze this microscopy image."
328
 
329
- # Add user message with Markdown image preview
330
  history.append({"role": "user", "content": f"![](file={image_path})\n\n{user_text}"})
331
-
332
- # Add loading message
333
- history.append({"role": "assistant", "content": "πŸ”„ Starting analysis..."})
334
 
335
- # Hide welcome, show loading, hide results
336
  yield history, session_id, image_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *empty_slider_updates, None, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
337
 
338
- # Run full analysis
339
  final_result = None
340
  async for result in run_analysis(image_path, user_text, session_id):
341
- # result = (history, session_id, overlay, checkboxes, download, df_m, df_s, df_r, slider1-4)
342
  final_result = result
343
- # Update the history from run_analysis but preserve our image indicator
344
  updated_history = result[0].copy()
345
  if files and len(updated_history) > 0:
346
  updated_history[0] = {"role": "user", "content": f"![](file={image_path})\n\n{user_text}"}
347
-
348
- # Keep loading overlay visible during processing
349
  yield (updated_history, result[1], image_path, *result[2:], None, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False))
350
 
351
- # Final yield: hide welcome, hide loading, show results
352
  if final_result:
353
  updated_history = final_result[0].copy()
354
  if files and len(updated_history) > 0:
355
  updated_history[0] = {"role": "user", "content": f"![](file={image_path})\n\n{user_text}"}
356
  yield (updated_history, final_result[1], image_path, *final_result[2:], None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=True))
357
-
358
  return
359
 
360
- # Case 2: Follow-up question (has session, no new image)
361
  elif session_id and user_text:
362
- # Add user message to history
363
  history.append({"role": "user", "content": user_text})
364
  history.append({"role": "assistant", "content": "πŸ’­ Thinking..."})
365
-
366
- # Keep all overlays in their current state
367
  yield history, session_id, current_img_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *empty_slider_updates, None, gr.update(), gr.update(), gr.update()
368
 
369
  if not ACTIVE_RUNNER:
370
- history[-1]["content"] = "⚠️ Session expired. Please upload an image and start a new analysis."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  from google.genai import types
20
 
21
  # Project Imports
22
+ # Wrap imports to prevent immediate crash if dependencies are missing
23
+ try:
24
+ from cellemetry import root_agent
25
+ from cellemetry.config import AnalysisDeps
26
+ from transformers import Sam3Processor, Sam3Model
27
+ except ImportError as e:
28
+ print(f"⚠️ Import Error (Non-fatal for UI startup): {e}")
29
+ Sam3Model = None
30
+ Sam3Processor = None
31
+ root_agent = None
32
+ AnalysisDeps = None
33
 
34
  # --- Global State ---
35
  MODEL_CACHE = {
36
  "model": None,
37
  "processor": None,
38
+ "device": "cpu",
39
+ "loaded": False
40
  }
41
 
42
+ # Store the active runner globally
43
  ACTIVE_RUNNER = None
44
 
45
 
46
  def load_models():
47
+ """Initialize SAM3 model. Now called AFTER app startup."""
48
+ if MODEL_CACHE["loaded"]:
49
  return
50
+
51
  print("--- Loading SAM3 Model ---")
52
  device = "cuda" if torch.cuda.is_available() else "cpu"
53
  MODEL_CACHE["device"] = device
54
 
55
  try:
56
+ # Check if imports succeeded
57
+ if Sam3Model is None:
58
+ raise ImportError("Sam3Model not found. Please check requirements.")
59
+
60
  MODEL_CACHE["model"] = Sam3Model.from_pretrained("facebook/sam3").to(device)
61
  MODEL_CACHE["processor"] = Sam3Processor.from_pretrained("facebook/sam3")
62
+ MODEL_CACHE["loaded"] = True
63
  print(f"βœ… SAM3 loaded on {device}")
64
+ return f"βœ… SAM3 loaded on {device}"
65
  except Exception as e:
66
  print(f"⚠️ SAM3 load failed: {e}")
67
+ return f"⚠️ Model load failed: {e}"
 
68
 
69
  # --- Helpers ---
70
  def load_excel_data(logs_text):
 
80
  try:
81
  xls = pd.ExcelFile(report_file, engine='openpyxl')
82
 
 
83
  def process_sheet(sheet_name):
84
  if sheet_name in xls.sheet_names:
85
  df = pd.read_excel(xls, sheet_name)
86
  if not df.empty and len(df.columns) > 0:
 
87
  df = df.set_index(df.columns[0]).T.reset_index()
 
88
  df.rename(columns={df.columns[0]: "Metric"}, inplace=True)
89
  return df
90
  return placeholder
 
99
  return report_file, placeholder, placeholder, placeholder
100
 
101
  def get_available_layers():
 
102
  files = glob.glob("/tmp/data_*.npz")
103
  layers = []
104
  for f in files:
 
107
  return sorted(layers)
108
 
109
  def update_opacity_sliders(layers):
 
110
  updates = []
111
+ for i in range(4):
112
  if i < len(layers):
113
  layer_name = layers[i].replace("_", " ").title()
114
  updates.append(gr.update(visible=True, label=f"{layer_name} Opacity", value=0.6))
 
116
  updates.append(gr.update(visible=False))
117
  return updates
118
 
 
 
 
 
 
 
 
 
119
  def generate_overlay(image_path_str, selected_layers, layer_opacities=None):
 
120
  if not image_path_str:
121
  return None
122
 
123
  base_image = Image.open(image_path_str).convert("RGBA")
124
 
 
125
  color_map = {
126
  "green_cell": (0, 255, 0),
127
  "blue_nucleus": (0, 0, 255),
 
129
  "nucleus": (0, 0, 255),
130
  }
131
 
 
132
  overlay = Image.new('RGBA', base_image.size, (0, 0, 0, 0))
133
 
134
  for layer_name in selected_layers:
 
144
  else:
145
  combined_mask = masks
146
 
 
147
  h_mask, w_mask = combined_mask.shape
148
  if overlay.size != (w_mask, h_mask):
149
  overlay = overlay.resize((w_mask, h_mask), Image.Resampling.LANCZOS)
150
  base_image = base_image.resize((w_mask, h_mask), Image.Resampling.LANCZOS)
151
 
152
  combined_mask = combined_mask.astype(bool)
153
+ color = color_map.get(layer_name.lower(), (255, 255, 0))
154
 
 
 
 
 
155
  opacity = 0.5
156
  if layer_opacities and layer_name in layer_opacities:
157
  opacity = layer_opacities[layer_name]
158
 
 
159
  mask_overlay = np.zeros((*combined_mask.shape, 4), dtype=np.uint8)
160
  mask_overlay[combined_mask] = (*color, int(255 * opacity))
161
 
 
162
  mask_image = Image.fromarray(mask_overlay, 'RGBA')
163
  overlay = Image.alpha_composite(overlay, mask_image)
 
164
  except Exception as e:
165
  print(f"Error loading layer {layer_name}: {e}")
166
 
 
167
  result = Image.alpha_composite(base_image, overlay)
168
  return result.convert("RGB")
169
 
170
  # --- Core Logic ---
171
  async def run_analysis(image_path_str, user_prompt, session_id_state):
172
+ """Runs analysis. Triggers model load if not yet ready."""
 
 
 
173
  waiting_df = pd.DataFrame({"Status": ["Waiting..."]})
174
+ empty_slider_updates = [gr.update()] * 4
175
 
176
+ # Lazy Load: Ensure model is loaded before inference
177
+ if not MODEL_CACHE["loaded"]:
178
+ yield [], None, None, [], None, waiting_df, waiting_df, waiting_df, *empty_slider_updates
179
+ # We can't yield a log easily here without breaking the tuple structure, so we just wait
180
+ load_models()
181
+
182
  if not image_path_str:
 
183
  yield [], None, None, [], None, waiting_df, waiting_df, waiting_df, *empty_slider_updates
184
  return
185
 
186
+ # Cleanup
187
  for f in glob.glob("/tmp/out_*.png") + glob.glob("/tmp/data_*.npz") + glob.glob("/tmp/*.xlsx"):
188
  try: os.remove(f)
189
  except: pass
190
 
191
+ # Setup
192
  image_path = Path(image_path_str)
193
+
194
+ # Check if model loaded successfully
195
+ if MODEL_CACHE["model"] is None:
196
+ error_msg = "❌ Model failed to load. Please check logs."
197
+ yield [{"role": "assistant", "content": error_msg}], None, None, [], None, waiting_df, waiting_df, waiting_df, *empty_slider_updates
198
+ return
199
+
200
  deps = AnalysisDeps(
201
  sam_model=MODEL_CACHE["model"],
202
  sam_processor=MODEL_CACHE["processor"],
 
205
  pixel_size_microns=None
206
  )
207
 
 
208
  global ACTIVE_RUNNER
209
  ACTIVE_RUNNER = InMemoryRunner(agent=root_agent, app_name="cellemetry_demo")
210
 
 
211
  session = await ACTIVE_RUNNER.session_service.create_session(
212
  app_name="cellemetry_demo",
213
  user_id="demo_user",
 
215
  )
216
 
217
  session_id = session.id
 
 
218
  image_bytes = image_path.read_bytes()
219
  content = types.Content(
220
  role="user",
 
226
 
227
  logs = [f"πŸ”„ **Starting analysis** on {MODEL_CACHE['device']}..."]
228
 
 
229
  def yield_status(log_list):
230
  full_log = "\n\n".join(log_list)
 
231
  user_msg = f"![](file={image_path_str})\n\n{user_prompt}"
232
+ return [{"role": "user", "content": user_msg}, {"role": "assistant", "content": full_log}]
 
 
 
233
 
234
  yield yield_status(logs), session_id, None, [], None, waiting_df, waiting_df, waiting_df, *empty_slider_updates
235
 
 
255
  logs[-1] = f"βœ… **{author}**: {part.text}"
256
  else:
257
  logs.append(f"βœ… **{author}**: {part.text}")
 
258
  yield yield_status(logs), session_id, None, [], None, waiting_df, waiting_df, waiting_df, *empty_slider_updates
259
 
260
  except Exception as e:
 
262
  yield yield_status(logs), session_id, None, [], None, waiting_df, waiting_df, waiting_df, *empty_slider_updates
263
  return
264
 
 
265
  logs.append("\nβœ… **Analysis Complete!** Loading results...")
266
  yield yield_status(logs), session_id, None, [], None, waiting_df, waiting_df, waiting_df, *empty_slider_updates
267
 
268
  await asyncio.sleep(0.5)
269
 
 
270
  full_log_text = "\n".join(logs)
271
  report_file, df_m, df_s, df_r = load_excel_data(full_log_text)
272
  layers = get_available_layers()
273
  initial_overlay = generate_overlay(image_path_str, layers)
274
 
 
275
  completion_msg = f"\n\n---\n\n✨ **Analysis finished!** Found {len(layers)} layer(s). Results are now available in the Segmentation and Quantitative Results tabs."
276
  full_log_text += completion_msg
277
 
 
278
  final_user_msg = f"![](file={image_path_str})\n\n{user_prompt}"
279
+ final_history = [{"role": "user", "content": final_user_msg}, {"role": "assistant", "content": full_log_text}]
 
 
 
280
  slider_updates = update_opacity_sliders(layers)
281
+
282
  yield final_history, session_id, initial_overlay, gr.CheckboxGroup(choices=layers, value=layers), report_file, df_m, df_s, df_r, *slider_updates
283
 
284
  async def unified_chat_handler(message, history, session_id, current_img_path):
 
 
 
 
 
 
 
 
285
  if history is None:
286
  history = []
287
 
 
288
  user_text = message.get("text", "").strip() if isinstance(message, dict) else str(message).strip()
289
  files = message.get("files", []) if isinstance(message, dict) else []
290
 
 
291
  image_path = None
292
  if files:
293
  image_path = files[0] if isinstance(files[0], str) else files[0].get("path")
294
  elif current_img_path:
295
  image_path = current_img_path
296
 
297
+ waiting_df = pd.DataFrame({"Status": ["Waiting..."]})
298
+ empty_slider_updates = [gr.update()] * 4
299
+
300
+ # Case 1: Initial analysis
301
  if image_path and (not session_id or files):
302
+ if not user_text: user_text = "Analyze this microscopy image."
 
 
303
 
 
304
  history.append({"role": "user", "content": f"![](file={image_path})\n\n{user_text}"})
305
+ history.append({"role": "assistant", "content": "πŸ”„ Starting analysis (Model loading may take a moment)..."})
 
 
306
 
 
307
  yield history, session_id, image_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *empty_slider_updates, None, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
308
 
 
309
  final_result = None
310
  async for result in run_analysis(image_path, user_text, session_id):
 
311
  final_result = result
 
312
  updated_history = result[0].copy()
313
  if files and len(updated_history) > 0:
314
  updated_history[0] = {"role": "user", "content": f"![](file={image_path})\n\n{user_text}"}
 
 
315
  yield (updated_history, result[1], image_path, *result[2:], None, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False))
316
 
 
317
  if final_result:
318
  updated_history = final_result[0].copy()
319
  if files and len(updated_history) > 0:
320
  updated_history[0] = {"role": "user", "content": f"![](file={image_path})\n\n{user_text}"}
321
  yield (updated_history, final_result[1], image_path, *final_result[2:], None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=True))
 
322
  return
323
 
324
+ # Case 2: Follow-up
325
  elif session_id and user_text:
 
326
  history.append({"role": "user", "content": user_text})
327
  history.append({"role": "assistant", "content": "πŸ’­ Thinking..."})
 
 
328
  yield history, session_id, current_img_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *empty_slider_updates, None, gr.update(), gr.update(), gr.update()
329
 
330
  if not ACTIVE_RUNNER:
331
+ history[-1]["content"] = "⚠️ Session expired."
332
+ yield history, None, current_img_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *empty_slider_updates, None, gr.update(), gr.update(), gr.update()
333
+ return
334
+
335
+ content = types.Content(role="user", parts=[types.Part.from_text(text=user_text)])
336
+ accumulated_response = ""
337
+ try:
338
+ async for event in ACTIVE_RUNNER.run_async(user_id="demo_user", session_id=session_id, new_message=content):
339
+ if event.content and event.content.parts:
340
+ for part in event.content.parts:
341
+ if hasattr(part, 'text') and part.text:
342
+ accumulated_response += part.text
343
+ history[-1]["content"] = accumulated_response
344
+ yield history, session_id, current_img_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *empty_slider_updates, None, gr.update(), gr.update(), gr.update()
345
+ except Exception as e:
346
+ history[-1]["content"] = f"❌ Error: {e}"
347
+ yield history, session_id, current_img_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *empty_slider_updates, None, gr.update(), gr.update(), gr.update()
348
+ return
349
+
350
+ else:
351
+ if not history:
352
+ history = [{"role": "assistant", "content": "πŸ‘‹ Welcome! Upload a microscopy image and describe what you'd like to analyze."}]
353
+ else:
354
+ history.append({"role": "assistant", "content": "⚠️ Please provide a question or upload a new image."})
355
+ yield history, session_id, current_img_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *empty_slider_updates, None, gr.update(), gr.update(), gr.update()
356
+
357
+ # --- UI Layout ---
358
+ with gr.Blocks(title="Cellemetry Agent") as demo:
359
+ session_id_state = gr.State(None)
360
+ current_image_path = gr.State(None)
361
+
362
+ with gr.Row():
363
+ with gr.Column(scale=1):
364
+ chatbot = gr.Chatbot(
365
+ label="Agent Conversation",
366
+ height=600,
367
+ value=[{"role": "assistant", "content": "πŸ‘‹ Welcome to Cellemetry! Upload a microscopy image and describe what you'd like to analyze."}],
368
+ show_label=True,
369
+ type="messages"
370
+ )
371
+ chat_input = gr.MultimodalTextbox(
372
+ file_types=["image"],
373
+ placeholder="Upload an image and describe your analysis...",
374
+ show_label=False,
375
+ submit_btn="Send"
376
+ )
377
+
378
+ with gr.Column(scale=2):
379
+ with gr.Column(visible=True, elem_id="welcome-overlay") as welcome_overlay:
380
+ gr.HTML(f"""
381
+ <div style="display: flex; flex-direction: column; align-items: center; justify-content: center; min-height: 780px; padding: 40px; background: #f8f9fa; border-radius: 8px; border: 2px solid #3498db;">
382
+ <div style="text-align: center;">
383
+ <div style='text-align: center;'>
384
+ <img src="https://raw.githubusercontent.com/hmgill/Cellemetry/main/logo.png" alt="Logo" style="height:200px; display: block; margin: 0 auto;">
385
+ </div>
386
+ <h2 style="color: #333; margin: 20px 0 10px; font-weight: 600; font-size: 28px;">Welcome to Cellemetry</h2>
387
+ <p style="color: #666; font-size: 16px; max-width: 400px; margin: 0 auto 30px; line-height: 1.6;">Upload a microscopy image to get started with AI-powered cell analysis and segmentation</p>
388
+ <div style="padding: 20px; background: #fff; border-radius: 8px; border-left: 4px solid #3498db; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
389
+ <p style="color: #555; margin: 0; font-size: 14px;">πŸ‘ˆ Use the chat on the left to begin</p>
390
+ </div>
391
+ </div>
392
+ </div>
393
+ """
394
+ )
395
+
396
+ with gr.Column(visible=False, elem_id="loading-overlay") as loading_overlay:
397
+ gr.HTML("""
398
+ <div style="display: flex; flex-direction: column; align-items: center; justify-content: center; height: 780px; background: rgba(255, 255, 255, 0.95); border-radius: 8px; box-shadow: 0 4px 6px rgba(0,0,0,0.1);">
399
+ <div style="text-align: center;">
400
+ <div style="border: 8px solid #f3f3f3; border-top: 8px solid #3498db; border-radius: 50%; width: 60px; height: 60px; animation: spin 1s linear infinite; margin: 0 auto 20px;"></div>
401
+ <h3 style="color: #555; margin: 0;">βš™οΈ Analysis in Progress</h3>
402
+ <p style="color: #888; margin-top: 10px;">Please wait while we process your microscopy image...</p>
403
+ </div>
404
+ <style>@keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }</style>
405
+ </div>
406
+ """)
407
+
408
+ with gr.Tabs(visible=False) as results_tabs:
409
+ with gr.Tab("πŸ” Segmentation"):
410
+ with gr.Row():
411
+ with gr.Column(scale=3):
412
+ overlay_output = gr.Image(label="Segmentation Result", height=780, type="pil")
413
+ with gr.Column(scale=1):
414
+ gr.Markdown("**Layer Controls**")
415
+ layer_checkboxes = gr.CheckboxGroup(label="Visible Layers", choices=[], value=[], interactive=True)
416
+ gr.Markdown("**Opacity Controls**")
417
+ opacity_slider_1 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 1 Opacity", visible=False)
418
+ opacity_slider_2 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 2 Opacity", visible=False)
419
+ opacity_slider_3 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 3 Opacity", visible=False)
420
+ opacity_slider_4 = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.1, label="Layer 4 Opacity", visible=False)
421
+
422
+ with gr.Tab("πŸ“Š Quantitative Results"):
423
+ download_btn = gr.File(label="Download Excel Report")
424
+ with gr.Tabs():
425
+ with gr.Tab("Morphology"):
426
+ tbl_morph = gr.Dataframe(interactive=False, wrap=True)
427
+ with gr.Tab("Spatial"):
428
+ tbl_spatial = gr.Dataframe(interactive=False, wrap=True)
429
+ with gr.Tab("Relational"):
430
+ tbl_rel = gr.Dataframe(interactive=False, wrap=True)
431
+
432
+ def regenerate_overlay_with_opacity(img_path, selected_layers, op1, op2, op3, op4):
433
+ if not img_path or not selected_layers: return None
434
+ opacities = {}
435
+ opacity_values = [op1, op2, op3, op4]
436
+ all_layers = get_available_layers()
437
+ for i, layer in enumerate(all_layers[:4]):
438
+ opacities[layer] = opacity_values[i]
439
+ return generate_overlay(img_path, selected_layers, opacities)
440
+
441
+ chat_input.submit(
442
+ fn=unified_chat_handler,
443
+ inputs=[chat_input, chatbot, session_id_state, current_image_path],
444
+ outputs=[chatbot, session_id_state, current_image_path, overlay_output, layer_checkboxes, download_btn, tbl_morph, tbl_spatial, tbl_rel, opacity_slider_1, opacity_slider_2, opacity_slider_3, opacity_slider_4, chat_input, welcome_overlay, loading_overlay, results_tabs]
445
+ )
446
+
447
+ for component in [layer_checkboxes, opacity_slider_1, opacity_slider_2, opacity_slider_3, opacity_slider_4]:
448
+ component.change(
449
+ fn=regenerate_overlay_with_opacity,
450
+ inputs=[current_image_path, layer_checkboxes, opacity_slider_1, opacity_slider_2, opacity_slider_3, opacity_slider_4],
451
+ outputs=[overlay_output]
452
+ )
453
+
454
+ # TRIGGER MODEL LOAD AFTER UI LAUNCH
455
+ demo.load(load_models)
456
+
457
+ if __name__ == "__main__":
458
+ demo.queue().launch(
459
+ ssr_mode=False,
460
+ theme=gr.themes.Soft(),
461
+ server_name="0.0.0.0",
462
+ server_port=7860,
463
+ allowed_paths=["."]
464
+ )