hmgill commited on
Commit
6c70288
·
verified ·
1 Parent(s): 46544ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -57
app.py CHANGED
@@ -19,7 +19,6 @@ from google.adk.runners import InMemoryRunner
19
  from google.genai import types
20
 
21
  # Project Imports
22
- # Wrap imports to prevent immediate crash if dependencies are missing
23
  try:
24
  from cellemetry import root_agent
25
  from cellemetry.config import AnalysisDeps
@@ -39,12 +38,17 @@ MODEL_CACHE = {
39
  "loaded": False
40
  }
41
 
42
- # Store the active runner globally
43
- ACTIVE_RUNNER = None
 
 
 
 
44
 
 
45
 
46
  def load_models():
47
- """Initialize SAM3 model. Now called AFTER app startup to prevent timeout."""
48
  if MODEL_CACHE["loaded"]:
49
  return
50
 
@@ -53,7 +57,6 @@ def load_models():
53
  MODEL_CACHE["device"] = device
54
 
55
  try:
56
- # Check if imports succeeded
57
  if Sam3Model is None:
58
  raise ImportError("Sam3Model not found. Please check requirements.")
59
 
@@ -84,9 +87,7 @@ def load_excel_data(logs_text):
84
  if sheet_name in xls.sheet_names:
85
  df = pd.read_excel(xls, sheet_name)
86
  if not df.empty and len(df.columns) > 0:
87
- # Set first column as index, transpose, then reset index to make it a column again
88
  df = df.set_index(df.columns[0]).T.reset_index()
89
- # Rename the new first column (formerly the index) to 'Metric'
90
  df.rename(columns={df.columns[0]: "Metric"}, inplace=True)
91
  return df
92
  return placeholder
@@ -118,64 +119,101 @@ def update_opacity_sliders(layers):
118
  updates.append(gr.update(visible=False))
119
  return updates
120
 
 
121
  def generate_overlay(image_path_str, selected_layers, layer_opacities=None):
 
122
  if not image_path_str:
123
  return None
124
 
125
- base_image = Image.open(image_path_str).convert("RGBA")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
 
 
 
 
127
  color_map = {
128
  "green_cell": (0, 255, 0),
129
  "blue_nucleus": (0, 0, 255),
130
  "cell": (0, 255, 0),
131
  "nucleus": (0, 0, 255),
132
  }
133
-
134
- overlay = Image.new('RGBA', base_image.size, (0, 0, 0, 0))
135
-
136
- for layer_name in selected_layers:
137
- file_path = f"/tmp/data_{layer_name}.npz"
138
- if os.path.exists(file_path):
139
- try:
140
- data = np.load(file_path)
141
- masks = data['masks'] if 'masks' in data else data[data.files[0]]
142
-
143
- if masks.size > 0:
144
- if masks.ndim == 3:
145
- combined_mask = np.max(masks, axis=0)
146
- else:
147
- combined_mask = masks
148
-
149
- h_mask, w_mask = combined_mask.shape
150
- if overlay.size != (w_mask, h_mask):
151
- overlay = overlay.resize((w_mask, h_mask), Image.Resampling.LANCZOS)
152
- base_image = base_image.resize((w_mask, h_mask), Image.Resampling.LANCZOS)
153
 
154
- combined_mask = combined_mask.astype(bool)
155
- color = color_map.get(layer_name.lower(), (255, 255, 0))
156
-
157
- opacity = 0.5
158
- if layer_opacities and layer_name in layer_opacities:
159
- opacity = layer_opacities[layer_name]
160
-
161
- mask_overlay = np.zeros((*combined_mask.shape, 4), dtype=np.uint8)
162
- mask_overlay[combined_mask] = (*color, int(255 * opacity))
163
-
164
- mask_image = Image.fromarray(mask_overlay, 'RGBA')
165
- overlay = Image.alpha_composite(overlay, mask_image)
166
- except Exception as e:
167
- print(f"Error loading layer {layer_name}: {e}")
 
 
 
 
 
 
 
168
 
169
- result = Image.alpha_composite(base_image, overlay)
 
170
  return result.convert("RGB")
171
 
172
  # --- Core Logic ---
173
  async def run_analysis(image_path_str, user_prompt, session_id_state):
174
- """Runs analysis. Triggers model load if not yet ready."""
175
  waiting_df = pd.DataFrame({"Status": ["Waiting..."]})
