hmgill commited on
Commit
42bf28c
Β·
verified Β·
1 Parent(s): a8b0d0d

Upload 41 files

Browse files
Files changed (41) hide show
  1. agents/__init__.py +0 -0
  2. agents/__pycache__/__init__.cpython-311.pyc +0 -0
  3. agents/__pycache__/agent.cpython-311.pyc +0 -0
  4. agents/agent.py +225 -0
  5. agents/agent.py~ +224 -0
  6. app.py +184 -10
  7. config/__init__.py +0 -0
  8. config/__pycache__/__init__.cpython-311.pyc +0 -0
  9. config/__pycache__/settings.cpython-311.pyc +0 -0
  10. config/settings.py +60 -0
  11. config/settings.py~ +60 -0
  12. models/__init__.py +0 -0
  13. models/__pycache__/__init__.cpython-311.pyc +0 -0
  14. models/__pycache__/embeddings.cpython-311.pyc +0 -0
  15. models/__pycache__/reranker.cpython-311.pyc +0 -0
  16. models/embeddings.py +36 -0
  17. models/reranker.py +29 -0
  18. stores/__init__.py +15 -0
  19. stores/__pycache__/__init__.cpython-311.pyc +0 -0
  20. stores/__pycache__/chroma_store.cpython-311.pyc +0 -0
  21. stores/__pycache__/neo4j_store.cpython-311.pyc +0 -0
  22. stores/chroma_store.py +22 -0
  23. stores/chroma_store.py~ +22 -0
  24. stores/neo4j_store.py +69 -0
  25. tools/__init__.py +35 -0
  26. tools/__pycache__/__init__.cpython-311.pyc +0 -0
  27. tools/__pycache__/search.cpython-311.pyc +0 -0
  28. tools/__pycache__/segmentation.cpython-311.pyc +0 -0
  29. tools/search.py +101 -0
  30. tools/segmentation.py +532 -0
  31. tools/segmentation.py~ +531 -0
  32. utils/__init__.py +24 -0
  33. utils/__init__.py~ +23 -0
  34. utils/__pycache__/__init__.cpython-311.pyc +0 -0
  35. utils/__pycache__/gpu.cpython-311.pyc +0 -0
  36. utils/__pycache__/image_utils.cpython-311.pyc +0 -0
  37. utils/__pycache__/prechecks.cpython-311.pyc +0 -0
  38. utils/gpu.py +80 -0
  39. utils/image_utils.py +31 -0
  40. utils/prechecks.py +44 -0
  41. utils/prechecks.py~ +44 -0
agents/__init__.py ADDED
File without changes
agents/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (206 Bytes). View file
 
agents/__pycache__/agent.cpython-311.pyc ADDED
Binary file (12.4 kB). View file
 
agents/agent.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CellposeAgent with proper VLM configuration
3
+ """
4
+ import torch
5
+ import json
6
+ from datetime import datetime
7
+ from PIL import Image
8
+ from smolagents import ToolCallingAgent, InferenceClientModel
9
+ from smolagents.agents import ActionStep
10
+ from langfuse import get_client, observe
11
+
12
+ from config import settings
13
+ from utils.gpu import clear_gpu_cache
14
+ from tools import all_tools
15
+
16
+
17
+ langfuse = get_client()
18
+
19
+
20
+ class CellposeAgent:
21
+
22
+ @staticmethod
23
+ def attach_images_callback(step_log: ActionStep, agent: ToolCallingAgent) -> None:
24
+ """
25
+ Callback to attach actual PIL images for VLM inspection.
26
+ Images are automatically resized to reduce token consumption.
27
+ """
28
+ if not isinstance(step_log, ActionStep):
29
+ return
30
+
31
+ if not step_log.observations:
32
+ return
33
+
34
+ def resize_image(img: Image.Image, max_size: int = 1024) -> Image.Image:
35
+ """Resize image maintaining aspect ratio, max dimension = max_size."""
36
+ if max(img.size) <= max_size:
37
+ return img
38
+
39
+ ratio = max_size / max(img.size)
40
+ new_size = tuple(int(dim * ratio) for dim in img.size)
41
+ resized = img.resize(new_size, Image.Resampling.LANCZOS)
42
+ print(f" Resized {img.size} β†’ {resized.size}")
43
+ return resized
44
+
45
+ try:
46
+ obs_data = json.loads(step_log.observations)
47
+
48
+ # Pattern 1: Single image from get_segmentation_parameters
49
+ if obs_data.get("status") == "success" and "image_path" in obs_data:
50
+ image_path = obs_data["image_path"]
51
+ print(f"[Callback] Attaching image: {image_path}")
52
+
53
+ try:
54
+ img = Image.open(image_path)
55
+ resized_img = resize_image(img)
56
+
57
+ # Attach resized PIL Image
58
+ step_log.observations_images = [resized_img]
59
+
60
+ # Keep metadata for context
61
+ obs_data["image_info"] = {
62
+ "original_dimensions": f"{img.size[0]}x{img.size[1]} pixels",
63
+ "resized_dimensions": f"{resized_img.size[0]}x{resized_img.size[1]} pixels",
64
+ "mode": resized_img.mode,
65
+ "note": "Image attached for visual inspection (resized for efficiency)"
66
+ }
67
+ step_log.observations = json.dumps(obs_data, indent=2)
68
+ print(f"[Callback] βœ“ Attached resized image for VLM inspection")
69
+ except Exception as e:
70
+ print(f"[Callback] Error attaching image: {e}")
71
+
72
+ # Pattern 2: Multiple images from refine_segmentation
73
+
74
+ elif obs_data.get("status") == "ready_for_visual_analysis":
75
+ paths = obs_data.get("image_paths", {})
76
+ original = paths.get("original")
77
+ segmented = paths.get("segmented")
78
+
79
+ if original and segmented:
80
+ print(f"[Callback] Attaching both original and segmented images")
81
+ try:
82
+ orig_img = Image.open(original)
83
+ seg_img = Image.open(segmented)
84
+
85
+ # Resize both images
86
+ resized_orig = resize_image(orig_img)
87
+ resized_seg = resize_image(seg_img)
88
+
89
+ # Attach both resized images as list
90
+ step_log.observations_images = [resized_orig, resized_seg]
91
+
92
+ obs_data["images_info"] = {
93
+ "image_order": ["original", "segmented"],
94
+ "original_size": f"{orig_img.size[0]}x{orig_img.size[1]}",
95
+ "resized_size": f"{resized_orig.size[0]}x{resized_orig.size[1]}",
96
+ "note": "Both images attached for visual comparison (resized for efficiency)"
97
+ }
98
+ step_log.observations = json.dumps(obs_data, indent=2)
99
+ print(f"[Callback] βœ“ Attached both resized images for VLM inspection")
100
+ except Exception as e:
101
+ print(f"[Callback] Error attaching images: {e}")
102
+
103
+ except json.JSONDecodeError:
104
+ pass
105
+ except Exception as e:
106
+ print(f"[Callback] Error in attach_images_callback: {e}")
107
+
108
+
109
+ @staticmethod
110
+ def manage_image_memory(step_log: ActionStep, agent: ToolCallingAgent) -> None:
111
+ """
112
+ Aggressive memory management: keep ONLY the last step's images.
113
+ All previous steps have their images cleared immediately.
114
+ """
115
+ if not isinstance(step_log, ActionStep):
116
+ return
117
+
118
+ current_step = step_log.step_number
119
+
120
+ # Clear images from ALL previous steps (keeping only current)
121
+ for previous_step in agent.memory.steps:
122
+ if isinstance(previous_step, ActionStep) and \
123
+ previous_step.step_number < current_step:
124
+ if previous_step.observations_images is not None:
125
+ print(f" [Memory] Clearing images from step {previous_step.step_number}")
126
+ previous_step.observations_images = None
127
+
128
+
129
+ def __init__(self):
130
+ self.instructions = """
131
+ You are an assistant for the cellpose-sam segmentation tool.
132
+
133
+ ## PRIMARY WORKFLOW - IMAGE SEGMENTATION
134
+
135
+ When a user provides an image:
136
+ 1. use appropriate tools to review which cellpose-sam parameters are available.
137
+ 2. use the tool: `get_segmentation_parameters`
138
+ - **IMPORTANT**: After this tool runs, you will receive image metadata (dimensions, properties)
139
+ - Use this information to reason about appropriate parameter values
140
+ 3. carefully analyze the image metadata and matched parameters:
141
+ - consider cell density based on image dimensions
142
+ - compare matched parameter values to image characteristics
143
+ - consider if adjustments would likely improve the segmentation
144
+ 4. Be conservative: if you make changes, assess if they should differ significantly from the original values
145
+ 5. Provide your final parameter recommendations in a clear, structured format
146
+ 6. Use the parameters to run cellpose_sam through the tool: run_cellpose_sam
147
+ 7. after run_cellpose_sam, call the tool: refine_cellpose_sam_segmentation
148
+ - **IMPORTANT**: After this tool runs, you will receive metadata about both original and segmented images
149
+ - Use the provided information to assess segmentation quality
150
+ 8. Based on the metadata and any quality metrics returned:
151
+ - Identify potential segmentation issues based on reported metrics
152
+ - If refinement is needed, use knowledge graph and RAG tools to understand parameter effects
153
+ - Decide which parameters to adjust based on the segmentation analysis
154
+ - Re-run run_cellpose_sam with adjusted parameters
155
+ **CRITICAL: Call refine_cellpose_sam_segmentation AT MOST 2 TIMES total**
156
+ - First call: Check initial segmentation quality
157
+ - Second call (if needed): Verify refinement improved results
158
+ - NEVER call it a third time - always stop after 2 refinement checks
159
+
160
+ ## DOCUMENTATION QUERY WORKFLOW ##
161
+
162
+ - "What is X": use `search_documentation_vector`
163
+ - "How does X affect Y": use `search_knowledge_graph`
164
+ - Complex analysis: use `hybrid_search`
165
+ - Parameter relationships: use `get_parameter_relationships`
166
+
167
+ ## RESPONSE STYLE ##
168
+ - Be concise and actionable
169
+ - Always explain your reasoning when adjusting parameters
170
+ - If keeping original matched parameters, briefly confirm why it's appropriate
171
+ - Base your decisions on the metadata and metrics provided by the tools
172
+
173
+ """
174
+
175
+ self.model = self._initialize_model()
176
+ self.agent = self._create_agent()
177
+
178
+
179
+ def _initialize_model(self):
180
+ """Initializes the TransformersModel for the agent with VLM support."""
181
+ clear_gpu_cache()
182
+
183
+ return InferenceClientModel(
184
+ model_id=settings.AGENT_MODEL_ID,
185
+ token = settings.HF_TOKEN
186
+ )
187
+
188
+
189
+
190
+ def _create_agent(self):
191
+ """Creates the ToolCallingAgent with all available tools and memory management."""
192
+ return ToolCallingAgent(
193
+ model=self.model,
194
+ tools=all_tools,
195
+ instructions=self.instructions,
196
+ max_steps=10,
197
+ step_callbacks=[
198
+ self.attach_images_callback,
199
+ self.manage_image_memory,
200
+ ]
201
+ )
202
+
203
+ @observe()
204
+ def run(self, task: str):
205
+ """Runs the agent on a given task with Langfuse tracing."""
206
+ print(f"\n{'='*60}\nTASK: {task}\n{'='*60}")
207
+
208
+ langfuse.update_current_trace(
209
+ input={"task": task},
210
+ user_id="user_001",
211
+ tags=["rag", "cellpose", "knowledge-graph", "vision"],
212
+ metadata={"agent_type": "ToolCallingAgent", "model_id": settings.AGENT_MODEL_ID}
213
+ )
214
+
215
+ try:
216
+ final_answer = self.agent.run(task)
217
+ print("\n--- Final Answer from Agent ---\n", final_answer)
218
+ langfuse.update_current_trace(output={"final_answer": final_answer})
219
+ return final_answer
220
+ except Exception as e:
221
+ print(f"Agent run failed: {e}")
222
+ langfuse.update_current_trace(output={"error": str(e)})
223
+ raise
224
+ finally:
225
+ clear_gpu_cache()
agents/agent.py~ ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CellposeAgent with proper VLM configuration
3
+ """
4
+ import torch
5
+ import json
6
+ from datetime import datetime
7
+ from PIL import Image
8
+ from smolagents import ToolCallingAgent, InferenceClientModel
9
+ from smolagents.agents import ActionStep
10
+ from langfuse import get_client, observe
11
+
12
+ from config import settings
13
+ from utils.gpu import clear_gpu_cache
14
+ from tools import all_tools
15
+
16
+
17
+ langfuse = get_client()
18
+
19
+
20
+ class CellposeAgent:
21
+
22
+ @staticmethod
23
+ def attach_images_callback(step_log: ActionStep, agent: ToolCallingAgent) -> None:
24
+ """
25
+ Callback to attach actual PIL images for VLM inspection.
26
+ Images are automatically resized to reduce token consumption.
27
+ """
28
+ if not isinstance(step_log, ActionStep):
29
+ return
30
+
31
+ if not step_log.observations:
32
+ return
33
+
34
+ def resize_image(img: Image.Image, max_size: int = 1024) -> Image.Image:
35
+ """Resize image maintaining aspect ratio, max dimension = max_size."""
36
+ if max(img.size) <= max_size:
37
+ return img
38
+
39
+ ratio = max_size / max(img.size)
40
+ new_size = tuple(int(dim * ratio) for dim in img.size)
41
+ resized = img.resize(new_size, Image.Resampling.LANCZOS)
42
+ print(f" Resized {img.size} β†’ {resized.size}")
43
+ return resized
44
+
45
+ try:
46
+ obs_data = json.loads(step_log.observations)
47
+
48
+ # Pattern 1: Single image from get_segmentation_parameters
49
+ if obs_data.get("status") == "success" and "image_path" in obs_data:
50
+ image_path = obs_data["image_path"]
51
+ print(f"[Callback] Attaching image: {image_path}")
52
+
53
+ try:
54
+ img = Image.open(image_path)
55
+ resized_img = resize_image(img)
56
+
57
+ # Attach resized PIL Image
58
+ step_log.observations_images = [resized_img]
59
+
60
+ # Keep metadata for context
61
+ obs_data["image_info"] = {
62
+ "original_dimensions": f"{img.size[0]}x{img.size[1]} pixels",
63
+ "resized_dimensions": f"{resized_img.size[0]}x{resized_img.size[1]} pixels",
64
+ "mode": resized_img.mode,
65
+ "note": "Image attached for visual inspection (resized for efficiency)"
66
+ }
67
+ step_log.observations = json.dumps(obs_data, indent=2)
68
+ print(f"[Callback] βœ“ Attached resized image for VLM inspection")
69
+ except Exception as e:
70
+ print(f"[Callback] Error attaching image: {e}")
71
+
72
+ # Pattern 2: Multiple images from refine_segmentation
73
+
74
+ elif obs_data.get("status") == "ready_for_visual_analysis":
75
+ paths = obs_data.get("image_paths", {})
76
+ original = paths.get("original")
77
+ segmented = paths.get("segmented")
78
+
79
+ if original and segmented:
80
+ print(f"[Callback] Attaching both original and segmented images")
81
+ try:
82
+ orig_img = Image.open(original)
83
+ seg_img = Image.open(segmented)
84
+
85
+ # Resize both images
86
+ resized_orig = resize_image(orig_img)
87
+ resized_seg = resize_image(seg_img)
88
+
89
+ # Attach both resized images as list
90
+ step_log.observations_images = [resized_orig, resized_seg]
91
+
92
+ obs_data["images_info"] = {
93
+ "image_order": ["original", "segmented"],
94
+ "original_size": f"{orig_img.size[0]}x{orig_img.size[1]}",
95
+ "resized_size": f"{resized_orig.size[0]}x{resized_orig.size[1]}",
96
+ "note": "Both images attached for visual comparison (resized for efficiency)"
97
+ }
98
+ step_log.observations = json.dumps(obs_data, indent=2)
99
+ print(f"[Callback] βœ“ Attached both resized images for VLM inspection")
100
+ except Exception as e:
101
+ print(f"[Callback] Error attaching images: {e}")
102
+
103
+ except json.JSONDecodeError:
104
+ pass
105
+ except Exception as e:
106
+ print(f"[Callback] Error in attach_images_callback: {e}")
107
+
108
+
109
+ @staticmethod
110
+ def manage_image_memory(step_log: ActionStep, agent: ToolCallingAgent) -> None:
111
+ """
112
+ Aggressive memory management: keep ONLY the last step's images.
113
+ All previous steps have their images cleared immediately.
114
+ """
115
+ if not isinstance(step_log, ActionStep):
116
+ return
117
+
118
+ current_step = step_log.step_number
119
+
120
+ # Clear images from ALL previous steps (keeping only current)
121
+ for previous_step in agent.memory.steps:
122
+ if isinstance(previous_step, ActionStep) and \
123
+ previous_step.step_number < current_step:
124
+ if previous_step.observations_images is not None:
125
+ print(f" [Memory] Clearing images from step {previous_step.step_number}")
126
+ previous_step.observations_images = None
127
+
128
+ def __init__(self):
129
+ self.instructions = """
130
+ You are an assistant for the cellpose-sam segmentation tool.
131
+
132
+ ## PRIMARY WORKFLOW - IMAGE SEGMENTATION
133
+
134
+ When a user provides an image:
135
+ 1. use appropriate tools to review which cellpose-sam parameters are available.
136
+ 2. use the tool: `get_segmentation_parameters`
137
+ - **IMPORTANT**: After this tool runs, you will receive image metadata (dimensions, properties)
138
+ - Use this information to reason about appropriate parameter values
139
+ 3. carefully analyze the image metadata and matched parameters:
140
+ - consider cell density based on image dimensions
141
+ - compare matched parameter values to image characteristics
142
+ - consider if adjustments would likely improve the segmentation
143
+ 4. Be conservative: if you make changes, assess if they should differ significantly from the original values
144
+ 5. Provide your final parameter recommendations in a clear, structured format
145
+ 6. Use the parameters to run cellpose_sam through the tool: run_cellpose_sam
146
+ 7. after run_cellpose_sam, call the tool: refine_cellpose_sam_segmentation
147
+ - **IMPORTANT**: After this tool runs, you will receive metadata about both original and segmented images
148
+ - Use the provided information to assess segmentation quality
149
+ 8. Based on the metadata and any quality metrics returned:
150
+ - Identify potential segmentation issues based on reported metrics
151
+ - If refinement is needed, use knowledge graph and RAG tools to understand parameter effects
152
+ - Decide which parameters to adjust based on the segmentation analysis
153
+ - Re-run run_cellpose_sam with adjusted parameters
154
+ **CRITICAL: Call refine_cellpose_sam_segmentation AT MOST 2 TIMES total**
155
+ - First call: Check initial segmentation quality
156
+ - Second call (if needed): Verify refinement improved results
157
+ - NEVER call it a third time - always stop after 2 refinement checks
158
+
159
+ ## DOCUMENTATION QUERY WORKFLOW ##
160
+
161
+ - "What is X": use `search_documentation_vector`
162
+ - "How does X affect Y": use `search_knowledge_graph`
163
+ - Complex analysis: use `hybrid_search`
164
+ - Parameter relationships: use `get_parameter_relationships`
165
+
166
+ ## RESPONSE STYLE ##
167
+ - Be concise and actionable
168
+ - Always explain your reasoning when adjusting parameters
169
+ - If keeping original matched parameters, briefly confirm why it's appropriate
170
+ - Base your decisions on the metadata and metrics provided by the tools
171
+
172
+ """
173
+
174
+ self.model = self._initialize_model()
175
+ self.agent = self._create_agent()
176
+
177
+
178
+ def _initialize_model(self):
179
+ """Initializes the TransformersModel for the agent with VLM support."""
180
+ clear_gpu_cache()
181
+
182
+ return InferenceClientModel(
183
+ model_id=settings.AGENT_MODEL_ID,
184
+ token = settings.HF_TOKEN
185
+ )
186
+
187
+
188
+
189
+ def _create_agent(self):
190
+ """Creates the ToolCallingAgent with all available tools and memory management."""
191
+ return ToolCallingAgent(
192
+ model=self.model,
193
+ tools=all_tools,
194
+ instructions=self.instructions,
195
+ max_steps=10,
196
+ step_callbacks=[
197
+ self.attach_images_callback,
198
+ self.manage_image_memory,
199
+ ]
200
+ )
201
+
202
+ @observe()
203
+ def run(self, task: str):
204
+ """Runs the agent on a given task with Langfuse tracing."""
205
+ print(f"\n{'='*60}\nTASK: {task}\n{'='*60}")
206
+
207
+ langfuse.update_current_trace(
208
+ input={"task": task},
209
+ user_id="user_001",
210
+ tags=["rag", "cellpose", "knowledge-graph", "vision"],
211
+ metadata={"agent_type": "ToolCallingAgent", "model_id": settings.AGENT_MODEL_ID}
212
+ )
213
+
214
+ try:
215
+ final_answer = self.agent.run(task)
216
+ print("\n--- Final Answer from Agent ---\n", final_answer)
217
+ langfuse.update_current_trace(output={"final_answer": final_answer})
218
+ return final_answer
219
+ except Exception as e:
220
+ print(f"Agent run failed: {e}")
221
+ langfuse.update_current_trace(output={"error": str(e)})
222
+ raise
223
+ finally:
224
+ clear_gpu_cache()
app.py CHANGED
@@ -1,14 +1,188 @@
 
 
 
