File size: 15,690 Bytes
42bf28c
4cf382c
 
 
 
42bf28c
 
 
eb03f5e
42bf28c
 
 
 
 
 
 
 
 
 
 
 
 
 
4cf382c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42bf28c
 
 
 
 
13f99ed
 
42bf28c
 
 
 
 
 
 
4e17a1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42bf28c
 
 
 
13f99ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42bf28c
aa704d4
13f99ed
42bf28c
 
 
aa704d4
4e17a1a
42bf28c
13f99ed
aa704d4
42bf28c
 
aa704d4
 
 
 
 
42bf28c
 
aa704d4
42bf28c
aa704d4
42bf28c
 
 
 
 
 
 
 
 
 
13f99ed
42bf28c
 
 
 
13f99ed
42bf28c
d2d7824
13f99ed
42bf28c
13f99ed
 
d2d7824
 
42bf28c
 
 
 
 
c97129d
4cf382c
 
 
 
 
 
 
 
 
 
 
42bf28c
 
4cf382c
 
13f99ed
 
4cf382c
 
 
 
42bf28c
 
4cf382c
 
13f99ed
 
aa704d4
 
 
 
 
 
42bf28c
aa704d4
42bf28c
c97129d
13f99ed
42bf28c
 
13f99ed
c97129d
42bf28c
 
 
 
 
c97129d
42bf28c
 
 
 
13f99ed
c97129d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42bf28c
 
 
 
 
 
 
aa704d4
42bf28c
 
 
 
aa704d4
13f99ed
42bf28c
 
 
 
 
 
 
 
 
e15c90f
42bf28c
 
 
 
 
 
 
4cf382c
 
 
 
 
 
 
 
42bf28c
 
4cf382c
 
 
 
aa704d4
 
4cf382c
aa704d4
 
4cf382c
 
 
 
 
aa704d4
 
 
42bf28c
 
 
 
aa704d4
 
 
 
 
 
42bf28c
 
 
aa704d4
 
 
 
42bf28c
 
aa704d4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
"""
CellposeAgent with proper VLM configuration, JPEG compression, and reliable path injection.

Key change: The agent stores the current image path in a global context that tools can access,
preventing the LLM from corrupting file paths when passing them as arguments.
"""
import torch
import json 
from io import BytesIO
from datetime import datetime
from PIL import Image
from smolagents import ToolCallingAgent, InferenceClientModel
from smolagents.agents import ActionStep
from langfuse import get_client, observe

from config import settings
from utils.gpu import clear_gpu_cache
from tools import all_tools


langfuse = get_client()


# =============================================================================
# GLOBAL IMAGE PATH CONTEXT
# =============================================================================
# This module-level variable stores the current image path reliably.
# Tools can access this directly instead of relying on the LLM to pass the path.
_current_image_context = {
    "image_path": None,
    "output_path": None,
}


def set_current_image_path(path: str) -> None:
    """Set the current image path for tools to access."""
    global _current_image_context
    _current_image_context["image_path"] = path
    _current_image_context["output_path"] = None  # Reset output on new image
    print(f"[Context] Set current image path: {path}")


def get_current_image_path() -> str | None:
    """Get the current image path (used by tools)."""
    return _current_image_context.get("image_path")


def set_current_output_path(path: str) -> None:
    """Set the current output path after segmentation."""
    global _current_image_context
    _current_image_context["output_path"] = path


def get_current_output_path() -> str | None:
    """Get the current output path (used by tools)."""
    return _current_image_context.get("output_path")


def get_image_context() -> dict:
    """Get the full image context dictionary."""
    return _current_image_context.copy()


