hmgill's picture
Upload 41 files
42bf28c verified
"""
CellposeAgent with proper VLM configuration
"""
import torch
import json
from datetime import datetime
from PIL import Image
from smolagents import ToolCallingAgent, InferenceClientModel
from smolagents.agents import ActionStep
from langfuse import get_client, observe
from config import settings
from utils.gpu import clear_gpu_cache
from tools import all_tools
langfuse = get_client()
class CellposeAgent:
@staticmethod
def attach_images_callback(step_log: ActionStep, agent: ToolCallingAgent) -> None:
"""
Callback to attach actual PIL images for VLM inspection.
Images are automatically resized to reduce token consumption.
"""
if not isinstance(step_log, ActionStep):
return
if not step_log.observations:
return
def resize_image(img: Image.Image, max_size: int = 1024) -> Image.Image:
"""Resize image maintaining aspect ratio, max dimension = max_size."""
if max(img.size) <= max_size:
return img
ratio = max_size / max(img.size)
new_size = tuple(int(dim * ratio) for dim in img.size)
resized = img.resize(new_size, Image.Resampling.LANCZOS)
print(f" Resized {img.size} β†’ {resized.size}")
return resized
try:
obs_data = json.loads(step_log.observations)
# Pattern 1: Single image from get_segmentation_parameters
if obs_data.get("status") == "success" and "image_path" in obs_data:
image_path = obs_data["image_path"]
print(f"[Callback] Attaching image: {image_path}")
try:
img = Image.open(image_path)
resized_img = resize_image(img)
# Attach resized PIL Image
step_log.observations_images = [resized_img]
# Keep metadata for context
obs_data["image_info"] = {
"original_dimensions": f"{img.size[0]}x{img.size[1]} pixels",
"resized_dimensions": f"{resized_img.size[0]}x{resized_img.size[1]} pixels",
"mode": resized_img.mode,
"note": "Image attached for visual inspection (resized for efficiency)"
}
step_log.observations = json.dumps(obs_data, indent=2)
print(f"[Callback] βœ“ Attached resized image for VLM inspection")
except Exception as e:
print(f"[Callback] Error attaching image: {e}")
# Pattern 2: Multiple images from refine_segmentation
elif obs_data.get("status") == "ready_for_visual_analysis":
paths = obs_data.get("image_paths", {})
original = paths.get("original")
segmented = paths.get("segmented")
if original and segmented:
print(f"[Callback] Attaching both original and segmented images")
try:
orig_img = Image.open(original)
seg_img = Image.open(segmented)
# Resize both images
resized_orig = resize_image(orig_img)
resized_seg = resize_image(seg_img)
# Attach both resized images as list
step_log.observations_images = [resized_orig, resized_seg]
obs_data["images_info"] = {
"image_order": ["original", "segmented"],
"original_size": f"{orig_img.size[0]}x{orig_img.size[1]}",
"resized_size": f"{resized_orig.size[0]}x{resized_orig.size[1]}",
"note": "Both images attached for visual comparison (resized for efficiency)"
}
step_log.observations = json.dumps(obs_data, indent=2)
print(f"[Callback] βœ“ Attached both resized images for VLM inspection")
except Exception as e:
print(f"[Callback] Error attaching images: {e}")
except json.JSONDecodeError:
pass
except Exception as e:
print(f"[Callback] Error in attach_images_callback: {e}")
@staticmethod
def manage_image_memory(step_log: ActionStep, agent: ToolCallingAgent) -> None:
"""
Aggressive memory management: keep ONLY the last step's images.
All previous steps have their images cleared immediately.
"""
if not isinstance(step_log, ActionStep):
return
current_step = step_log.step_number
# Clear images from ALL previous steps (keeping only current)
for previous_step in agent.memory.steps:
if isinstance(previous_step, ActionStep) and \
previous_step.step_number < current_step:
if previous_step.observations_images is not None:
print(f" [Memory] Clearing images from step {previous_step.step_number}")
previous_step.observations_images = None
def __init__(self):
self.instructions = """
You are an assistant for the cellpose-sam segmentation tool.
## PRIMARY WORKFLOW - IMAGE SEGMENTATION
When a user provides an image:
1. use appropriate tools to review which cellpose-sam parameters are available.
2. use the tool: `get_segmentation_parameters`
- **IMPORTANT**: After this tool runs, you will receive image metadata (dimensions, properties)
- Use this information to reason about appropriate parameter values
3. carefully analyze the image metadata and matched parameters:
- consider cell density based on image dimensions
- compare matched parameter values to image characteristics
- consider if adjustments would likely improve the segmentation
4. Be conservative: if you make changes, assess if they should differ significantly from the original values
5. Provide your final parameter recommendations in a clear, structured format
6. Use the parameters to run cellpose_sam through the tool: run_cellpose_sam
7. after run_cellpose_sam, call the tool: refine_cellpose_sam_segmentation
- **IMPORTANT**: After this tool runs, you will receive metadata about both original and segmented images
- Use the provided information to assess segmentation quality
8. Based on the metadata and any quality metrics returned:
- Identify potential segmentation issues based on reported metrics
- If refinement is needed, use knowledge graph and RAG tools to understand parameter effects
- Decide which parameters to adjust based on the segmentation analysis
- Re-run run_cellpose_sam with adjusted parameters
**CRITICAL: Call refine_cellpose_sam_segmentation AT MOST 2 TIMES total**
- First call: Check initial segmentation quality
- Second call (if needed): Verify refinement improved results
- NEVER call it a third time - always stop after 2 refinement checks
## DOCUMENTATION QUERY WORKFLOW ##
- "What is X": use `search_documentation_vector`
- "How does X affect Y": use `search_knowledge_graph`
- Complex analysis: use `hybrid_search`
- Parameter relationships: use `get_parameter_relationships`
## RESPONSE STYLE ##
- Be concise and actionable
- Always explain your reasoning when adjusting parameters
- If keeping original matched parameters, briefly confirm why it's appropriate
- Base your decisions on the metadata and metrics provided by the tools
"""
self.model = self._initialize_model()
self.agent = self._create_agent()
def _initialize_model(self):
"""Initializes the TransformersModel for the agent with VLM support."""
clear_gpu_cache()
return InferenceClientModel(
model_id=settings.AGENT_MODEL_ID,
token = settings.HF_TOKEN
)
def _create_agent(self):
"""Creates the ToolCallingAgent with all available tools and memory management."""
return ToolCallingAgent(
model=self.model,
tools=all_tools,
instructions=self.instructions,
max_steps=10,
step_callbacks=[
self.attach_images_callback,
self.manage_image_memory,
]
)
@observe()
def run(self, task: str):
"""Runs the agent on a given task with Langfuse tracing."""
print(f"\n{'='*60}\nTASK: {task}\n{'='*60}")
langfuse.update_current_trace(
input={"task": task},
user_id="user_001",
tags=["rag", "cellpose", "knowledge-graph", "vision"],
metadata={"agent_type": "ToolCallingAgent", "model_id": settings.AGENT_MODEL_ID}
)
try:
final_answer = self.agent.run(task)
print("\n--- Final Answer from Agent ---\n", final_answer)
langfuse.update_current_trace(output={"final_answer": final_answer})
return final_answer
except Exception as e:
print(f"Agent run failed: {e}")
langfuse.update_current_trace(output={"error": str(e)})
raise
finally:
clear_gpu_cache()