Spaces:
Runtime error
Runtime error
| """ | |
| CellposeAgent with proper VLM configuration | |
| """ | |
| import torch | |
| import json | |
| from datetime import datetime | |
| from PIL import Image | |
| from smolagents import ToolCallingAgent, InferenceClientModel | |
| from smolagents.agents import ActionStep | |
| from langfuse import get_client, observe | |
| from config import settings | |
| from utils.gpu import clear_gpu_cache | |
| from tools import all_tools | |
| langfuse = get_client() | |
| class CellposeAgent: | |
| def attach_images_callback(step_log: ActionStep, agent: ToolCallingAgent) -> None: | |
| """ | |
| Callback to attach actual PIL images for VLM inspection. | |
| Images are automatically resized to reduce token consumption. | |
| """ | |
| if not isinstance(step_log, ActionStep): | |
| return | |
| if not step_log.observations: | |
| return | |
| def resize_image(img: Image.Image, max_size: int = 1024) -> Image.Image: | |
| """Resize image maintaining aspect ratio, max dimension = max_size.""" | |
| if max(img.size) <= max_size: | |
| return img | |
| ratio = max_size / max(img.size) | |
| new_size = tuple(int(dim * ratio) for dim in img.size) | |
| resized = img.resize(new_size, Image.Resampling.LANCZOS) | |
| print(f" Resized {img.size} β {resized.size}") | |
| return resized | |
| try: | |
| obs_data = json.loads(step_log.observations) | |
| # Pattern 1: Single image from get_segmentation_parameters | |
| if obs_data.get("status") == "success" and "image_path" in obs_data: | |
| image_path = obs_data["image_path"] | |
| print(f"[Callback] Attaching image: {image_path}") | |
| try: | |
| img = Image.open(image_path) | |
| resized_img = resize_image(img) | |
| # Attach resized PIL Image | |
| step_log.observations_images = [resized_img] | |
| # Keep metadata for context | |
| obs_data["image_info"] = { | |
| "original_dimensions": f"{img.size[0]}x{img.size[1]} pixels", | |
| "resized_dimensions": f"{resized_img.size[0]}x{resized_img.size[1]} pixels", | |
| "mode": resized_img.mode, | |
| "note": "Image attached for visual inspection (resized for efficiency)" | |
| } | |
| step_log.observations = json.dumps(obs_data, indent=2) | |
| print(f"[Callback] β Attached resized image for VLM inspection") | |
| except Exception as e: | |
| print(f"[Callback] Error attaching image: {e}") | |
| # Pattern 2: Multiple images from refine_segmentation | |
| elif obs_data.get("status") == "ready_for_visual_analysis": | |
| paths = obs_data.get("image_paths", {}) | |
| original = paths.get("original") | |
| segmented = paths.get("segmented") | |
| if original and segmented: | |
| print(f"[Callback] Attaching both original and segmented images") | |
| try: | |
| orig_img = Image.open(original) | |
| seg_img = Image.open(segmented) | |
| # Resize both images | |
| resized_orig = resize_image(orig_img) | |
| resized_seg = resize_image(seg_img) | |
| # Attach both resized images as list | |
| step_log.observations_images = [resized_orig, resized_seg] | |
| obs_data["images_info"] = { | |
| "image_order": ["original", "segmented"], | |
| "original_size": f"{orig_img.size[0]}x{orig_img.size[1]}", | |
| "resized_size": f"{resized_orig.size[0]}x{resized_orig.size[1]}", | |
| "note": "Both images attached for visual comparison (resized for efficiency)" | |
| } | |
| step_log.observations = json.dumps(obs_data, indent=2) | |
| print(f"[Callback] β Attached both resized images for VLM inspection") | |
| except Exception as e: | |
| print(f"[Callback] Error attaching images: {e}") | |
| except json.JSONDecodeError: | |
| pass | |
| except Exception as e: | |
| print(f"[Callback] Error in attach_images_callback: {e}") | |
| def manage_image_memory(step_log: ActionStep, agent: ToolCallingAgent) -> None: | |
| """ | |
| Aggressive memory management: keep ONLY the last step's images. | |
| All previous steps have their images cleared immediately. | |
| """ | |
| if not isinstance(step_log, ActionStep): | |
| return | |
| current_step = step_log.step_number | |
| # Clear images from ALL previous steps (keeping only current) | |
| for previous_step in agent.memory.steps: | |
| if isinstance(previous_step, ActionStep) and \ | |
| previous_step.step_number < current_step: | |
| if previous_step.observations_images is not None: | |
| print(f" [Memory] Clearing images from step {previous_step.step_number}") | |
| previous_step.observations_images = None | |
| def __init__(self): | |
| self.instructions = """ | |
| You are an assistant for the cellpose-sam segmentation tool. | |
| ## PRIMARY WORKFLOW - IMAGE SEGMENTATION | |
| When a user provides an image: | |
| 1. use appropriate tools to review which cellpose-sam parameters are available. | |
| 2. use the tool: `get_segmentation_parameters` | |
| - **IMPORTANT**: After this tool runs, you will receive image metadata (dimensions, properties) | |
| - Use this information to reason about appropriate parameter values | |
| 3. carefully analyze the image metadata and matched parameters: | |
| - consider cell density based on image dimensions | |
| - compare matched parameter values to image characteristics | |
| - consider if adjustments would likely improve the segmentation | |
| 4. Be conservative: if you make changes, assess if they should differ significantly from the original values | |
| 5. Provide your final parameter recommendations in a clear, structured format | |
| 6. Use the parameters to run cellpose_sam through the tool: run_cellpose_sam | |
| 7. after run_cellpose_sam, call the tool: refine_cellpose_sam_segmentation | |
| - **IMPORTANT**: After this tool runs, you will receive metadata about both original and segmented images | |
| - Use the provided information to assess segmentation quality | |
| 8. Based on the metadata and any quality metrics returned: | |
| - Identify potential segmentation issues based on reported metrics | |
| - If refinement is needed, use knowledge graph and RAG tools to understand parameter effects | |
| - Decide which parameters to adjust based on the segmentation analysis | |
| - Re-run run_cellpose_sam with adjusted parameters | |
| **CRITICAL: Call refine_cellpose_sam_segmentation AT MOST 2 TIMES total** | |
| - First call: Check initial segmentation quality | |
| - Second call (if needed): Verify refinement improved results | |
| - NEVER call it a third time - always stop after 2 refinement checks | |
| ## DOCUMENTATION QUERY WORKFLOW ## | |
| - "What is X": use `search_documentation_vector` | |
| - "How does X affect Y": use `search_knowledge_graph` | |
| - Complex analysis: use `hybrid_search` | |
| - Parameter relationships: use `get_parameter_relationships` | |
| ## RESPONSE STYLE ## | |
| - Be concise and actionable | |
| - Always explain your reasoning when adjusting parameters | |
| - If keeping original matched parameters, briefly confirm why it's appropriate | |
| - Base your decisions on the metadata and metrics provided by the tools | |
| """ | |
| self.model = self._initialize_model() | |
| self.agent = self._create_agent() | |
| def _initialize_model(self): | |
| """Initializes the TransformersModel for the agent with VLM support.""" | |
| clear_gpu_cache() | |
| return InferenceClientModel( | |
| model_id=settings.AGENT_MODEL_ID, | |
| token = settings.HF_TOKEN | |
| ) | |
| def _create_agent(self): | |
| """Creates the ToolCallingAgent with all available tools and memory management.""" | |
| return ToolCallingAgent( | |
| model=self.model, | |
| tools=all_tools, | |
| instructions=self.instructions, | |
| max_steps=10, | |
| step_callbacks=[ | |
| self.attach_images_callback, | |
| self.manage_image_memory, | |
| ] | |
| ) | |
| def run(self, task: str): | |
| """Runs the agent on a given task with Langfuse tracing.""" | |
| print(f"\n{'='*60}\nTASK: {task}\n{'='*60}") | |
| langfuse.update_current_trace( | |
| input={"task": task}, | |
| user_id="user_001", | |
| tags=["rag", "cellpose", "knowledge-graph", "vision"], | |
| metadata={"agent_type": "ToolCallingAgent", "model_id": settings.AGENT_MODEL_ID} | |
| ) | |
| try: | |
| final_answer = self.agent.run(task) | |
| print("\n--- Final Answer from Agent ---\n", final_answer) | |
| langfuse.update_current_trace(output={"final_answer": final_answer}) | |
| return final_answer | |
| except Exception as e: | |
| print(f"Agent run failed: {e}") | |
| langfuse.update_current_trace(output={"error": str(e)}) | |
| raise | |
| finally: | |
| clear_gpu_cache() | |