1
  import gradio as gr
2
- import spaces
3
- import torch
 
4
 
5
- zero = torch.Tensor([0]).cuda()
6
- print(zero.device) # <-- 'cpu' πŸ€”
 
 
7
 
8
- @spaces.GPU
9
- def greet(n):
10
- print(zero.device) # <-- 'cuda:0' πŸ€—
11
- return f"Hello {zero + n} Tensor"
12
 
13
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
14
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio web interface for CellposeAgent
3
+ """
4
  import gradio as gr
5
+ from pathlib import Path
6
+ from langfuse import get_client
7
+ from openinference.instrumentation.smolagents import SmolagentsInstrumentor
8
 
9
+ from config import settings
10
+ from agents.agent import CellposeAgent
11
+ from stores import neo4j_store
12
+ from utils.prechecks import check_hf_persistent_storage
13
 
 
 
 
 
14
 
15
+ def setup_observability():
16
+ """Initializes Langfuse and Smolagents instrumentation."""
17
+ get_client()
18
+ SmolagentsInstrumentor().instrument()
19
+ print("βœ“ Observability and instrumentation initialized.")
20
+
21
+
22
+ def initialize_app():
23
+ """Initialize the application and verify prerequisites."""
24
+ print("\n--- Initializing Cellpose Agent Application ---")
25
+
26
+ # Setup observability
27
+ setup_observability()
28
+
29
+ # Configure LlamaIndex
30
+ settings.configure_llama_index()
31
+
32
+ # check for cellpose-db
33
+ check_hf_persistent_storage(
34
+ repo_id = "hmgill/Cellpose-DB",
35
+ target = "cellpose_db",
36
+ file_or_folder="folder"
37
+ )
38
+
39
+ # check for cellpose sam
40
+ check_hf_persistent_storage(
41
+ repo_id = "hmgill/Cellpose-SAM-Checkpoint",
42
+ target = "sam_vit_h_4b8939.pth",
43
+ file_or_folder="file"
44
+ )
45
+
46
+ # Verify knowledge graph is ready
47
+ try:
48
+ node_count, _ = neo4j_store.check_graph_status()
49
+ if node_count == 0:
50
+ print("\n❌ WARNING: The knowledge graph is empty.")
51
+ print("Please run the setup script to build the knowledge graph:")
52
+ print("\n python setup_kg.py\n")
53
+ return False
54
+ print(f"βœ“ Knowledge graph is ready with {node_count} nodes.")
55
+ except Exception as e:
56
+ print(f"❌ ERROR: Could not connect to Neo4j: {e}")
57
+ print("Please ensure Neo4j is running and accessible.")
58
+ return False
59
+
60
+ return True
61
+
62
+
63
+ def process_image_task(image_path: str, task_text: str, agent: CellposeAgent) -> str:
64
+ """
65
+ Process a user task with the CellposeAgent.
66
+
67
+ Args:
68
+ image_path: Path to the uploaded image file
69
+ task_text: User's text prompt/question
70
+ agent: Initialized CellposeAgent instance
71
+
72
+ Returns:
73
+ str: Agent's response
74
+ """
75
+ if not image_path:
76
+ return "⚠️ Please upload an image first."
77
+
78
+ if not task_text:
79
+ task_text = f"What parameters would work best for my image {image_path}?"
80
+
81
+ try:
82
+ result = agent.run(task_text)
83
+ get_client().flush()
84
+ return result
85
+ except Exception as e:
86
+ return f"❌ Error processing task: {str(e)}"
87
+
88
+
89
+ def create_gradio_interface():
90
+ """Creates and configures the Gradio interface."""
91
+
92
+ # Initialize the agent once at startup
93
+ if not initialize_app():
94
+ raise RuntimeError("Failed to initialize application. Please check logs.")
95
+
96
+ agent = CellposeAgent()
97
+ print("βœ“ CellposeAgent initialized and ready.")
98
+
99
+ with gr.Blocks(title="Cellpose-SAM Agent", theme=gr.themes.Soft()) as demo:
100
+ gr.Markdown(
101
+ """
102
+ # πŸ”¬ Cellpose-SAM Segmentation Agent
103
+
104
+ Upload a microscopy image and ask the AI agent to recommend optimal segmentation parameters,
105
+ run segmentation, or answer questions about the cellpose-sam pipeline.
106
+ """
107
+ )
108
+
109
+ with gr.Row():
110
+ with gr.Column(scale=1):
111
+ # Image upload
112
+ image_input = gr.Image(
113
+ label="Upload Microscopy Image",
114
+ type="filepath",
115
+ height=300
116
+ )
117
+
118
+ # Task input
119
+ task_input = gr.Textbox(
120
+ label="Task / Question",
121
+ placeholder="e.g., 'What parameters would work best for this image?' or leave empty for default",
122
+ lines=3
123
+ )
124
+
125
+ # Submit button
126
+ submit_btn = gr.Button("Run Agent", variant="primary", size="lg")
127
+
128
+ # Example tasks
129
+ gr.Markdown("### πŸ’‘ Example Tasks")
130
+ gr.Examples(
131
+ examples=[
132
+ ["What parameters would work best for this image?"],
133
+ ["Analyze this image and run segmentation with optimal parameters."],
134
+ ["What is the flow_threshold parameter and how does it affect segmentation?"],
135
+ ["Run segmentation with diameter=30, flow_threshold=0.5, cellprob_threshold=0, min_size=20"],
136
+ ],
137
+ inputs=task_input,
138
+ label="Click to use:"
139
+ )
140
+
141
+ with gr.Column(scale=1):
142
+ # Output
143
+ output = gr.Textbox(
144
+ label="Agent Response",
145
+ lines=20,
146
+ max_lines=30,
147
+ show_copy_button=True
148
+ )
149
+
150
+ # Event handler
151
+ submit_btn.click(
152
+ fn=lambda img, task: process_image_task(img, task, agent),
153
+ inputs=[image_input, task_input],
154
+ outputs=output
155
+ )
156
+
157
+ gr.Markdown(
158
+ """
159
+ ---
160
+ ### πŸ“š What can this agent do?
161
+
162
+ - **Parameter Recommendation**: Analyzes your image and suggests optimal segmentation parameters
163
+ - **Automated Segmentation**: Runs the full cellpose-sam pipeline with parameter refinement
164
+ - **Visual Analysis**: Uses vision-language models to assess segmentation quality
165
+ - **Documentation Search**: Answers questions about parameters using RAG and knowledge graphs
166
+ - **Iterative Refinement**: Automatically adjusts parameters based on visual feedback
167
+
168
+ ### πŸ” How it works
169
+
170
+ 1. Upload your microscopy image
171
+ 2. The agent finds similar images and recommends parameters
172
+ 3. Visually analyzes your image to validate recommendations
173
+ 4. Runs segmentation and checks quality
174
+ 5. Refines parameters if needed (up to 2 iterations)
175
+ """
176
+ )
177
+
178
+ return demo
179
+
180
+
181
+ def main():
182
+ """Launch the Gradio application."""
183
+ demo = create_gradio_interface()
184
+ demo.launch()
185
+
186
+
187
+ if __name__ == "__main__":
188
+ main()
config/__init__.py ADDED
File without changes
config/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (206 Bytes). View file
 
config/__pycache__/settings.cpython-311.pyc ADDED
Binary file (2.53 kB). View file
 
config/settings.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ """
4
+
5
+ import os
6
+ import torch
7
+ from llama_index.core import Settings
8
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
9
+ from llama_index.llms.huggingface import HuggingFaceLLM
10
+ from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
11
+ from llama_index.core.prompts import PromptTemplate
12
+
13
+
14
+
15
+ # ---- Model IDs ---
16
+ AGENT_MODEL_ID = "google/gemma-3-12b-it"
17
+ EMBEDDING_MODEL_ID = "clip-ViT-B-32"
18
+
19
+
20
+ # --- Environment & Paths ---
21
+ CHROMADB = os.getenv("CHROMADB")
22
+ CELLPOSE_SAM = os.getenv("CELLPOSE_SAM")
23
+ HF_TOKEN = os.getenv("HF_TOKEN")
24
+
25
+ NEO4J_URI = os.getenv("NEO4J_URI")
26
+ NEO4J_USERNAME = os.getenv("NEO4J_USERNAME")
27
+ NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
28
+ NEO4J_DATABASE = os.getenv("NEO4J_DATABASE")
29
+
30
+
31
+ # --- LlamaIndex Global Settings ---
32
+ def configure_llama_index():
33
+
34
+ """
35
+ Configures global LlamaIndex settings for the embedding model and the LLM.
36
+ """
37
+
38
+ print("βœ“ Configuring LlamaIndex settings...")
39
+
40
+ # Gemma 3 Prompt Template
41
+ query_wrapper_prompt = PromptTemplate(
42
+ "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{query_str}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
43
+ )
44
+
45
+ llm = HuggingFaceInferenceAPI(
46
+ model_name=AGENT_MODEL_ID,
47
+ token = HF_TOKEN,
48
+ provider = "auto"
49
+ )
50
+
51
+ Settings.llm = llm
52
+
53
+ Settings.embed_model = HuggingFaceEmbedding(
54
+ model_name=f"sentence-transformers/{EMBEDDING_MODEL_ID}"
55
+ )
56
+
57
+ Settings.chunk_size = 512
58
+ Settings.chunk_overlap = 50
59
+
60
+ print("βœ“ LlamaIndex configured to use local Embedding Model and LLM.")
config/settings.py~ ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ """
4
+
5
+ import os
6
+ import torch
7
+ from llama_index.core import Settings
8
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
9
+ from llama_index.llms.huggingface import HuggingFaceLLM
10
+ from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
11
+ from llama_index.core.prompts import PromptTemplate
12
+
13
+
14
+
15
+ # ---- Model IDs ---
16
+ AGENT_MODEL_ID = "google/gemma-3-12b-it"
17
+ EMBEDDING_MODEL_ID = "clip-ViT-B-32"
18
+
19
+
20
+ # --- Environment & Paths ---
21
+ CHROMADB = os.getenv("CHROMADB", "./data/cellpose_db/")
22
+ CELLPOSE_SAM = os.getenv("CELLPOSE_SAM", "./data/sam_vit_h_4b8939.pth")
23
+ HF_TOKEN = os.getenv("HF_TOKEN")
24
+
25
+ NEO4J_URI = os.getenv("NEO4J_URI", "neo4j+s://8d0af37b.databases.neo4j.io")
26
+ NEO4J_USERNAME = os.getenv("NEO4J_USERNAME", "neo4j")
27
+ NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "b5zqfnglm_CWHVYpmuXBR8oDyjaOqvT17L8pBUnfUJ0")
28
+ NEO4J_DATABASE = os.getenv("NEO4J_DATABASE", "neo4j")
29
+
30
+
31
+ # --- LlamaIndex Global Settings ---
32
+ def configure_llama_index():
33
+
34
+ """
35
+ Configures global LlamaIndex settings for the embedding model and the LLM.
36
+ """
37
+
38
+ print("βœ“ Configuring LlamaIndex settings...")
39
+
40
+ # Gemma 3 Prompt Template
41
+ query_wrapper_prompt = PromptTemplate(
42
+ "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{query_str}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
43
+ )
44
+
45
+ llm = HuggingFaceInferenceAPI(
46
+ model_name=AGENT_MODEL_ID,
47
+ token = HF_TOKEN,
48
+ provider = "auto"
49
+ )
50
+
51
+ Settings.llm = llm
52
+
53
+ Settings.embed_model = HuggingFaceEmbedding(
54
+ model_name=f"sentence-transformers/{EMBEDDING_MODEL_ID}"
55
+ )
56
+
57
+ Settings.chunk_size = 512
58
+ Settings.chunk_overlap = 50
59
+
60
+ print("βœ“ LlamaIndex configured to use local Embedding Model and LLM.")
models/__init__.py ADDED
File without changes
models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (206 Bytes). View file
 
models/__pycache__/embeddings.cpython-311.pyc ADDED
Binary file (1.78 kB). View file
 
models/__pycache__/reranker.cpython-311.pyc ADDED
Binary file (960 Bytes). View file
 
models/embeddings.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ """
4
+
5
+ from PIL import Image
6
+ from sentence_transformers import SentenceTransformer
7
+ from config import settings
8
+
9
+ # --- Global Singleton for Embedding Model ---
10
+ _embedding_model = None
11
+
12
+ def get_embedding_model():
13
+ """
14
+ Initializes and returns the SentenceTransformer model (singleton pattern).
15
+ """
16
+ global _embedding_model
17
+ if _embedding_model is None:
18
+ print("Initializing embedding model...")
19
+ _embedding_model = SentenceTransformer(settings.EMBEDDING_MODEL_ID)
20
+ print(f"βœ“ Embedding model initialized ({settings.EMBEDDING_MODEL_ID})")
21
+ return _embedding_model
22
+
23
+ def get_image_embedding(image_path: str) -> list[float]:
24
+ """
25
+ Generates a CLIP embedding for a given image file.
26
+
27
+ Args:
28
+ image_path (str): The path to the image file.
29
+
30
+ Returns:
31
+ list[float]: The image embedding as a list of floats.
32
+ """
33
+ model = get_embedding_model()
34
+ img = Image.open(image_path).convert("RGB")
35
+ embedding = model.encode(img, convert_to_numpy=True)
36
+ return embedding.tolist()
models/reranker.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ """
4
+
5
+ from llama_index.core.postprocessor import SentenceTransformerRerank
6
+
7
+ # --- Global Singleton for Reranker Model ---
8
+ _reranker_model = None
9
+
10
+ def get_reranker():
11
+ """
12
+ Initializes and returns the SentenceTransformerRerank model (singleton pattern).
13
+ This model will download on first use.
14
+ """
15
+
16
+ global _reranker_model
17
+
18
+ if _reranker_model is None:
19
+
20
+ print("Initializing Cross-Encoder Reranker model...")
21
+
22
+ # A popular, lightweight, and effective cross-encoder
23
+ _reranker_model = SentenceTransformerRerank(
24
+ model="cross-encoder/ms-marco-MiniLM-L-6-v2",
25
+ top_n=3 # The number of documents to return after reranking
26
+ )
27
+ print("βœ“ Reranker model initialized.")
28
+
29
+ return _reranker_model
stores/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .chroma_store import get_client as get_chroma_client
2
+ from .neo4j_store import (
3
+ get_graph_store,
4
+ check_graph_status,
5
+ initialize_knowledge_graph,
6
+ get_kg_index
7
+ )
8
+
9
+ __all__ = [
10
+ "get_chroma_client",
11
+ "get_graph_store",
12
+ "check_graph_status",
13
+ "initialize_knowledge_graph",
14
+ "get_kg_index"
15
+ ]
stores/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (543 Bytes). View file
 
stores/__pycache__/chroma_store.cpython-311.pyc ADDED
Binary file (932 Bytes). View file
 
stores/__pycache__/neo4j_store.cpython-311.pyc ADDED
Binary file (4.27 kB). View file
 
stores/chroma_store.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ """
4
+
5
+ import chromadb
6
+ from config import settings
7
+
8
+ # --- Global Singleton for ChromaDB Client ---
9
+ _chroma_client = None
10
+
11
+
12
+
13
+ def get_client():
14
+ """
15
+ Initializes and returns the ChromaDB persistent client (singleton pattern).
16
+ """
17
+ global _chroma_client
18
+ if _chroma_client is None:
19
+ print("Initializing ChromaDB client...")
20
+ _chroma_client = chromadb.PersistentClient(path=settings.CHROMADB)
21
+ print(f"βœ“ ChromaDB client connected to path: {settings.CHROMADB}")
22
+ return _chroma_client
stores/chroma_store.py~ ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ """
4
+
5
+ import chromadb
6
+ from config import settings
7
+
8
+ # --- Global Singleton for ChromaDB Client ---
9
+ _chroma_client = None
10
+
11
+
12
+
13
+ def get_client():
14
+ """
15
+ Initializes and returns the ChromaDB persistent client (singleton pattern).
16
+ """
17
+ global _chroma_client
18
+ if _chroma_client is None:
19
+ print("Initializing ChromaDB client...")
20
+ _chroma_client = chromadb.PersistentClient(path=settings.CHROMADB)
21
+ print(f"βœ“ ChromaDB client connected to path: {settings.CHROMA_DB_PATH}")
22
+ return _chroma_client
stores/neo4j_store.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ """
4
+
5
+ from llama_index.core import Document, KnowledgeGraphIndex, StorageContext
6
+ from llama_index.graph_stores.neo4j import Neo4jGraphStore
7
+ from neo4j import GraphDatabase
8
+
9
+ from config import settings
10
+ from stores import chroma_store
11
+
12
+ # --- Global Singleton for KG Index ---
13
+ _kg_index = None
14
+
15
+ def get_graph_store():
16
+ """Initializes and returns the Neo4jGraphStore."""
17
+ return Neo4jGraphStore(
18
+ username=settings.NEO4J_USERNAME,
19
+ password=settings.NEO4J_PASSWORD,
20
+ url=settings.NEO4J_URI,
21
+ database=settings.NEO4J_DATABASE,
22
+ )
23
+
24
+ def check_graph_status():
25
+ """Checks if the Neo4j graph contains any nodes or relationships."""
26
+ driver = GraphDatabase.driver(
27
+ settings.NEO4J_URI,
28
+ auth=(settings.NEO4J_USERNAME, settings.NEO4J_PASSWORD)
29
+ )
30
+ with driver.session(database=settings.NEO4J_DATABASE) as session:
31
+ nodes_result = session.run("MATCH (n) RETURN count(n) as count")
32
+ node_count = nodes_result.single()['count']
33
+ rels_result = session.run("MATCH ()-[r]->() RETURN count(r) as count")
34
+ rel_count = rels_result.single()['count']
35
+ driver.close()
36
+ return node_count, rel_count
37
+
38
+ def initialize_knowledge_graph():
39
+ """Builds the knowledge graph from documents in ChromaDB and stores it in Neo4j."""
40
+ print("\n--- Building Knowledge Graph in Neo4j ---")
41
+ chroma_client = chroma_store.get_client()
42
+ doc_collection = chroma_client.get_collection(name='cellpose_docs')
43
+ doc_data = doc_collection.get()
44
+
45
+ documents = [
46
+ Document(text=text, metadata=meta)
47
+ for text, meta in zip(doc_data['documents'], doc_data['metadatas'])
48
+ ]
49
+
50
+ storage_context = StorageContext.from_defaults(graph_store=get_graph_store())
51
+
52
+ KnowledgeGraphIndex.from_documents(
53
+ documents,
54
+ storage_context=storage_context,
55
+ max_triplets_per_chunk=3,
56
+ include_embeddings=True,
57
+ show_progress=True
58
+ )
59
+ print("βœ“ Knowledge Graph built and stored in Neo4j successfully.")
60
+
61
+ def get_kg_index():
62
+ """Loads the KnowledgeGraphIndex from the existing Neo4j graph store."""
63
+ global _kg_index
64
+ if _kg_index is None:
65
+ print("Loading Knowledge Graph index from Neo4j...")
66
+ storage_context = StorageContext.from_defaults(graph_store=get_graph_store())
67
+ _kg_index = KnowledgeGraphIndex(nodes=[], storage_context=storage_context)
68
+ print("βœ“ Knowledge Graph index loaded.")
69
+ return _kg_index
tools/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .segmentation import (
2
+ get_segmentation_parameters,
3
+ run_cellpose_sam,
4
+ refine_cellpose_sam_segmentation
5
+ )
6
+ from .search import (
7
+ list_all_collections,
8
+ search_documentation_vector,
9
+ search_knowledge_graph,
10
+ hybrid_search,
11
+ get_parameter_relationships,
12
+ )
13
+
14
+ all_tools = [
15
+ get_segmentation_parameters,
16
+ run_cellpose_sam,
17
+ refine_cellpose_sam_segmentation,
18
+ list_all_collections,
19
+ search_documentation_vector,
20
+ search_knowledge_graph,
21
+ hybrid_search,
22
+ get_parameter_relationships,
23
+ ]
24
+
25
+ __all__ = [
26
+ "all_tools",
27
+ "get_segmentation_parameters",
28
+ "run_cellpose_sam",
29
+ "refine_cellpose_sam_segmentation",
30
+ "list_all_collections",
31
+ "search_documentation_vector",
32
+ "search_knowledge_graph",
33
+ "hybrid_search",
34
+ "get_parameter_relationships",
35
+ ]
tools/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (780 Bytes). View file
 