class CellposeAgent:

    @staticmethod
    def attach_images_callback(step_log: ActionStep, agent: ToolCallingAgent) -> None:
        """
        Callback to attach actual PIL images for VLM inspection.
        Images are automatically resized and compressed to reduce token consumption.
        """
        if not isinstance(step_log, ActionStep):
            return
    
        if not step_log.observations:
            return
    
        def resize_and_compress_image(img: Image.Image, max_size: int = 512, quality: int = 75) -> Image.Image:
            """
            Resize and compress image to reduce payload size.
            
            Args:
                img: Input PIL Image
                max_size: Maximum dimension (width or height)
                quality: JPEG quality (1-95, lower = smaller file)
            
            Returns:
                Compressed PIL Image
            """
            # Convert to RGB if needed (JPEG doesn't support RGBA)
            if img.mode in ('RGBA', 'LA', 'P'):
                background = Image.new('RGB', img.size, (255, 255, 255))
                if img.mode == 'P':
                    img = img.convert('RGBA')
                background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
                img = background
            elif img.mode != 'RGB':
                img = img.convert('RGB')
            
            # Resize maintaining aspect ratio
            if max(img.size) > max_size:
                ratio = max_size / max(img.size)
                new_size = tuple(int(dim * ratio) for dim in img.size)
                img = img.resize(new_size, Image.Resampling.LANCZOS)
            
            # Compress using JPEG encoding
            buffer = BytesIO()
            img.save(buffer, format='JPEG', quality=quality, optimize=True)
            buffer.seek(0)
            compressed_img = Image.open(buffer)
            
            print(f"    Resized and compressed to {compressed_img.size}, quality={quality}")
            return compressed_img
    
        try:
            obs_data = json.loads(step_log.observations)
        
            # Pattern 1: Single image from get_segmentation_parameters
            if obs_data.get("status") == "success" and "image_path" in obs_data:
                image_path = obs_data["image_path"]
                print(f"[Callback] Attaching image: {image_path}")
            
                try:
                    img = Image.open(image_path)
                    compressed_img = resize_and_compress_image(img, max_size=512, quality=75)
                
                    # Attach compressed PIL Image
                    step_log.observations_images = [compressed_img]
                    
                    # Keep metadata for context
                    obs_data["image_info"] = {
                        "original_dimensions": f"{img.size[0]}x{img.size[1]} pixels",
                        "processed_dimensions": f"{compressed_img.size[0]}x{compressed_img.size[1]} pixels",
                        "mode": compressed_img.mode,
                        "note": "Image compressed for API efficiency (JPEG quality=75)"
                    }
                    step_log.observations = json.dumps(obs_data, indent=2)
                    print(f"[Callback] βœ“ Attached compressed image for VLM inspection")
                except Exception as e:
                    print(f"[Callback] Error attaching image: {e}")
        
            # Pattern 2: Segmented image ONLY from refine_segmentation
            elif obs_data.get("status") == "ready_for_visual_analysis":
                paths = obs_data.get("image_paths", {})
                segmented = paths.get("segmented")
            
                if segmented:
                    print(f"[Callback] Attaching segmented image only: {segmented}")
                    try:
                        seg_img = Image.open(segmented)
                    
                        # Compress the segmented image
                        compressed_seg = resize_and_compress_image(seg_img, max_size=512, quality=75)
                    
                        # Attach only the segmented image
                        step_log.observations_images = [compressed_seg]
                    
                        obs_data["images_info"] = {
                            "image_type": "segmented_overlay",
                            "description": "Segmentation result with colored cell masks overlaid on original image",
                            "original_size": f"{seg_img.size[0]}x{seg_img.size[1]}",
                            "processed_size": f"{compressed_seg.size[0]}x{compressed_seg.size[1]}",
                            "note": "Segmented image attached for quality assessment (JPEG quality=75)"
                        }
                        step_log.observations = json.dumps(obs_data, indent=2)
                        print(f"[Callback] βœ“ Attached compressed segmented image for VLM inspection")
                    except Exception as e:
                        print(f"[Callback] Error attaching segmented image: {e}")
    
        except json.JSONDecodeError:
            pass
        except Exception as e:
            print(f"[Callback] Error in attach_images_callback: {e}")


    @staticmethod
    def manage_image_memory(step_log: ActionStep, agent: ToolCallingAgent) -> None:
        """
        Clear images from ALL previous steps at the START of each new step.
        """
        if not isinstance(step_log, ActionStep):
            return
    
        # Clear ALL previous step images immediately
        for previous_step in agent.memory.steps:
            if isinstance(previous_step, ActionStep):
                if previous_step.observations_images is not None:
                    print(f"  [Memory] Clearing images from step {previous_step.step_number}")
                    previous_step.observations_images = []  # Use empty list instead of None
                    # Also try to clear any cached references
                    if hasattr(previous_step, '_observations_images'):
                        previous_step._observations_images = []

                    
    def __init__(self):
        self.instructions = """
        You are an assistant for the cellpose-sam segmentation tool.
        
        ## CRITICAL: IMAGE PATH HANDLING ##
        **The image path is automatically available to all tools. You do NOT need to pass the exact path.**
        When calling tools that require an image path, you can pass an empty string "" or any placeholder - 
        the tool will automatically use the correct image path from the system context.
        
        Example: Instead of trying to reproduce the exact path, just call:
        `get_segmentation_parameters(image_path="")`
        `run_cellpose_sam(image_path="")`
        
        The system will automatically use the correct path.
        
        ## PRIMARY WORKFLOW - IMAGE SEGMENTATION 
        When a user provides an image:
        1. Use appropriate tools to review which cellpose-sam parameters are available. 
        2. Use the tool: `get_segmentation_parameters` with image_path=""
           - **IMPORTANT**: After this tool runs, you will receive image metadata (dimensions, properties)
           - Use this information to reason about appropriate parameter values
        3. Carefully analyze the image metadata and matched parameters:
           - Consider cell density based on image dimensions
           - Compare matched parameter values to image characteristics
           - Consider if adjustments would likely improve the segmentation
        4. Be conservative: if you make changes, assess if they should differ significantly from the original values
        5. Provide your final parameter recommendations in a clear, structured format 
        6. Use the parameters to run cellpose_sam through the tool: run_cellpose_sam with image_path=""
        7. After run_cellpose_sam, call the tool: refine_cellpose_sam_segmentation
           - **IMPORTANT**: After this tool runs, you will see the SEGMENTED image (colored masks overlay)
           - Visually inspect the segmentation quality - are cells properly detected and separated?
           - Use the visual analysis checklist provided in the tool output
        8. Based on visual analysis of the segmented image:
           - Assess if cell boundaries are accurate
           - Check if neighboring cells are properly separated or merged
           - Look for false positive detections (noise)
           - Identify any obvious cells that were missed
           - If refinement is needed, use knowledge graph and RAG tools to understand parameter effects
           - Decide which parameters to adjust based on what you observe
           - Re-run run_cellpose_sam with adjusted parameters   
        
        **CRITICAL: Call refine_cellpose_sam_segmentation AT MOST 1 TIMES total**
           - First call: Check initial segmentation quality
           - Second call (if needed): Verify refinement improved results
           - NEVER call it a second time - always stop after 1 refinement check
        
        ## DOCUMENTATION QUERY WORKFLOW ##
        - "What is X": use `search_documentation_vector`
        - "How does X affect Y": use `search_knowledge_graph`  
        - Complex analysis: use `hybrid_search`
        - Parameter relationships: use `get_parameter_relationships`        
        
        ## RESPONSE STYLE ##
        - Be concise and actionable
        - Always explain your reasoning when adjusting parameters
        - If keeping original matched parameters, briefly confirm why it's appropriate
        - Base your decisions on visual observation of the segmented output
        
        **CRITICAL - Final Response Format:**
        When segmentation is complete, you MUST provide a comprehensive text summary that includes:
        1. A brief statement about the segmentation completion
        2. Number of cells detected
        3. The final parameters used (diameter, flow_threshold, cellprob_threshold, min_size)
        4. A quality assessment (e.g., "excellent", "good", "acceptable")
        5. Any observations about the segmentation (e.g., "cells well-separated", "some clustering")
        6. The output file path at the end
        
        Example good final response:
        "Segmentation completed successfully! I detected 42 cells in your image using the following parameters:
        - diameter: 30
        - flow_threshold: 0.6
        - cellprob_threshold: 0
        - min_size: 15
        
        The segmentation quality looks excellent - cell boundaries are well-defined and neighboring cells are properly separated with minimal false positives.
        
        Output saved to: /path/to/image_cellpose_sam_overlay.png"
        
        **NEVER return just a filename like "segmentation_output.png" - always provide the full context above.**
        """
        
        self.model = self._initialize_model()
        self.agent = self._create_agent()
        

    def _initialize_model(self):
        """Initializes the InferenceClientModel for the agent with VLM support."""
        clear_gpu_cache()

        return InferenceClientModel(
            model_id=settings.AGENT_MODEL_ID,
            token=settings.HF_TOKEN,
            timeout=240  # 3 minutes timeout for API calls
        )


    def _create_agent(self):
        """Creates the ToolCallingAgent with all available tools and memory management."""
        return ToolCallingAgent(
            model=self.model,
            tools=all_tools,
            instructions=self.instructions,
            max_steps=6,
            step_callbacks=[
                self.attach_images_callback,
                self.manage_image_memory,
            ]
        )

    @observe()
    def run(self, task: str, image_path: str = None):
        """
        Runs the agent on a given task with Langfuse tracing.
        
        Args:
            task: The task description/prompt
            image_path: Optional path to the image (will be stored in context for tools)
        """
        print(f"\n{'='*60}\nTASK: {task}\n{'='*60}")
        
        # Store image path in global context for tools to access
        if image_path:
            set_current_image_path(image_path)
        
        try:
            langfuse.update_current_trace(
                input={"task": task, "image_path": image_path},
                user_id="user_001",
                tags=["rag", "cellpose", "knowledge-graph", "vision"],
                metadata={
                    "agent_type": "ToolCallingAgent", 
                    "model_id": settings.AGENT_MODEL_ID,
                    "image_context": get_image_context()
                }
            )
        except Exception as e:
            print(f"Warning: Could not update Langfuse trace: {e}")

        try:
            final_answer = self.agent.run(task)
            print("\n--- Final Answer from Agent ---\n", final_answer)
            
            try:
                langfuse.update_current_trace(output={"final_answer": final_answer})
            except Exception as e:
                print(f"Warning: Could not update Langfuse output: {e}")
                
            return final_answer
        except Exception as e:
            print(f"Agent run failed: {e}")
            try:
                langfuse.update_current_trace(output={"error": str(e)})
            except Exception as log_error:
                print(f"Warning: Could not log error to Langfuse: {log_error}")
            raise
        finally:
            clear_gpu_cache()