176
  empty_slider_updates = [gr.update()] * 4
177
 
178
- # Lazy Load: Ensure model is loaded before inference
179
  if not MODEL_CACHE["loaded"]:
180
  yield [], None, None, [], None, waiting_df, waiting_df, waiting_df, *empty_slider_updates
181
  load_models()
@@ -184,15 +222,19 @@ async def run_analysis(image_path_str, user_prompt, session_id_state):
184
  yield [], None, None, [], None, waiting_df, waiting_df, waiting_df, *empty_slider_updates
185
  return
186
 
187
- # Cleanup
188
  for f in glob.glob("/tmp/out_*.png") + glob.glob("/tmp/data_*.npz") + glob.glob("/tmp/*.xlsx"):
189
  try: os.remove(f)
190
  except: pass
 
 
 
 
 
191
 
192
  # Setup
193
  image_path = Path(image_path_str)
194
 
195
- # Check if model loaded successfully
196
  if MODEL_CACHE["model"] is None:
197
  error_msg = "❌ Model failed to load. Please check logs."
198
  yield [{"role": "assistant", "content": error_msg}], None, None, [], None, waiting_df, waiting_df, waiting_df, *empty_slider_updates
@@ -229,7 +271,6 @@ async def run_analysis(image_path_str, user_prompt, session_id_state):
229
 
230
  def yield_status(log_list):
231
  full_log = "\n\n".join(log_list)
232
- # Use Markdown to render the image in the chat
233
  user_msg = f"![](file={image_path_str})\n\n{user_prompt}"
234
  return [{"role": "user", "content": user_msg}, {"role": "assistant", "content": full_log}]
235
 
@@ -272,6 +313,8 @@ async def run_analysis(image_path_str, user_prompt, session_id_state):
272
  full_log_text = "\n".join(logs)
273
  report_file, df_m, df_s, df_r = load_excel_data(full_log_text)
274
  layers = get_available_layers()
 
 
275
  initial_overlay = generate_overlay(image_path_str, layers)
276
 
277
  completion_msg = f"\n\n---\n\n✨ **Analysis finished!** Found {len(layers)} layer(s). Results are now available in the Segmentation and Quantitative Results tabs."
@@ -303,7 +346,6 @@ async def unified_chat_handler(message, history, session_id, current_img_path):
303
  if image_path and (not session_id or files):
304
  if not user_text: user_text = "Analyze this microscopy image."
305
 
306
- # Add preview image using Markdown
307
  history.append({"role": "user", "content": f"![](file={image_path})\n\n{user_text}"})
308
  history.append({"role": "assistant", "content": "🔄 Starting analysis (Model loading may take a moment)..."})
309
 
@@ -361,15 +403,11 @@ async def unified_chat_handler(message, history, session_id, current_img_path):
361
 
362
  # CSS to handle margins and width consistency