tools/__pycache__/search.cpython-311.pyc ADDED
Binary file (5.62 kB). View file
 
tools/__pycache__/segmentation.cpython-311.pyc ADDED
Binary file (25.1 kB). View file
 
tools/search.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ """
4
+
5
+ # project/tools/search.py
6
+ from smolagents import tool
7
+ from langfuse import get_client
8
+ from llama_index.core import VectorStoreIndex, StorageContext
9
+ from llama_index.vector_stores.chroma import ChromaVectorStore
10
+
11
+ from stores import get_chroma_client, get_kg_index
12
+ from models.reranker import get_reranker
13
+
14
+ langfuse = get_client()
15
+
16
+ @tool
17
+ def list_all_collections() -> list[str]:
18
+ """Lists the names of all available collections in the ChromaDB database."""
19
+ # This is fine because it has no arguments.
20
+ print("\n--- TOOL CALLED: list_all_collections ---")
21
+ client = get_chroma_client()
22
+ collections = client.list_collections()
23
+ return [c.name for c in collections]
24
+
25
+
26
+ @tool
27
+ def search_documentation_vector(query: str) -> str:
28
+ """
29
+ Searches cellpose documentation using vector search followed by a reranking step.
30
+
31
+ Args:
32
+ query (str): The question or search term to look up in the documentation.
33
+ """
34
+ print(f"\n--- TOOL CALLED: search_documentation_vector (with Reranker) for '{query}' ---")
35
+ try:
36
+ client = get_chroma_client()
37
+ collection = client.get_collection(name='cellpose_docs')
38
+ vector_store = ChromaVectorStore(chroma_collection=collection)
39
+ vector_index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
40
+
41
+ query_engine = vector_index.as_query_engine(
42
+ similarity_top_k=25,
43
+ node_postprocessors=[get_reranker()]
44
+ )
45
+ response = query_engine.query(query)
46
+ return str(response)
47
+ except Exception as e:
48
+ return f"Error searching documentation: {e}"
49
+
50
+
51
+ @tool
52
+ def search_knowledge_graph(query: str) -> str:
53
+ """
54
+ Searches using knowledge graph relationships (Neo4j). Best for "how" and "why" questions.
55
+
56
+ Args:
57
+ query (str): The question about relationships between concepts (e.g., parameters).
58
+ """
59
+ print(f"\n--- TOOL CALLED: search_knowledge_graph for '{query}' ---")
60
+ try:
61
+ kg_index = get_kg_index()
62
+ query_engine = kg_index.as_query_engine(
63
+ include_text=True, response_mode="tree_summarize"
64
+ )
65
+ response = query_engine.query(query)
66
+ return str(response)
67
+ except Exception as e:
68
+ return f"Error querying knowledge graph: {e}."
69
+
70
+
71
+ @tool
72
+ def get_parameter_relationships(parameter_name: str) -> str:
73
+ """
74
+ Gets information about how a parameter relates to others using the knowledge graph.
75
+
76
+ Args:
77
+ parameter_name (str): The specific parameter name to investigate (e.g., 'flow_threshold').
78
+ """
79
+ print(f"\n--- TOOL CALLED: get_parameter_relationships for '{parameter_name}' ---")
80
+ query = f"What is {parameter_name} and how does it relate to other parameters?"
81
+ return search_knowledge_graph(query)
82
+
83
+
84
+ @tool
85
+ def hybrid_search(query: str) -> str:
86
+ """
87
+ Combines reranked vector search and knowledge graph search for complex questions.
88
+
89
+ Args:
90
+ query (str): The complex question that may require both semantic and relational understanding.
91
+ """
92
+ print(f"\n--- TOOL CALLED: hybrid_search (with Reranker) for '{query}' ---")
93
+ try:
94
+ vector_response_str = search_documentation_vector(query)
95
+ kg_response = search_knowledge_graph(query)
96
+
97
+ return f"Vector Search Results (Reranked):\n{vector_response_str}\n\nKnowledge Graph Insights:\n{kg_response}"
98
+
99
+ except Exception as e:
100
+ print(f"--- Hybrid search failed, falling back to vector search: {e} ---")
101
+ return search_documentation_vector(query)
tools/segmentation.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Segmentation tools for cellpose-sam pipeline with proper smolagents VLM integration.
3
+ """
4
+ import base64
5
+ import json
6
+ import re
7
+ from typing import Any, Dict, TYPE_CHECKING
8
+ import numpy as np
9
+ import cv2
10
+ import torch
11
+ from PIL import Image
12
+ from skimage.measure import regionprops
13
+ from cellpose import models
14
+ from segment_anything import sam_model_registry, SamPredictor
15
+
16
+ from smolagents import tool
17
+ from smolagents.agents import ActionStep
18
+ from langfuse import get_client
19
+
20
+ from stores import chroma_store
21
+ from models.embeddings import get_image_embedding
22
+ from utils.image_utils import resize_and_encode_image
23
+ from config import settings
24
+
25
+
26
+ langfuse = get_client()
27
+
28
+
29
+ # --- Global State and Caching ---
30
+ _image_cache: Dict[str, tuple[str, str]] = {}
31
+ _cellpose_model = None
32
+ _sam_predictor = None
33
+
34
+
35
+ def get_cellpose_model():
36
+ """Initialize Cellpose model (singleton)"""
37
+ global _cellpose_model
38
+ if _cellpose_model is None:
39
+ _cellpose_model = models.CellposeModel(gpu=torch.cuda.is_available())
40
+ return _cellpose_model
41
+
42
+
43
+ def get_sam_predictor():
44
+ """Initialize SAM predictor (singleton)"""
45
+ global _sam_predictor
46
+ if _sam_predictor is None:
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ sam = sam_model_registry["vit_h"](checkpoint=settings.CELLPOSE_SAM)
49
+ sam.to(device=device)
50
+ _sam_predictor = SamPredictor(sam)
51
+ return _sam_predictor
52
+
53
+ def _get_cached_image(image_path: str) -> tuple[str, str] | None:
54
+ """Helper to retrieve an image from the cache."""
55
+ if image_path in _image_cache:
56
+ return _image_cache[image_path]
57
+ return None
58
+
59
+ def _load_and_cache_image(image_path: str) -> tuple[str, str]:
60
+ """Helper to load, encode, and cache an image."""
61
+ image_base64, media_type = resize_and_encode_image(image_path)
62
+ _image_cache[image_path] = (image_base64, media_type)
63
+ return image_base64, media_type
64
+
65
+
66
+ def parse_parameters_from_text(param_text: str) -> dict:
67
+ """Extract parameter values from parameter text string."""
68
+ defaults = {
69
+ 'diameter': 25,
70
+ 'flow_threshold': 0.6,
71
+ 'cellprob_threshold': 0,
72
+ 'min_size': 15
73
+ }
74
+
75
+ params = defaults.copy()
76
+
77
+ patterns = {
78
+ 'diameter': r'diameter[=:]\s*(\d+)',
79
+ 'flow_threshold': r'flow_threshold[=:]\s*([\d.]+)',
80
+ 'cellprob_threshold': r'cellprob_threshold[=:]\s*([-\d.]+)',
81
+ 'min_size': r'min_size[=:]\s*(\d+)'
82
+ }
83
+
84
+ for param_name, pattern in patterns.items():
85
+ match = re.search(pattern, param_text, re.IGNORECASE)
86
+ if match:
87
+ value = match.group(1)
88
+ if param_name in ['diameter', 'min_size']:
89
+ params[param_name] = int(value)
90
+ else:
91
+ params[param_name] = float(value)
92
+
93
+ return params
94
+
95
+
96
+ @tool
97
+ def get_segmentation_parameters(image_path: str, agent: Any = None) -> str:
98
+ """
99
+ Finds the best cellpose-sam segmentation parameters for an image using vector similarity.
100
+ The image will be visible to the VLM for visual analysis.
101
+
102
+ Args:
103
+ image_path (str): Path to the image file to segment.
104
+ agent (Any, optional): The agent instance, passed automatically by smol-agents.
105
+
106
+ Returns:
107
+ str: JSON string containing recommended parameters and analysis context
108
+ (NO base64 to avoid GPU OOM)
109
+ """
110
+ print(f"\n--- TOOL CALLED: get_segmentation_parameters for '{image_path}' ---")
111
+
112
+ try:
113
+ # Load and cache image (for internal use)
114
+ image_base64, media_type = _get_cached_image(image_path) or _load_and_cache_image(image_path)
115
+
116
+
117
+ except Exception as e:
118
+ print(f"Warning: Could not read/resize image: {e}")
119
+ return json.dumps({"error": f"Could not read image: {e}"})
120
+
121
+ try:
122
+ # Get similar parameters from ChromaDB
123
+ client = chroma_store.get_client()
124
+ collection = client.get_collection(name='cellpose-sam_parameters_by_image_similarity')
125
+ query_embedding = get_image_embedding(image_path)
126
+
127
+ results = collection.query(query_embeddings=[query_embedding], n_results=1)
128
+
129
+ if not (results['metadatas'] and results['metadatas'][0]):
130
+ return json.dumps({"error": "No similar images found in the database."})
131
+
132
+ matched_parameters = results['metadatas'][0][0].get('parameter_text', 'N/A')
133
+ matched_image = results['metadatas'][0][0].get('image_name', 'N/A')
134
+ distance = results['distances'][0][0]
135
+
136
+ print(f"Most similar: {matched_image} (distance: {distance:.3f})")
137
+ print(f"Recommended: {matched_parameters}")
138
+
139
+ # Parse parameters
140
+ params = parse_parameters_from_text(matched_parameters)
141
+
142
+ # Analyze image
143
+ image = np.array(Image.open(image_path).convert("RGB"))
144
+ image_shape = image.shape
145
+ stats = {
146
+ 'size': (image_shape[0] * image_shape[1]),
147
+ 'mean_intensity': float(np.mean(image)),
148
+ 'stdev_intensity': float(np.std(image)),
149
+ 'min_intensity': int(np.min(image)),
150
+ 'max_intensity': int(np.max(image)),
151
+ }
152
+
153
+ # Log to Langfuse WITH image (for observability)
154
+ try:
155
+ langfuse.update_current_trace(
156
+ input={
157
+ "image_path": image_path,
158
+ "query_image": {
159
+ "type": "image_url",
160
+ "image_url": {
161
+ "url": f"data:{media_type};base64,{image_base64}"
162
+ }
163
+ },
164
+ "image_stats": stats
165
+ },
166
+ metadata={
167
+ "matched_image": matched_image,
168
+ "similarity_distance": float(distance),
169
+ "matched_parameters": matched_parameters,
170
+ "parsed_parameters": params
171
+ }
172
+ )
173
+ except Exception as log_error:
174
+ print(f"Warning: Could not log to Langfuse: {log_error}")
175
+
176
+ # Determine confidence level
177
+ if distance < 0.2:
178
+ confidence = "high"
179
+ confidence_note = "Very similar image found. Parameters should work well as-is."
180
+ elif distance < 0.4:
181
+ confidence = "medium"
182
+ confidence_note = "Similar image found. Parameters are a good starting point but may need minor adjustments."
183
+ else:
184
+ confidence = "low"
185
+ confidence_note = "No very similar images found. Parameters may need significant adjustment based on visual inspection."
186
+
187
+ # Return WITHOUT base64 (image already attached to ActionStep)
188
+ response = {
189
+ "status": "success",
190
+ "image_path": image_path,
191
+ "recommended_parameters": params,
192
+ "matched_image": matched_image,
193
+ "similarity_distance": float(distance),
194
+ "confidence": confidence,
195
+ "image_stats": stats,
196
+ "raw_parameter_text": matched_parameters,
197
+ "visual_guidance": "IMAGE NOW VISIBLE: The input image is now attached to this step. "
198
+ "Please visually inspect the image to assess cell morphology, density, "
199
+ "and boundaries before deciding whether to adjust the recommended parameters.",
200
+ "recommendation": f"{confidence_note}\n\nRecommended parameters:\n"
201
+ f"- diameter: {params['diameter']}\n"
202
+ f"- flow_threshold: {params['flow_threshold']}\n"
203
+ f"- cellprob_threshold: {params['cellprob_threshold']}\n"
204
+ f"- min_size: {params['min_size']}\n\n"
205
+ f"Image stats: {image_shape[0]}x{image_shape[1]} pixels, "
206
+ f"mean intensity {stats['mean_intensity']:.1f}\n\n"
207
+ f"To run segmentation, use: run_cellpose_sam(image_path='{image_path}', "
208
+ f"diameter={params['diameter']}, flow_threshold={params['flow_threshold']}, "
209
+ f"cellprob_threshold={params['cellprob_threshold']}, min_size={params['min_size']})"
210
+ }
211
+
212
+ return json.dumps(response, indent=2)
213
+
214
+ except Exception as e:
215
+ return json.dumps({"error": str(e)})
216
+
217
+
218
+ @tool
219
+ def run_cellpose_sam(
220
+ image_path: str,
221
+ diameter: int = None,
222
+ flow_threshold: float = None,
223
+ cellprob_threshold: float = None,
224
+ min_size: int = None,
225
+ output_path: str = None,
226
+ use_recommended_params: bool = True,
227
+ agent: Any = None
228
+ ) -> str:
229
+ """
230
+ Runs cellpose-sam segmentation pipeline on an image with specified parameters.
231
+ Returns results WITHOUT base64 images to prevent GPU memory issues.
232
+
233
+ Args:
234
+ image_path (str): Path to the image file to segment
235
+ diameter (int): Expected diameter of cells in pixels
236
+ flow_threshold (float): Flow error threshold (range: 0-1)
237
+ cellprob_threshold (float): Cell probability threshold (range: -6 to 6)
238
+ min_size (int): Minimum cell size in pixels
239
+ output_path (str): Optional path to save the overlay image
240
+ use_recommended_params (bool): If True and params not provided, get recommendations
241
+ agent (Any, optional): The agent instance
242
+
243
+ Returns:
244
+ str: JSON string with segmentation results (paths and stats, NO base64)
245
+ """
246
+ print(f"\n--- TOOL CALLED: run_cellpose_sam for '{image_path}' ---")
247
+
248
+ try:
249
+ # Load and cache input image
250
+ input_image_base64, input_media_type = _get_cached_image(image_path) or _load_and_cache_image(image_path)
251
+ except Exception as e:
252
+ return json.dumps({"error": f"Could not read input image: {e}"})
253
+
254
+ # Auto-fetch recommended parameters if needed
255
+ if use_recommended_params and all(p is None for p in [diameter, flow_threshold, cellprob_threshold, min_size]):
256
+ print("No parameters provided. Fetching recommended parameters...")
257
+ param_response = get_segmentation_parameters(image_path, agent=agent)
258
+
259
+ try:
260
+ param_data = json.loads(param_response)
261
+ if param_data.get("status") == "success":
262
+ rec_params = param_data["recommended_parameters"]
263
+ diameter = diameter or rec_params.get('diameter', 25)
264
+ flow_threshold = flow_threshold or rec_params.get('flow_threshold', 0.6)
265
+ cellprob_threshold = cellprob_threshold or rec_params.get('cellprob_threshold', 0)
266
+ min_size = min_size or rec_params.get('min_size', 15)
267
+ else:
268
+ diameter, flow_threshold, cellprob_threshold, min_size = 25, 0.6, 0, 15
269
+ except json.JSONDecodeError:
270
+ diameter, flow_threshold, cellprob_threshold, min_size = 25, 0.6, 0, 15
271
+ else:
272
+ diameter = diameter if diameter is not None else 25
273
+ flow_threshold = flow_threshold if flow_threshold is not None else 0.6
274
+ cellprob_threshold = cellprob_threshold if cellprob_threshold is not None else 0
275
+ min_size = min_size if min_size is not None else 15
276
+
277
+ print(f"Final parameters: diameter={diameter}, flow_threshold={flow_threshold}, "
278
+ f"cellprob_threshold={cellprob_threshold}, min_size={min_size}")
279
+
280
+ try:
281
+ # Read image
282
+ img = cv2.imread(image_path)
283
+ if img is None:
284
+ return json.dumps({"error": f"Could not read image at {image_path}"})
285
+
286
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
287
+ cellpose_model = get_cellpose_model()
288
+ sam_predictor = get_sam_predictor()
289
+
290
+ # Run Cellpose
291
+ print("Running Cellpose...")
292
+ masks_cellpose, flows, styles = cellpose_model.eval(
293
+ img_rgb,
294
+ diameter=diameter,
295
+ flow_threshold=flow_threshold,
296
+ cellprob_threshold=cellprob_threshold,
297
+ min_size=min_size
298
+ )
299
+
300
+ if masks_cellpose.max() == 0:
301
+ return json.dumps({
302
+ "status": "no_cells_detected",
303
+ "message": "No cells detected. Try adjusting parameters.",
304
+ "parameters": {
305
+ "diameter": diameter,
306
+ "flow_threshold": flow_threshold,
307
+ "cellprob_threshold": cellprob_threshold,
308
+ "min_size": min_size
309
+ }
310
+ })
311
+
312
+ print(f"Cellpose detected {masks_cellpose.max()} regions")
313
+
314
+ # SAM refinement
315
+ sam_predictor.set_image(img_rgb)
316
+ props = regionprops(masks_cellpose)
317
+ boxes = np.array([prop.bbox for prop in props])
318
+ boxes = boxes[:, [1,0,3,2]]
319
+
320
+ print(f"Refining {len(boxes)} masks with SAM...")
321
+
322
+ combined_masks = np.zeros(img_rgb.shape[:2], dtype=np.uint16)
323
+ colored_overlay = img_rgb.copy().astype(np.float32)
324
+
325
+ for i, box in enumerate(boxes):
326
+ masks, scores, _ = sam_predictor.predict(box=box, multimask_output=True)
327
+ best_mask = masks[np.argmax(scores)]
328
+ combined_masks[best_mask] = i + 1
329
+ color = np.random.randint(0, 255, 3)
330
+ colored_overlay[best_mask] = colored_overlay[best_mask] * 0.6 + color * 0.4
331
+
332
+ # Generate output path
333
+ if output_path is None:
334
+ base_name = image_path.rsplit('.', 1)[0]
335
+ output_path = f"{base_name}_cellpose_sam_overlay.png"
336
+
337
+ # Save output
338
+ cv2.imwrite(output_path, cv2.cvtColor(colored_overlay.astype(np.uint8), cv2.COLOR_RGB2BGR))
339
+
340
+ # Load and cache output image
341
+ output_image_base64, output_media_type = _load_and_cache_image(output_path)
342
+
343
+ # Log to Langfuse WITH both images
344
+ try:
345
+ langfuse.update_current_trace(
346
+ input={
347
+ "image_path": image_path,
348
+ "input_image": {
349
+ "type": "image_url",
350
+ "image_url": {"url": f"data:{input_media_type};base64,{input_image_base64}"}
351
+ }
352
+ },
353
+ output={
354
+ "cell_count": int(masks_cellpose.max()),
355
+ "output_image": {
356
+ "type": "image_url",
357
+ "image_url": {"url": f"data:{output_media_type};base64,{output_image_base64}"}
358
+ },
359
+ "output_path": output_path
360
+ },
361
+ metadata={
362
+ "parameters": {
363
+ "diameter": diameter,
364
+ "flow_threshold": flow_threshold,
365
+ "cellprob_threshold": cellprob_threshold,
366
+ "min_size": min_size
367
+ }
368
+ }
369
+ )
370
+ except Exception as log_error:
371
+ print(f"Warning: Could not log output to Langfuse: {log_error}")
372
+
373
+ # Return WITHOUT base64
374
+ result = {
375
+ "status": "success",
376
+ "cell_count": int(masks_cellpose.max()),
377
+ "output_path": output_path,
378
+ "input_path": image_path,
379
+ "parameters": {
380
+ "diameter": diameter,
381
+ "flow_threshold": flow_threshold,
382
+ "cellprob_threshold": cellprob_threshold,
383
+ "min_size": min_size
384
+ },
385
+ "summary": f"Detected {masks_cellpose.max()} cells. Output saved to: {output_path}",
386
+ "next_step": "Call refine_cellpose_sam_segmentation to visually analyze the segmentation quality and decide if parameter adjustments are needed."
387
+ }
388
+
389
+ return json.dumps(result, indent=2)
390
+
391
+ except Exception as e:
392
+ return json.dumps({"error": f"Error during segmentation: {e}"})
393
+
394
+
395
+ @tool
396
+ def refine_cellpose_sam_segmentation(
397
+ original_image_path: str,
398
+ segmentation_output_path: str,
399
+ current_parameters: dict,
400
+ agent: Any = None,
401
+ ) -> str:
402
+ """
403
+ Provides both original and segmented images to the VLM for visual quality assessment.
404
+ The VLM will be able to see both images and provide informed analysis.
405
+
406
+ Use this tool after run_cellpose_sam to check segmentation quality. The tool attaches
407
+ both images to the current step so you can visually compare them.
408
+
409
+ Before calling, consider using search_knowledge_graph or hybrid_search to refresh
410
+ your understanding of how cellpose parameters affect segmentation.
411
+
412
+ Common issues and fixes:
413
+ - Under-segmentation (cells merged): decrease flow_threshold or diameter
414
+ - Over-segmentation (cells fragmented): increase flow_threshold or min_size
415
+ - Too few cells: decrease cellprob_threshold or flow_threshold
416
+ - Too many false positives: increase cellprob_threshold or min_size
417
+
418
+ Args:
419
+ original_image_path: Path to the original input image
420
+ segmentation_output_path: Path to the segmented overlay image
421
+ current_parameters: Dict with current diameter, flow_threshold, cellprob_threshold, min_size
422
+ agent: The agent instance (passed automatically)
423
+
424
+ Returns:
425
+ str: JSON with guidance for VLM analysis (NO base64 images)
426
+ """
427
+ print(f"\n--- TOOL CALLED: refine_cellpose_sam_segmentation ---")
428
+ print(f"Original image: {original_image_path}")
429
+ print(f"Segmented image: {segmentation_output_path}")
430
+ print(f"Current parameters: {current_parameters}")
431
+
432
+ try:
433
+ # Load both images (for cache)
434
+ original_b64, original_type = _get_cached_image(original_image_path) or _load_and_cache_image(original_image_path)
435
+ segmented_b64, segmented_type = _get_cached_image(segmentation_output_path) or _load_and_cache_image(segmentation_output_path)
436
+
437
+ # CRITICAL: Attach BOTH images to ActionStep so VLM can see them
438
+ if agent is not None and hasattr(agent, 'memory') and hasattr(agent.memory, 'steps'):
439
+ current_steps = [s for s in agent.memory.steps if isinstance(s, ActionStep)]
440
+ if current_steps:
441
+ current_step = current_steps[-1]
442
+
443
+ # Load both as PIL Images
444
+ original_img = Image.open(original_image_path).convert("RGB")
445
+ segmented_img = Image.open(segmentation_output_path).convert("RGB")
446
+
447
+ # CRITICAL: Use .copy() for both images
448
+ current_step.observations_images = [original_img.copy(), segmented_img.copy()]
449
+ print(f"βœ“ Attached both images to ActionStep for VLM comparison")
450
+
451
+ # Get image dimensions for context
452
+ original_img_array = np.array(Image.open(original_image_path).convert("RGB"))
453
+ img_size = original_img_array.shape[0] * original_img_array.shape[1]
454
+
455
+ # Log to Langfuse WITH both images
456
+ try:
457
+ langfuse.update_current_trace(
458
+ input={
459
+ "tool": "refine_cellpose_sam_segmentation",
460
+ "original_image": {
461
+ "type": "image_url",
462
+ "image_url": {"url": f"data:{original_type};base64,{original_b64}"}
463
+ },
464
+ "segmented_image": {
465
+ "type": "image_url",
466
+ "image_url": {"url": f"data:{segmented_type};base64,{segmented_b64}"}
467
+ },
468
+ "current_parameters": current_parameters
469
+ },
470
+ metadata={
471
+ "original_path": original_image_path,
472
+ "segmented_path": segmentation_output_path
473
+ }
474
+ )
475
+ except Exception as log_error:
476
+ print(f"Warning: Could not log to Langfuse: {log_error}")
477
+
478
+ # Return analysis guidance WITHOUT base64
479
+ analysis = {
480
+ "status": "ready_for_visual_analysis",
481
+ "images_attached": "BOTH IMAGES NOW VISIBLE: The first image is the original input, "
482
+ "the second is the segmented overlay. Compare them visually to assess quality.",
483
+ "image_paths": {
484
+ "original": original_image_path,
485
+ "segmented": segmentation_output_path
486
+ },
487
+ "current_parameters": current_parameters,
488
+ "image_info": {
489
+ "dimensions": f"{original_img_array.shape[1]}x{original_img_array.shape[0]}",
490
+ "total_pixels": img_size
491
+ },
492
+ "visual_analysis_checklist": [
493
+ "1. Do the colored masks accurately cover entire cells without extending beyond boundaries?",
494
+ "2. Are neighboring cells properly separated, or are they merged together?",
495
+ "3. Are there many small false positive detections (noise)?",
496
+ "4. Are any large, obvious cells being missed completely?",
497
+ "5. Overall quality assessment: excellent, good, needs_refinement, or poor?"
498
+ ],
499
+ "parameter_adjustment_guide": {
500
+ "under_segmentation": {
501
+ "symptoms": "Masks don't reach cell edges, cells appear merged",
502
+ "solution": "Decrease flow_threshold by 0.1-0.2 OR decrease diameter by 10-20%"
503
+ },
504
+ "over_segmentation": {
505
+ "symptoms": "Masks extend past boundaries, cells fragmented into pieces",
506
+ "solution": "Increase flow_threshold by 0.1-0.2 OR increase min_size to 2-3x current value"
507
+ },
508
+ "too_few_cells": {
509
+ "symptoms": "Obvious cells in image are not being detected",
510
+ "solution": "Decrease cellprob_threshold by 1-2 OR decrease flow_threshold by 0.1-0.2"
511
+ },
512
+ "too_many_false_positives": {
513
+ "symptoms": "Many tiny spurious detections, background noise detected as cells",
514
+ "solution": "Increase cellprob_threshold by 1-2 OR increase min_size to 2-3x current value"
515
+ }
516
+ },
517
+ "next_steps": {
518
+ "if_good": "If segmentation looks accurate, inform the user of success and provide the output_path.",
519
+ "if_needs_refinement": "Based on your visual analysis, adjust the appropriate parameters and call run_cellpose_sam again with the new values.",
520
+ "important": "You can only call refine_cellpose_sam_segmentation AT MOST 2 TIMES total. If this is your second call, you must make a final decision."
521
+ }
522
+ }
523
+
524
+ return json.dumps(analysis, indent=2)
525
+
526
+ except Exception as e:
527
+ error_result = {
528
+ "status": "error",
529
+ "error": str(e),
530
+ "message": "Could not load images for refinement. Check that both file paths are valid."
531
+ }
532
+ return json.dumps(error_result, indent=2)
tools/segmentation.py~ ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Segmentation tools for cellpose-sam pipeline with proper smolagents VLM integration.
3
+ """
4
+ import base64
5
+ import json
6
+ import re
7
+ from typing import Any, Dict, TYPE_CHECKING
8
+ import numpy as np
9
+ import cv2
10
+ import torch
11
+ from PIL import Image
12
+ from skimage.measure import regionprops
13
+ from cellpose import models
14
+ from segment_anything import sam_model_registry, SamPredictor
15
+
16
+ from smolagents import tool
17
+ from smolagents.agents import ActionStep
18
+ from langfuse import get_client
19
+
20
+ from stores import chroma_store
21
+ from models.embeddings import get_image_embedding
22
+ from utils.image_utils import resize_and_encode_image
23
+
24
+
25
+ langfuse = get_client()
26
+
27
+
28
+ # --- Global State and Caching ---
29
+ _image_cache: Dict[str, tuple[str, str]] = {}
30
+ _cellpose_model = None
31
+ _sam_predictor = None
32
+
33
+
34
+ def get_cellpose_model():
35
+ """Initialize Cellpose model (singleton)"""
36
+ global _cellpose_model
37
+ if _cellpose_model is None:
38
+ _cellpose_model = models.CellposeModel(gpu=torch.cuda.is_available())
39
+ return _cellpose_model
40
+
41
+
42
+ def get_sam_predictor():
43
+ """Initialize SAM predictor (singleton)"""
44
+ global _sam_predictor
45
+ if _sam_predictor is None:
46
+ device = "cuda" if torch.cuda.is_available() else "cpu"
47
+ sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
48
+ sam.to(device=device)
49
+ _sam_predictor = SamPredictor(sam)
50
+ return _sam_predictor
51
+
52
+ def _get_cached_image(image_path: str) -> tuple[str, str] | None:
53
+ """Helper to retrieve an image from the cache."""
54
+ if image_path in _image_cache:
55
+ return _image_cache[image_path]
56
+ return None
57
+
58
+ def _load_and_cache_image(image_path: str) -> tuple[str, str]:
59
+ """Helper to load, encode, and cache an image."""
60
+ image_base64, media_type = resize_and_encode_image(image_path)
61
+ _image_cache[image_path] = (image_base64, media_type)
62
+ return image_base64, media_type
63
+
64
+
65
+ def parse_parameters_from_text(param_text: str) -> dict:
66
+ """Extract parameter values from parameter text string."""
67
+ defaults = {
68
+ 'diameter': 25,
69
+ 'flow_threshold': 0.6,
70
+ 'cellprob_threshold': 0,
71
+ 'min_size': 15
72
+ }
73
+
74
+ params = defaults.copy()
75
+
76
+ patterns = {
77
+ 'diameter': r'diameter[=:]\s*(\d+)',
78
+ 'flow_threshold': r'flow_threshold[=:]\s*([\d.]+)',
79
+ 'cellprob_threshold': r'cellprob_threshold[=:]\s*([-\d.]+)',
80
+ 'min_size': r'min_size[=:]\s*(\d+)'
81
+ }
82
+
83
+ for param_name, pattern in patterns.items():
84
+ match = re.search(pattern, param_text, re.IGNORECASE)
85
+ if match:
86
+ value = match.group(1)
87
+ if param_name in ['diameter', 'min_size']:
88
+ params[param_name] = int(value)
89
+ else:
90
+ params[param_name] = float(value)
91
+
92
+ return params
93
+
94
+
95
+ @tool
96
+ def get_segmentation_parameters(image_path: str, agent: Any = None) -> str:
97
+ """
98
+ Finds the best cellpose-sam segmentation parameters for an image using vector similarity.
99
+ The image will be visible to the VLM for visual analysis.
100
+
101
+ Args:
102
+ image_path (str): Path to the image file to segment.
103
+ agent (Any, optional): The agent instance, passed automatically by smol-agents.
104
+
105
+ Returns:
106
+ str: JSON string containing recommended parameters and analysis context
107
+ (NO base64 to avoid GPU OOM)
108
+ """
109
+ print(f"\n--- TOOL CALLED: get_segmentation_parameters for '{image_path}' ---")
110
+
111
+ try:
112
+ # Load and cache image (for internal use)
113
+ image_base64, media_type = _get_cached_image(image_path) or _load_and_cache_image(image_path)
114
+
115
+
116
+ except Exception as e:
117
+ print(f"Warning: Could not read/resize image: {e}")
118
+ return json.dumps({"error": f"Could not read image: {e}"})
119
+
120
+ try:
121
+ # Get similar parameters from ChromaDB
122
+ client = chroma_store.get_client()
123
+ collection = client.get_collection(name='cellpose-sam_parameters_by_image_similarity')
124
+ query_embedding = get_image_embedding(image_path)
125
+
126
+ results = collection.query(query_embeddings=[query_embedding], n_results=1)
127
+
128
+ if not (results['metadatas'] and results['metadatas'][0]):
129
+ return json.dumps({"error": "No similar images found in the database."})
130
+
131
+ matched_parameters = results['metadatas'][0][0].get('parameter_text', 'N/A')
132
+ matched_image = results['metadatas'][0][0].get('image_name', 'N/A')
133
+ distance = results['distances'][0][0]
134
+
135
+ print(f"Most similar: {matched_image} (distance: {distance:.3f})")
136
+ print(f"Recommended: {matched_parameters}")
137
+
138
+ # Parse parameters
139
+ params = parse_parameters_from_text(matched_parameters)
140
+
141
+ # Analyze image
142
+ image = np.array(Image.open(image_path).convert("RGB"))
143
+ image_shape = image.shape
144
+ stats = {
145
+ 'size': (image_shape[0] * image_shape[1]),
146
+ 'mean_intensity': float(np.mean(image)),
147
+ 'stdev_intensity': float(np.std(image)),
148
+ 'min_intensity': int(np.min(image)),
149
+ 'max_intensity': int(np.max(image)),
150
+ }
151
+
152
+ # Log to Langfuse WITH image (for observability)
153
+ try:
154
+ langfuse.update_current_trace(
155
+ input={
156
+ "image_path": image_path,
157
+ "query_image": {
158
+ "type": "image_url",
159
+ "image_url": {
160
+ "url": f"data:{media_type};base64,{image_base64}"
161
+ }
162
+ },
163
+ "image_stats": stats
164
+ },
165
+ metadata={
166
+ "matched_image": matched_image,
167
+ "similarity_distance": float(distance),
168
+ "matched_parameters": matched_parameters,
169
+ "parsed_parameters": params
170
+ }
171
+ )
172
+ except Exception as log_error:
173
+ print(f"Warning: Could not log to Langfuse: {log_error}")
174
+
175
+ # Determine confidence level
176
+ if distance < 0.2:
177
+ confidence = "high"
178
+ confidence_note = "Very similar image found. Parameters should work well as-is."
179
+ elif distance < 0.4:
180
+ confidence = "medium"
181
+ confidence_note = "Similar image found. Parameters are a good starting point but may need minor adjustments."
182
+ else:
183
+ confidence = "low"
184
+ confidence_note = "No very similar images found. Parameters may need significant adjustment based on visual inspection."
185
+
186
+ # Return WITHOUT base64 (image already attached to ActionStep)
187
+ response = {
188
+ "status": "success",
189
+ "image_path": image_path,
190
+ "recommended_parameters": params,
191
+ "matched_image": matched_image,
192
+ "similarity_distance": float(distance),
193
+ "confidence": confidence,
194
+ "image_stats": stats,
195
+ "raw_parameter_text": matched_parameters,
196
+ "visual_guidance": "IMAGE NOW VISIBLE: The input image is now attached to this step. "
197
+ "Please visually inspect the image to assess cell morphology, density, "
198
+ "and boundaries before deciding whether to adjust the recommended parameters.",
199
+ "recommendation": f"{confidence_note}\n\nRecommended parameters:\n"
200
+ f"- diameter: {params['diameter']}\n"
201
+ f"- flow_threshold: {params['flow_threshold']}\n"
202
+ f"- cellprob_threshold: {params['cellprob_threshold']}\n"
203
+ f"- min_size: {params['min_size']}\n\n"
204
+ f"Image stats: {image_shape[0]}x{image_shape[1]} pixels, "
205
+ f"mean intensity {stats['mean_intensity']:.1f}\n\n"
206
+ f"To run segmentation, use: run_cellpose_sam(image_path='{image_path}', "
207
+ f"diameter={params['diameter']}, flow_threshold={params['flow_threshold']}, "
208
+ f"cellprob_threshold={params['cellprob_threshold']}, min_size={params['min_size']})"
209
+ }
210
+
211
+ return json.dumps(response, indent=2)
212
+
213
+ except Exception as e:
214
+ return json.dumps({"error": str(e)})
215
+
216
+
217
+ @tool
218
+ def run_cellpose_sam(
219
+ image_path: str,
220
+ diameter: int = None,
221
+ flow_threshold: float = None,
222
+ cellprob_threshold: float = None,
223
+ min_size: int = None,
224
+ output_path: str = None,
225
+ use_recommended_params: bool = True,
226
+ agent: Any = None
227
+ ) -> str:
228
+ """
229
+ Runs cellpose-sam segmentation pipeline on an image with specified parameters.
230
+ Returns results WITHOUT base64 images to prevent GPU memory issues.
231
+
232
+ Args:
233
+ image_path (str): Path to the image file to segment
234
+ diameter (int): Expected diameter of cells in pixels
235
+ flow_threshold (float): Flow error threshold (range: 0-1)
236
+ cellprob_threshold (float): Cell probability threshold (range: -6 to 6)
237
+ min_size (int): Minimum cell size in pixels
238
+ output_path (str): Optional path to save the overlay image
239
+ use_recommended_params (bool): If True and params not provided, get recommendations
240
+ agent (Any, optional): The agent instance
241
+
242
+ Returns:
243
+ str: JSON string with segmentation results (paths and stats, NO base64)
244
+ """
245
+ print(f"\n--- TOOL CALLED: run_cellpose_sam for '{image_path}' ---")
246
+
247
+ try:
248
+ # Load and cache input image
249
+ input_image_base64, input_media_type = _get_cached_image(image_path) or _load_and_cache_image(image_path)
250
+ except Exception as e:
251
+ return json.dumps({"error": f"Could not read input image: {e}"})
252
+
253
+ # Auto-fetch recommended parameters if needed
254
+ if use_recommended_params and all(p is None for p in [diameter, flow_threshold, cellprob_threshold, min_size]):
255
+ print("No parameters provided. Fetching recommended parameters...")
256
+ param_response = get_segmentation_parameters(image_path, agent=agent)
257
+
258
+ try:
259
+ param_data = json.loads(param_response)
260
+ if param_data.get("status") == "success":
261
+ rec_params = param_data["recommended_parameters"]
262
+ diameter = diameter or rec_params.get('diameter', 25)
263
+ flow_threshold = flow_threshold or rec_params.get('flow_threshold', 0.6)
264
+ cellprob_threshold = cellprob_threshold or rec_params.get('cellprob_threshold', 0)
265
+ min_size = min_size or rec_params.get('min_size', 15)
266
+ else:
267
+ diameter, flow_threshold, cellprob_threshold, min_size = 25, 0.6, 0, 15
268
+ except json.JSONDecodeError:
269
+ diameter, flow_threshold, cellprob_threshold, min_size = 25, 0.6, 0, 15
270
+ else:
271
+ diameter = diameter if diameter is not None else 25
272
+ flow_threshold = flow_threshold if flow_threshold is not None else 0.6
273
+ cellprob_threshold = cellprob_threshold if cellprob_threshold is not None else 0
274
+ min_size = min_size if min_size is not None else 15
275
+
276
+ print(f"Final parameters: diameter={diameter}, flow_threshold={flow_threshold}, "
277
+ f"cellprob_threshold={cellprob_threshold}, min_size={min_size}")
278
+
279
+ try:
280
+ # Read image
281
+ img = cv2.imread(image_path)
282
+ if img is None:
283
+ return json.dumps({"error": f"Could not read image at {image_path}"})
284
+
285
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
286
+ cellpose_model = get_cellpose_model()
287
+ sam_predictor = get_sam_predictor()
288
+
289
+ # Run Cellpose
290
+ print("Running Cellpose...")
291
+ masks_cellpose, flows, styles = cellpose_model.eval(
292
+ img_rgb,
293
+ diameter=diameter,
294
+ flow_threshold=flow_threshold,
295
+ cellprob_threshold=cellprob_threshold,
296
+ min_size=min_size
297
+ )
298
+
299
+ if masks_cellpose.max() == 0:
300
+ return json.dumps({
301
+ "status": "no_cells_detected",
302
+ "message": "No cells detected. Try adjusting parameters.",
303
+ "parameters": {
304
+ "diameter": diameter,
305
+ "flow_threshold": flow_threshold,
306
+ "cellprob_threshold": cellprob_threshold,
307
+ "min_size": min_size
308
+ }
309
+ })
310
+
311
+ print(f"Cellpose detected {masks_cellpose.max()} regions")
312
+
313
+ # SAM refinement
314
+ sam_predictor.set_image(img_rgb)
315
+ props = regionprops(masks_cellpose)
316
+ boxes = np.array([prop.bbox for prop in props])
317
+ boxes = boxes[:, [1,0,3,2]]
318
+
319
+ print(f"Refining {len(boxes)} masks with SAM...")
320
+
321
+ combined_masks = np.zeros(img_rgb.shape[:2], dtype=np.uint16)
322
+ colored_overlay = img_rgb.copy().astype(np.float32)
323
+
324
+ for i, box in enumerate(boxes):
325
+ masks, scores, _ = sam_predictor.predict(box=box, multimask_output=True)
326
+ best_mask = masks[np.argmax(scores)]
327
+ combined_masks[best_mask] = i + 1
328
+ color = np.random.randint(0, 255, 3)
329
+ colored_overlay[best_mask] = colored_overlay[best_mask] * 0.6 + color * 0.4
330
+
331
+ # Generate output path
332
+ if output_path is None:
333
+ base_name = image_path.rsplit('.', 1)[0]
334
+ output_path = f"{base_name}_cellpose_sam_overlay.png"
335
+
336
+ # Save output
337
+ cv2.imwrite(output_path, cv2.cvtColor(colored_overlay.astype(np.uint8), cv2.COLOR_RGB2BGR))
338
+
339
+ # Load and cache output image
340
+ output_image_base64, output_media_type = _load_and_cache_image(output_path)
341
+
342
+ # Log to Langfuse WITH both images
343
+ try:
344
+ langfuse.update_current_trace(
345
+ input={
346
+ "image_path": image_path,
347
+ "input_image": {
348
+ "type": "image_url",
349
+ "image_url": {"url": f"data:{input_media_type};base64,{input_image_base64}"}
350
+ }
351
+ },
352
+ output={
353
+ "cell_count": int(masks_cellpose.max()),
354
+ "output_image": {
355
+ "type": "image_url",
356
+ "image_url": {"url": f"data:{output_media_type};base64,{output_image_base64}"}
357
+ },
358
+ "output_path": output_path
359
+ },
360
+ metadata={
361
+ "parameters": {
362
+ "diameter": diameter,
363
+ "flow_threshold": flow_threshold,
364
+ "cellprob_threshold": cellprob_threshold,
365
+ "min_size": min_size
366
+ }
367
+ }
368
+ )
369
+ except Exception as log_error:
370
+ print(f"Warning: Could not log output to Langfuse: {log_error}")
371
+
372
+ # Return WITHOUT base64
373
+ result = {
374
+ "status": "success",
375
+ "cell_count": int(masks_cellpose.max()),
376
+ "output_path": output_path,
377
+ "input_path": image_path,
378
+ "parameters": {
379
+ "diameter": diameter,
380
+ "flow_threshold": flow_threshold,
381
+ "cellprob_threshold": cellprob_threshold,
382
+ "min_size": min_size
383
+ },
384
+ "summary": f"Detected {masks_cellpose.max()} cells. Output saved to: {output_path}",
385
+ "next_step": "Call refine_cellpose_sam_segmentation to visually analyze the segmentation quality and decide if parameter adjustments are needed."
386
+ }
387
+
388
+ return json.dumps(result, indent=2)
389
+
390
+ except Exception as e:
391
+ return json.dumps({"error": f"Error during segmentation: {e}"})
392
+
393
+
394
+ @tool
395
+ def refine_cellpose_sam_segmentation(
396
+ original_image_path: str,
397
+ segmentation_output_path: str,
398
+ current_parameters: dict,
399
+ agent: Any = None,
400
+ ) -> str:
401
+ """
402
+ Provides both original and segmented images to the VLM for visual quality assessment.
403
+ The VLM will be able to see both images and provide informed analysis.
404
+
405
+ Use this tool after run_cellpose_sam to check segmentation quality. The tool attaches
406
+ both images to the current step so you can visually compare them.
407
+
408
+ Before calling, consider using search_knowledge_graph or hybrid_search to refresh
409
+ your understanding of how cellpose parameters affect segmentation.
410
+
411
+ Common issues and fixes:
412
+ - Under-segmentation (cells merged): decrease flow_threshold or diameter
413
+ - Over-segmentation (cells fragmented): increase flow_threshold or min_size
414
+ - Too few cells: decrease cellprob_threshold or flow_threshold
415
+ - Too many false positives: increase cellprob_threshold or min_size
416
+
417
+ Args:
418
+ original_image_path: Path to the original input image
419
+ segmentation_output_path: Path to the segmented overlay image
420
+ current_parameters: Dict with current diameter, flow_threshold, cellprob_threshold, min_size
421
+ agent: The agent instance (passed automatically)
422
+
423
+ Returns:
424
+ str: JSON with guidance for VLM analysis (NO base64 images)
425
+ """
426
+ print(f"\n--- TOOL CALLED: refine_cellpose_sam_segmentation ---")
427
+ print(f"Original image: {original_image_path}")
428
+ print(f"Segmented image: {segmentation_output_path}")
429
+ print(f"Current parameters: {current_parameters}")
430
+
431
+ try:
432
+ # Load both images (for cache)
433
+ original_b64, original_type = _get_cached_image(original_image_path) or _load_and_cache_image(original_image_path)
434
+ segmented_b64, segmented_type = _get_cached_image(segmentation_output_path) or _load_and_cache_image(segmentation_output_path)
435
+
436
+ # CRITICAL: Attach BOTH images to ActionStep so VLM can see them
437
+ if agent is not None and hasattr(agent, 'memory') and hasattr(agent.memory, 'steps'):
438
+ current_steps = [s for s in agent.memory.steps if isinstance(s, ActionStep)]
439
+ if current_steps:
440
+ current_step = current_steps[-1]
441
+
442
+ # Load both as PIL Images
443
+ original_img = Image.open(original_image_path).convert("RGB")
444
+ segmented_img = Image.open(segmentation_output_path).convert("RGB")
445
+
446
+ # CRITICAL: Use .copy() for both images
447
+ current_step.observations_images = [original_img.copy(), segmented_img.copy()]
448
+ print(f"βœ“ Attached both images to ActionStep for VLM comparison")
449
+
450
+ # Get image dimensions for context
451
+ original_img_array = np.array(Image.open(original_image_path).convert("RGB"))
452
+ img_size = original_img_array.shape[0] * original_img_array.shape[1]
453
+
454
+ # Log to Langfuse WITH both images
455
+ try:
456
+ langfuse.update_current_trace(
457
+ input={
458
+ "tool": "refine_cellpose_sam_segmentation",
459
+ "original_image": {
460
+ "type": "image_url",
461
+ "image_url": {"url": f"data:{original_type};base64,{original_b64}"}
462
+ },
463
+ "segmented_image": {
464
+ "type": "image_url",
465
+ "image_url": {"url": f"data:{segmented_type};base64,{segmented_b64}"}
466
+ },
467
+ "current_parameters": current_parameters
468
+ },
469
+ metadata={
470
+ "original_path": original_image_path,
471
+ "segmented_path": segmentation_output_path
472
+ }
473
+ )
474
+ except Exception as log_error:
475
+ print(f"Warning: Could not log to Langfuse: {log_error}")
476
+
477
+ # Return analysis guidance WITHOUT base64
478
+ analysis = {
479
+ "status": "ready_for_visual_analysis",
480
+ "images_attached": "BOTH IMAGES NOW VISIBLE: The first image is the original input, "
481
+ "the second is the segmented overlay. Compare them visually to assess quality.",
482
+ "image_paths": {
483
+ "original": original_image_path,
484
+ "segmented": segmentation_output_path
485
+ },
486
+ "current_parameters": current_parameters,
487
+ "image_info": {
488
+ "dimensions": f"{original_img_array.shape[1]}x{original_img_array.shape[0]}",
489
+ "total_pixels": img_size
490
+ },
491
+ "visual_analysis_checklist": [
492
+ "1. Do the colored masks accurately cover entire cells without extending beyond boundaries?",
493
+ "2. Are neighboring cells properly separated, or are they merged together?",
494
+ "3. Are there many small false positive detections (noise)?",
495
+ "4. Are any large, obvious cells being missed completely?",
496
+ "5. Overall quality assessment: excellent, good, needs_refinement, or poor?"
497
+ ],
498
+ "parameter_adjustment_guide": {
499
+ "under_segmentation": {
500
+ "symptoms": "Masks don't reach cell edges, cells appear merged",
501
+ "solution": "Decrease flow_threshold by 0.1-0.2 OR decrease diameter by 10-20%"
502
+ },
503
+ "over_segmentation": {
504
+ "symptoms": "Masks extend past boundaries, cells fragmented into pieces",
505
+ "solution": "Increase flow_threshold by 0.1-0.2 OR increase min_size to 2-3x current value"
506
+ },
507
+ "too_few_cells": {
508
+ "symptoms": "Obvious cells in image are not being detected",
509
+ "solution": "Decrease cellprob_threshold by 1-2 OR decrease flow_threshold by 0.1-0.2"
510
+ },
511
+ "too_many_false_positives": {
512
+ "symptoms": "Many tiny spurious detections, background noise detected as cells",
513
+ "solution": "Increase cellprob_threshold by 1-2 OR increase min_size to 2-3x current value"
514
+ }
515
+ },
516
+ "next_steps": {
517
+ "if_good": "If segmentation looks accurate, inform the user of success and provide the output_path.",
518
+ "if_needs_refinement": "Based on your visual analysis, adjust the appropriate parameters and call run_cellpose_sam again with the new values.",
519
+ "important": "You can only call refine_cellpose_sam_segmentation AT MOST 2 TIMES total. If this is your second call, you must make a final decision."
520
+ }
521
+ }
522
+
523
+ return json.dumps(analysis, indent=2)
524
+
525
+ except Exception as e:
526
+ error_result = {
527
+ "status": "error",
528
+ "error": str(e),
529
+ "message": "Could not load images for refinement. Check that both file paths are valid."
530
+ }
531
+ return json.dumps(error_result, indent=2)
utils/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .gpu import (
2
+ clear_gpu_cache,
3
+ get_max_memory,
4
+ monitor_and_clear_cache
5
+ )
6
+
7
+ from .image_utils import (
8
+ resize_and_encode_image
9
+ )
10
+
11
+ from .prechecks import (
12
+ check_hf_persistent_storage
13
+ )
14
+
15
+ __all__ = __all__ = [
16
+ # GPU utilities
17
+ "clear_gpu_cache",
18
+ "get_max_memory",
19
+ "monitor_and_clear_cache",
20
+ # Image utilities
21
+ "resize_and_encode_image",
22
+ # precheck
23
+ "check_hf_persistent_storage"
24
+ ]
utils/__init__.py~ ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .gpu import (
2
+ clear_gpu_cache,
3
+ get_max_memory,
4
+ monitor_and_clear_cache
5
+ )
6
+
7
+ from .image_utils import (
8
+ resize_and_encode_image
9
+ )
10
+
11
+ from .precheck import (
12
+ ""
13
+ )
14
+
15
+ __all__ = __all__ = [
16
+ # GPU utilities
17
+ "clear_gpu_cache",
18
+ "get_max_memory",
19
+ "monitor_and_clear_cache",
20
+ # Image utilities
21
+ "resize_and_encode_image",
22
+
23
+ ]
utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (599 Bytes). View file
 
utils/__pycache__/gpu.cpython-311.pyc ADDED
Binary file (4.17 kB). View file
 
utils/__pycache__/image_utils.cpython-311.pyc ADDED
Binary file (1.6 kB). View file
 
utils/__pycache__/prechecks.cpython-311.pyc ADDED
Binary file (2.01 kB). View file
 
utils/gpu.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ """
4
+
5
+ import torch
6
+ import gc
7
+
8
+ def clear_gpu_cache():
9
+ """Frees up GPU memory by clearing cache and collecting garbage."""
10
+ if torch.cuda.is_available():
11
+ torch.cuda.empty_cache()
12
+ torch.cuda.synchronize()
13
+ gc.collect()
14
+ print("βœ“ GPU cache cleared.")
15
+
16
+
17
+ def get_max_memory(memory_fraction=0.85, cpu_memory="50GB"):
18
+ """
19
+ Automatically configure max memory per GPU.
20
+
21
+ When used with device_map="auto", this tells the model loader how much memory
22
+ it CAN use per GPU during the INITIAL model loading phase. If a model's layers
23
+ don't fit on one GPU with this limit, the loader will automatically split the
24
+ model across multiple GPUs.
25
+
26
+ Args:
27
+ memory_fraction: Fraction of GPU memory to allocate (0.0-1.0).
28
+ Default 0.85 leaves 15% headroom.
29
+ cpu_memory: Maximum CPU memory to use as offload space.
30
+
31
+ Returns:
32
+ dict: Memory limits per device, or None if no CUDA available
33
+ """
34
+ if not torch.cuda.is_available():
35
+ print("⚠ No CUDA GPUs available")
36
+ return None
37
+
38
+ max_memory = {}
39
+ total_available = 0
40
+
41
+ for i in range(torch.cuda.device_count()):
42
+ props = torch.cuda.get_device_properties(i)
43
+ total_memory = props.total_memory
44
+ usable_memory = int(total_memory * memory_fraction)
45
+ max_memory[i] = usable_memory
46
+ total_available += usable_memory
47
+
48
+ print(f"GPU {i} ({props.name}): "
49
+ f"{usable_memory / 1024**3:.2f}GB / {total_memory / 1024**3:.2f}GB "
50
+ f"({memory_fraction*100:.0f}% limit)")
51
+
52
+ # CPU memory for offloading if needed
53
+ max_memory["cpu"] = cpu_memory
54
+
55
+ print(f"βœ“ Total GPU memory available for models: {total_available / 1024**3:.2f}GB")
56
+ print(f"βœ“ CPU offload memory: {cpu_memory}")
57
+
58
+ return max_memory
59
+
60
+ def monitor_and_clear_cache(threshold=0.90):
61
+ """
62
+ Monitor GPU memory and clear cache if usage exceeds threshold.
63
+ Call this periodically during long-running operations.
64
+
65
+ Args:
66
+ threshold: Memory usage fraction (0.0-1.0) that triggers cache clearing
67
+ """
68
+ if not torch.cuda.is_available():
69
+ return
70
+
71
+ for i in range(torch.cuda.device_count()):
72
+ props = torch.cuda.get_device_properties(i)
73
+ allocated = torch.cuda.memory_allocated(i)
74
+ total = props.total_memory
75
+ usage = allocated / total
76
+
77
+ if usage > threshold:
78
+ print(f"⚠ GPU {i} usage at {usage*100:.1f}%, clearing cache...")
79
+ torch.cuda.empty_cache()
80
+ gc.collect()
utils/image_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image utilities for encoding and resizing
3
+ """
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+
8
+
9
+ def resize_and_encode_image(image_path: str, size: tuple = (512, 512)) -> tuple[str, str]:
10
+ """
11
+ Resize an image to specified size and encode as base64.
12
+
13
+ Args:
14
+ image_path (str): Path to the image file
15
+ size (tuple): Target size as (width, height), default (1024, 1024)
16
+
17
+ Returns:
18
+ tuple: (base64_string, media_type)
19
+ """
20
+ # Open and convert to RGB
21
+ img = Image.open(image_path).convert("RGB")
22
+
23
+ # Resize with high-quality resampling
24
+ img_resized = img.resize(size, Image.Resampling.LANCZOS)
25
+
26
+ # Encode to base64
27
+ buffered = BytesIO()
28
+ img_resized.save(buffered, format="PNG")
29
+ img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
30
+
31
+ return img_base64, "image/png"
utils/prechecks.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ """
4
+
5
+ from pathlib import Path
6
+ from huggingface_hub import hf_hub_download, snapshot_download
7
+
8
+ def check_hf_persistent_storage(
9
+ repo_id: str = None,
10
+ repo_type: str = "model",
11
+ file_or_folder="file",
12
+ target: str = None,
13
+ destination: str = "/data/"
14
+ ):
15
+
16
+ file_path = Path(destination) / target
17
+
18
+ def _download_file():
19
+ try:
20
+ if file_or_folder == "file":
21
+ hf_hub_download(
22
+ repo_id=repo_id,
23
+ repo_type=repo_type,
24
+ filename=target,
25
+ local_dir=destination
26
+ )
27
+ elif file_or_folder == "folder":
28
+ snapshot_download(
29
+ repo_id=repo_id,
30
+ repo_type=repo_type,
31
+ allow_patterns=f"{target}/**",
32
+ local_dir=destination
33
+ )
34
+
35
+ print(f"Successfully downloaded '{target}' to '{destination}'.")
36
+ except Exception as e:
37
+ print(f"An error occurred during the download: {e}")
38
+
39
+ # Check if the file exists at the specified path
40
+ if not file_path.exists():
41
+ _download_file()
42
+ else:
43
+ print(f"File '{file_path}' already exists. No download needed.")
44
+
utils/prechecks.py~ ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ """
4
+
5
+ from pathlib import Path
6
+ from huggingface_hub import hf_hub_download, snapshot_download
7
+
8
+ def check_hf_persistent_storage(
9
+ repo_id: str = None,
10
+ repo_type: str = "model",
11
+ file_or_folder="file",
12
+ target: str = None,
13
+ destination: str = "./data/"
14
+ ):
15
+
16
+ file_path = Path(destination) / target
17
+
18
+ def _download_file():
19
+ try:
20
+ if file_or_folder == "file":
21
+ hf_hub_download(
22
+ repo_id=repo_id,
23
+ repo_type=repo_type,
24
+ filename=target,
25
+ local_dir=destination
26
+ )
27
+ elif file_or_folder == "folder":
28
+ snapshot_download(
29
+ repo_id=repo_id,
30
+ repo_type=repo_type,
31
+ allow_patterns=f"{target}/**",
32
+ local_dir=destination
33
+ )
34
+
35
+ print(f"Successfully downloaded '{target}' to '{destination}'.")
36
+ except Exception as e:
37
+ print(f"An error occurred during the download: {e}")
38
+
39
+ # Check if the file exists at the specified path
40
+ if not file_path.exists():
41
+ _download_file()
42
+ else:
43
+ print(f"File '{file_path}' already exists. No download needed.")
44
+