hmgill commited on
Commit
1b3ec1b
·
verified ·
1 Parent(s): 42166b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -27
app.py CHANGED
@@ -30,6 +30,13 @@ except ImportError as e:
30
  root_agent = None
31
  AnalysisDeps = None
32
 
 
 
 
 
 
 
 
33
  # --- Global State ---
34
  MODEL_CACHE = {
35
  "model": None,
@@ -38,7 +45,6 @@ MODEL_CACHE = {
38
  "loaded": False
39
  }
40
 
41
- # In-memory cache for masks and base image to prevent disk I/O on slider updates
42
  MASK_CACHE = {
43
  "current_path": None,
44
  "base_image": None,
@@ -47,18 +53,28 @@ MASK_CACHE = {
47
 
48
  ACTIVE_RUNNER = None
49
 
50
- # FIX: Dynamic Color Palette (Removed hard-coded "nucleus" keys)
51
- # Colors: Green, Blue, Red, Yellow, Cyan, Magenta, Orange, Purple
52
- COLOR_PALETTE = [
53
- (0, 255, 0), # Green
54
- (0, 100, 255), # Blue
55
- (255, 0, 0), # Red
56
- (255, 255, 0), # Yellow
57
- (0, 255, 255), # Cyan
58
- (255, 0, 255), # Magenta
59
- (255, 128, 0), # Orange
60
- (128, 0, 255) # Purple
61
- ]
 
 
 
 
 
 
 
 
 
 
62
 
63
  def load_models():
64
  """Initialize SAM3 model. Now called AFTER app startup."""
@@ -84,7 +100,6 @@ def load_models():
84
 
85
  # --- Helpers ---
86
  def load_excel_data(logs_text):
87
- """Finds and loads the Excel report, transposing for better display."""
88
  placeholder = pd.DataFrame({"Status": ["No Data Available"]})
89
  candidates = glob.glob("/tmp/*.xlsx") + glob.glob("*.xlsx")
90
 
@@ -133,24 +148,40 @@ def update_opacity_sliders(layers):
133
  return updates
134
 
135
  # --- OPTIMIZED OVERLAY GENERATION ---
136
- def generate_overlay(image_path_str, selected_layers, layer_opacities=None):
 
 
 
 
137
  if not image_path_str:
138
  return None
139
 
140
- # Check cache
141
- if MASK_CACHE["current_path"] != image_path_str or MASK_CACHE["base_image"] is None:
 
 
 
 
142
  print(f"🔄 Caching masks for {os.path.basename(image_path_str)}...")
143
  try:
144
- base_img = Image.open(image_path_str).convert("RGBA")
145
- MASK_CACHE["base_image"] = base_img
146
- MASK_CACHE["current_path"] = image_path_str
147
- MASK_CACHE["layers"] = {}
 
 
148
 
 
149
  all_layer_files = glob.glob("/tmp/data_*.npz")
150
  base_w, base_h = base_img.size
151
 
152
  for file_path in all_layer_files:
153
  layer_name = os.path.basename(file_path).replace("data_", "").replace(".npz", "")
 
 
 
 
 
154
  try:
155
  data = np.load(file_path)
156
  masks = data['masks'] if 'masks' in data else data[data.files[0]]
@@ -178,20 +209,18 @@ def generate_overlay(image_path_str, selected_layers, layer_opacities=None):
178
  base_image = MASK_CACHE["base_image"]
179
  overlay_accum = Image.new('RGBA', base_image.size, (0, 0, 0, 0))
180
 
181
- # Get all available layers to ensure consistent coloring regardless of selection
182
  all_known_layers = sorted(MASK_CACHE["layers"].keys())
183
 
184
  for layer_name in selected_layers:
185
  if layer_name in MASK_CACHE["layers"]:
186
  mask_bool = MASK_CACHE["layers"][layer_name]
187
 
188
- # FIX: Assign color based on sorted index instead of hard-coded map
189
- # This ensures "Mitochondria" gets a color even if we didn't plan for it
190
  if layer_name in all_known_layers:
191
  color_idx = all_known_layers.index(layer_name) % len(COLOR_PALETTE)
192
  color = COLOR_PALETTE[color_idx]
193
  else:
194
- color = (255, 255, 0) # Fallback Yellow
195
 
196
  opacity = 0.6
197
  if layer_opacities and layer_name in layer_opacities:
@@ -336,12 +365,14 @@ async def unified_chat_handler(message, history, session_id, current_img_path):
336
  waiting_df = pd.DataFrame({"Status": ["Waiting..."]})
337
  empty_slider_updates = [gr.update()] * 4
338
 
 
339
  if image_path and (not session_id or files):
340
  if not user_text: user_text = "Analyze this microscopy image."
341
 
342
  history.append({"role": "user", "content": f"![](file={image_path})\n\n{user_text}"})
343
  history.append({"role": "assistant", "content": "🔄 Starting analysis (Model loading may take a moment)..."})
344
 
 
345
  yield history, session_id, image_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *empty_slider_updates, None, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
346
 
347
  final_result = None
@@ -350,15 +381,20 @@ async def unified_chat_handler(message, history, session_id, current_img_path):
350
  updated_history = result[0].copy()
351
  if files and len(updated_history) > 0:
352
  updated_history[0] = {"role": "user", "content": f"![](file={image_path})\n\n{user_text}"}
 
 
353
  yield (updated_history, result[1], image_path, *result[2:], None, gr.update(), gr.update(), gr.update())
354
 
355
  if final_result:
356
  updated_history = final_result[0].copy()
357
  if files and len(updated_history) > 0:
358
  updated_history[0] = {"role": "user", "content": f"![](file={image_path})\n\n{user_text}"}
 
 
359
  yield (updated_history, final_result[1], image_path, *final_result[2:], None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=True))
360
  return
361
 
 
362
  elif session_id and user_text:
363
  history.append({"role": "user", "content": user_text})
364
  history.append({"role": "assistant", "content": "💭 Thinking..."})
@@ -371,6 +407,8 @@ async def unified_chat_handler(message, history, session_id, current_img_path):
371
 
372
  content = types.Content(role="user", parts=[types.Part.from_text(text=user_text)])
373
  accumulated_response = ""
 
 
374
  try:
375
  async for event in ACTIVE_RUNNER.run_async(user_id="demo_user", session_id=session_id, new_message=content):
376
  if event.content and event.content.parts:
@@ -382,6 +420,32 @@ async def unified_chat_handler(message, history, session_id, current_img_path):
382
  except Exception as e:
383
  history[-1]["content"] = f"❌ Error: {e}"
384
  yield history, session_id, current_img_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *empty_slider_updates, None, gr.update(), gr.update(), gr.update()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  return
386
 
387
  else:
@@ -396,8 +460,8 @@ async def unified_chat_handler(message, history, session_id, current_img_path):
396
  custom_css = """
397
  /* 1. Global Margin Setting */
398
  #main_container {
399
- margin-left: 10% !important;
400
- margin-right: 10% !important;
401
  width: auto !important;
402
  }
403
 
 
30
  root_agent = None
31
  AnalysisDeps = None
32
 
33
+ # Optional: Distinctipy for better colors
34
+ try:
35
+ from distinctipy import distinctipy
36
+ except ImportError:
37
+ distinctipy = None
38
+ print("⚠️ distinctipy not found. Using fallback colors.")
39
+
40
  # --- Global State ---
41
  MODEL_CACHE = {
42
  "model": None,
 
45
  "loaded": False
46
  }
47
 
 
48
  MASK_CACHE = {
49
  "current_path": None,
50
  "base_image": None,
 
53
 
54
  ACTIVE_RUNNER = None
55
 
56
+ # --- Dynamic Color Helper ---
57
+ def generate_color_palette(n=50):
58
+ """Generates a palette of N distinct colors [0-255]."""
59
+ if distinctipy:
60
+ print(f"🎨 Generating {n} distinct colors using distinctipy...")
61
+ colors = distinctipy.get_colors(n)
62
+ return [tuple(int(c * 255) for c in color) for color in colors]
63
+
64
+ try:
65
+ import matplotlib.pyplot as plt
66
+ cmap = plt.get_cmap('tab20')
67
+ return [tuple(int(c * 255) for c in cmap(i % 20)[:3]) for i in range(n)]
68
+ except Exception:
69
+ pass
70
+
71
+ return [
72
+ (0, 255, 0), (0, 0, 255), (255, 0, 0), (255, 255, 0),
73
+ (0, 255, 255), (255, 0, 255), (255, 128, 0), (128, 0, 255),
74
+ (0, 128, 0), (0, 0, 128), (128, 0, 0), (128, 128, 0)
75
+ ] * (n // 12 + 1)
76
+
77
+ COLOR_PALETTE = generate_color_palette(50)
78
 
79
  def load_models():
80
  """Initialize SAM3 model. Now called AFTER app startup."""
 
100
 
101
  # --- Helpers ---
102
  def load_excel_data(logs_text):
 
103
  placeholder = pd.DataFrame({"Status": ["No Data Available"]})
104
  candidates = glob.glob("/tmp/*.xlsx") + glob.glob("*.xlsx")
105
 
 
148
  return updates
149
 
150
  # --- OPTIMIZED OVERLAY GENERATION ---
151
+ def generate_overlay(image_path_str, selected_layers, layer_opacities=None, force_reload=False):
152
+ """
153
+ Regenerates overlay.
154
+ force_reload: If True, clears the layer cache to pick up new files from agent.
155
+ """
156
  if not image_path_str:
157
  return None
158
 
159
+ # Force reload if requested (e.g., after follow-up analysis)
160
+ if force_reload:
161
+ MASK_CACHE["layers"] = {}
162
+
163
+ # Check cache loading
164
+ if MASK_CACHE["current_path"] != image_path_str or MASK_CACHE["base_image"] is None or not MASK_CACHE["layers"]:
165
  print(f"🔄 Caching masks for {os.path.basename(image_path_str)}...")
166
  try:
167
+ if MASK_CACHE["current_path"] != image_path_str or MASK_CACHE["base_image"] is None:
168
+ base_img = Image.open(image_path_str).convert("RGBA")
169
+ MASK_CACHE["base_image"] = base_img
170
+ MASK_CACHE["current_path"] = image_path_str
171
+ else:
172
+ base_img = MASK_CACHE["base_image"]
173
 
174
+ # Always scan for new layers if we are here
175
  all_layer_files = glob.glob("/tmp/data_*.npz")
176
  base_w, base_h = base_img.size
177
 
178
  for file_path in all_layer_files:
179
  layer_name = os.path.basename(file_path).replace("data_", "").replace(".npz", "")
180
+
181
+ # Skip if already cached
182
+ if layer_name in MASK_CACHE["layers"]:
183
+ continue
184
+
185
  try:
186
  data = np.load(file_path)
187
  masks = data['masks'] if 'masks' in data else data[data.files[0]]
 
209
  base_image = MASK_CACHE["base_image"]
210
  overlay_accum = Image.new('RGBA', base_image.size, (0, 0, 0, 0))
211
 
 
212
  all_known_layers = sorted(MASK_CACHE["layers"].keys())
213
 
214
  for layer_name in selected_layers:
215
  if layer_name in MASK_CACHE["layers"]:
216
  mask_bool = MASK_CACHE["layers"][layer_name]
217
 
218
+ # Use the global generated palette
 
219
  if layer_name in all_known_layers:
220
  color_idx = all_known_layers.index(layer_name) % len(COLOR_PALETTE)
221
  color = COLOR_PALETTE[color_idx]
222
  else:
223
+ color = (255, 255, 0)
224
 
225
  opacity = 0.6
226
  if layer_opacities and layer_name in layer_opacities:
 
365
  waiting_df = pd.DataFrame({"Status": ["Waiting..."]})
366
  empty_slider_updates = [gr.update()] * 4
367
 
368
+ # CASE 1: INITIAL ANALYSIS
369
  if image_path and (not session_id or files):
370
  if not user_text: user_text = "Analyze this microscopy image."
371
 
372
  history.append({"role": "user", "content": f"![](file={image_path})\n\n{user_text}"})
373
  history.append({"role": "assistant", "content": "🔄 Starting analysis (Model loading may take a moment)..."})
374
 
375
+ # Yield 1: Set initial visibility state ONCE
376
  yield history, session_id, image_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *empty_slider_updates, None, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
377
 
378
  final_result = None
 
381
  updated_history = result[0].copy()
382
  if files and len(updated_history) > 0:
383
  updated_history[0] = {"role": "user", "content": f"![](file={image_path})\n\n{user_text}"}
384
+
385
+ # Yield Loop: Pass gr.update() to prevent flickering
386
  yield (updated_history, result[1], image_path, *result[2:], None, gr.update(), gr.update(), gr.update())
387
 
388
  if final_result:
389
  updated_history = final_result[0].copy()
390
  if files and len(updated_history) > 0:
391
  updated_history[0] = {"role": "user", "content": f"![](file={image_path})\n\n{user_text}"}
392
+
393
+ # Yield Final: Show Results
394
  yield (updated_history, final_result[1], image_path, *final_result[2:], None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=True))
395
  return
396
 
397
+ # CASE 2: FOLLOW-UP ANALYSIS
398
  elif session_id and user_text:
399
  history.append({"role": "user", "content": user_text})
400
  history.append({"role": "assistant", "content": "💭 Thinking..."})
 
407
 
408
  content = types.Content(role="user", parts=[types.Part.from_text(text=user_text)])
409
  accumulated_response = ""
410
+
411
+ # 1. Stream response text
412
  try:
413
  async for event in ACTIVE_RUNNER.run_async(user_id="demo_user", session_id=session_id, new_message=content):
414
  if event.content and event.content.parts:
 
420
  except Exception as e:
421
  history[-1]["content"] = f"❌ Error: {e}"
422
  yield history, session_id, current_img_path, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *empty_slider_updates, None, gr.update(), gr.update(), gr.update()
423
+ return
424
+
425
+ # 2. REFRESH DATA (Tables, Overlays, Layers)
426
+ # Scan for potential new files created by the agent
427
+ report_file, df_m, df_s, df_r = load_excel_data("")
428
+ layers = get_available_layers()
429
+
430
+ # Force overlay generation with new layers (using force_reload=True to clear cache)
431
+ new_overlay = generate_overlay(current_img_path, layers, force_reload=True)
432
+ slider_updates = update_opacity_sliders(layers)
433
+
434
+ # 3. Yield FINAL update with new data
435
+ yield (
436
+ history,
437
+ session_id,
438
+ current_img_path,
439
+ new_overlay, # Updated Image
440
+ gr.CheckboxGroup(value=layers, choices=layers), # Updated Layers
441
+ report_file, # Updated Excel
442
+ df_m, df_s, df_r, # Updated Tables
443
+ *slider_updates, # Updated Sliders
444
+ None, # Clear input
445
+ gr.update(), # Welcome (no change)
446
+ gr.update(), # Loading (no change)
447
+ gr.update() # Results (no change)
448
+ )
449
  return
450
 
451
  else:
 
460
  custom_css = """
461
  /* 1. Global Margin Setting */
462
  #main_container {
463
+ margin-left: 20% !important;
464
+ margin-right: 20% !important;
465
  width: auto !important;
466
  }
467