363
  custom_css = """
364
- /* 1. Global Margin Setting */
365
  #main_container {
366
- margin-left: 20% !important;
367
- margin-right: 20% !important;
368
  width: auto !important;
369
  }
370
-
371
- /* 2. Fix Tab Width Consistency (Right Panel) */
372
- /* We enforce a minimum width so the panel doesn't shrink when switching to empty dataframes */
373
  .right-panel {
374
  min-width: 600px !important;
375
  flex-grow: 2 !important;
@@ -380,7 +418,6 @@ with gr.Blocks(title="Cellemetry Agent", css=custom_css) as demo:
380
  session_id_state = gr.State(None)
381
  current_image_path = gr.State(None)
382
 
383
- # WRAPPER to apply 20% margins
384
  with gr.Column(elem_id="main_container"):
385
 
386
  with gr.Row():
@@ -400,7 +437,6 @@ with gr.Blocks(title="Cellemetry Agent", css=custom_css) as demo:
400
  )
401
 
402
  # --- RIGHT COLUMN (Results) ---
403
- # Added 'right-panel' class for CSS width enforcement
404
  with gr.Column(scale=2, elem_classes=["right-panel"]):
405
 
406
  # Welcome overlay
 
19
  from google.genai import types
20
 
21
  # Project Imports
 
22
  try:
23
  from cellemetry import root_agent
24
  from cellemetry.config import AnalysisDeps
 
38
  "loaded": False
39
  }
40
 
41
+ # OPTIMIZATION: In-memory cache for masks and base image to prevent disk I/O on slider updates
42
+ MASK_CACHE = {
43
+ "current_path": None,
44
+ "base_image": None, # PIL RGBA Image
45
+ "layers": {} # Dict of 'layer_name': numpy_boolean_mask
46
+ }
47
 
48
+ ACTIVE_RUNNER = None
49
 
50
  def load_models():
51
+ """Initialize SAM3 model. Now called AFTER app startup."""
52
  if MODEL_CACHE["loaded"]:
53
  return
54
 
 
57
  MODEL_CACHE["device"] = device
58
 
59
  try:
 
60
  if Sam3Model is None:
61
  raise ImportError("Sam3Model not found. Please check requirements.")
62
 
 
87
  if sheet_name in xls.sheet_names:
88
  df = pd.read_excel(xls, sheet_name)
89
  if not df.empty and len(df.columns) > 0:
 
90
  df = df.set_index(df.columns[0]).T.reset_index()
 
91
  df.rename(columns={df.columns[0]: "Metric"}, inplace=True)
92
  return df
93
  return placeholder
 
119
  updates.append(gr.update(visible=False))
120
  return updates
121
 
122
+ # --- OPTIMIZED OVERLAY GENERATION ---
123
  def generate_overlay(image_path_str, selected_layers, layer_opacities=None):
124
+ """Regenerates overlay using in-memory caching for speed."""
125
  if not image_path_str:
126
  return None
127
 
128
+ # 1. Check if we need to load data into cache
129
+ # We reload if the path changes OR if the cache is empty
130
+ if MASK_CACHE["current_path"] != image_path_str or MASK_CACHE["base_image"] is None:
131
+ print(f"🔄 Caching masks for {os.path.basename(image_path_str)}...")
132
+ try:
133
+ # Load Base Image
134
+ base_img = Image.open(image_path_str).convert("RGBA")
135
+ MASK_CACHE["base_image"] = base_img
136
+ MASK_CACHE["current_path"] = image_path_str
137
+ MASK_CACHE["layers"] = {} # Clear old layers
138
+
139
+ # Pre-load and resize ALL available layers to match base image
140
+ all_layer_files = glob.glob("/tmp/data_*.npz")
141
+ base_w, base_h = base_img.size
142
+
143
+ for file_path in all_layer_files:
144
+ layer_name = os.path.basename(file_path).replace("data_", "").replace(".npz", "")
145
+ try:
146
+ data = np.load(file_path)
147
+ masks = data['masks'] if 'masks' in data else data[data.files[0]]
148
+
149
+ if masks.size > 0:
150
+ if masks.ndim == 3:
151
+ combined_mask = np.max(masks, axis=0)
152
+ else:
153
+ combined_mask = masks
154
+
155
+ # Resize boolean mask to match base image ONCE
156
+ # We use PIL for high-quality resizing of the mask
157
+ mask_pil = Image.fromarray(combined_mask.astype(np.uint8) * 255)
158
+ if mask_pil.size != (base_w, base_h):
159
+ mask_pil = mask_pil.resize((base_w, base_h), Image.Resampling.NEAREST)
160
+
161
+ # Store as boolean numpy array for fast processing
162
+ MASK_CACHE["layers"][layer_name] = np.array(mask_pil, dtype=bool)
163
+ except Exception as e:
164
+ print(f"Failed to cache layer {layer_name}: {e}")
165
+ except Exception as e:
166
+ print(f"Failed to load base image: {e}")
167
+ return None
168
+
169
+ # 2. Fast Composition using Cache
170
+ if MASK_CACHE["base_image"] is None:
171
+ return None
172
 
173
+ base_image = MASK_CACHE["base_image"]
174
+ # Start with a transparent overlay layer
175
+ overlay_accum = Image.new('RGBA', base_image.size, (0, 0, 0, 0))
176
+
177
  color_map = {
178
  "green_cell": (0, 255, 0),
179
  "blue_nucleus": (0, 0, 255),
180
  "cell": (0, 255, 0),
181
  "nucleus": (0, 0, 255),
182
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ # Iterate through requested layers
185
+ for layer_name in selected_layers:
186
+ if layer_name in MASK_CACHE["layers"]:
187
+ mask_bool = MASK_CACHE["layers"][layer_name]
188
+
189
+ # Get settings
190
+ color = color_map.get(layer_name.lower(), (255, 255, 0))
191
+ opacity = 0.6 # Default
192
+ if layer_opacities and layer_name in layer_opacities:
193
+ opacity = layer_opacities[layer_name]
194
+
195
+ # Create a solid color block
196
+ # Optimization: We construct the RGBA buffer directly
197
+ layer_rgba = np.zeros((mask_bool.shape[0], mask_bool.shape[1], 4), dtype=np.uint8)
198
+
199
+ # Apply color and opacity only where mask is True
200
+ layer_rgba[mask_bool] = (*color, int(255 * opacity))
201
+
202
+ # Composite using PIL (fast C implementation)
203
+ layer_img = Image.fromarray(layer_rgba, 'RGBA')
204
+ overlay_accum = Image.alpha_composite(overlay_accum, layer_img)
205
 
206
+ # Final Composite
207
+ result = Image.alpha_composite(base_image, overlay_accum)
208
  return result.convert("RGB")
209
 
210
  # --- Core Logic ---
211
  async def run_analysis(image_path_str, user_prompt, session_id_state):
212
+ """Runs analysis."""
213
  waiting_df = pd.DataFrame({"Status": ["Waiting..."]})
214
  empty_slider_updates = [gr.update()] * 4
215
 
216
+ # Lazy Load
217
  if not MODEL_CACHE["loaded"]:
218
  yield [], None, None, [], None, waiting_df, waiting_df, waiting_df, *empty_slider_updates
219
  load_models()
 
222
  yield [], None, None, [], None, waiting_df, waiting_df, waiting_df, *empty_slider_updates
223
  return
224
 
225
+ # Cleanup Files AND Cache
226
  for f in glob.glob("/tmp/out_*.png") + glob.glob("/tmp/data_*.npz") + glob.glob("/tmp/*.xlsx"):
227
  try: os.remove(f)
228
  except: pass
229
+
230
+ # Reset Cache for new run
231
+ MASK_CACHE["current_path"] = None
232
+ MASK_CACHE["base_image"] = None
233
+ MASK_CACHE["layers"] = {}
234
 
235
  # Setup
236
  image_path = Path(image_path_str)
237
 
 
238
  if MODEL_CACHE["model"] is None:
239
  error_msg = "❌ Model failed to load. Please check logs."
240
  yield [{"role": "assistant", "content": error_msg}], None, None, [], None, waiting_df, waiting_df, waiting_df, *empty_slider_updates
 
271
 
272
  def yield_status(log_list):
273
  full_log = "\n\n".join(log_list)
 
274
  user_msg = f"![](file={image_path_str})\n\n{user_prompt}"
275
  return [{"role": "user", "content": user_msg}, {"role": "assistant", "content": full_log}]
276
 
 
313
  full_log_text = "\n".join(logs)
314
  report_file, df_m, df_s, df_r = load_excel_data(full_log_text)
315
  layers = get_available_layers()
316
+
317
+ # Initial generation (will trigger cache population)
318
  initial_overlay = generate_overlay(image_path_str, layers)
319
 
320
  completion_msg = f"\n\n---\n\n✨ **Analysis finished!** Found {len(layers)} layer(s). Results are now available in the Segmentation and Quantitative Results tabs."
 
346
  if image_path and (not session_id or files):
347
  if not user_text: user_text = "Analyze this microscopy image."
348
 
 
349
  history.append({"role": "user", "content": f"![](file={image_path})\n\n{user_text}"})
350
  history.append({"role": "assistant", "content": "🔄 Starting analysis (Model loading may take a moment)..."})
351
 
 
403
 
404
  # CSS to handle margins and width consistency
405
  custom_css = """
 
406
  #main_container {
407
+ margin-left: 10% !important;
408
+ margin-right: 10% !important;
409
  width: auto !important;
410
  }
 
 
 
411
  .right-panel {
412
  min-width: 600px !important;
413
  flex-grow: 2 !important;
 
418
  session_id_state = gr.State(None)
419
  current_image_path = gr.State(None)
420
 
 
421
  with gr.Column(elem_id="main_container"):
422
 
423
  with gr.Row():
 
437
  )
438
 
439
  # --- RIGHT COLUMN (Results) ---
 
440
  with gr.Column(scale=2, elem_classes=["right-panel"]):
441
 
442
  # Welcome overlay