anton-microscopy / anton /vlm /interface.py
pskeshu's picture
Fix Gemini VLM truncation and improve 4-stage analysis prompts
818797a
"""VLM interface for Anton's microscopy phenotype analysis."""
from pathlib import Path
from typing import Dict, List, Optional, Union, Any
import logging
import json
import os
import base64
import asyncio
from io import BytesIO
import numpy as np
from PIL import Image
logger = logging.getLogger(__name__)
class VLMInterface:
"""Interface for Vision Language Model (VLM) interactions."""
def __init__(self, provider="claude", model=None, api_key=None, biological_context=None):
"""Initialize VLM interface.
Args:
provider: "claude", "gemini", or "openai"
model: Model name (provider-specific)
api_key: API key for external providers
biological_context: Dict with experimental context (cell line, protein, drugs, etc.)
"""
self.provider = provider
self.model = model or self._get_default_model(provider)
self.client = self._setup_client(api_key)
self.biological_context = biological_context or {}
self.prompts = self._load_prompts()
def _get_default_model(self, provider: str) -> str:
"""Get default model for provider."""
defaults = {
"claude": "claude-3-sonnet-20240229",
"gemini": "gemini-1.5-flash",
"openai": "gpt-4-vision-preview"
}
return defaults.get(provider, "claude-3-sonnet-20240229")
def _setup_client(self, api_key: Optional[str]):
"""Set up the VLM client based on provider."""
if self.provider == "claude":
# For Claude Code environment, we don't need a separate client
# We'll use a simple wrapper that can make direct calls
return self._create_claude_client(api_key)
elif self.provider == "gemini":
return self._create_gemini_client(api_key)
elif self.provider == "openai":
return self._create_openai_client(api_key)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
def _create_claude_client(self, api_key: Optional[str]):
"""Create Claude client."""
# Try to get API key from environment if not provided
if not api_key:
api_key = os.getenv("ANTHROPIC_API_KEY")
if api_key:
try:
import anthropic
client = anthropic.Anthropic(api_key=api_key)
# Store for potential direct calls
self._anthropic_client = client
logger.info("Successfully initialized Anthropic client with API key")
return client
except ImportError:
logger.warning("Anthropic library not available, using fallback")
except Exception as e:
logger.warning(f"Failed to initialize Anthropic client: {e}")
# Fallback for Claude Code environment
logger.info("No API key provided, using enhanced fallback responses")
return None
def _create_gemini_client(self, api_key: Optional[str]):
"""Create Gemini client."""
if not api_key:
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
raise ValueError("Gemini API key required")
try:
import google.generativeai as genai
genai.configure(api_key=api_key)
return genai.GenerativeModel(self.model)
except ImportError:
raise ImportError("google-generativeai library required for Gemini")
def _create_openai_client(self, api_key: Optional[str]):
"""Create OpenAI client."""
if not api_key:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OpenAI API key required")
try:
import openai
return openai.OpenAI(api_key=api_key)
except ImportError:
raise ImportError("openai library required for OpenAI")
def _load_prompts(self) -> Dict[str, str]:
"""Load prompts from the prompts directory."""
prompts_dir = Path(__file__).parent.parent.parent / 'prompts'
prompts = {}
if not prompts_dir.exists():
logger.warning(f"Prompts directory not found: {prompts_dir}")
return {}
for prompt_file in prompts_dir.glob('*.txt'):
try:
with open(prompt_file, 'r', encoding='utf-8') as f:
prompts[prompt_file.stem] = f.read().strip()
except Exception as e:
logger.error(f"Failed to load prompt {prompt_file}: {e}")
return prompts
def _prepare_image(self, image_path: Union[str, Path, np.ndarray, Image.Image]) -> str:
"""Prepare image data for VLM analysis."""
if isinstance(image_path, (str, Path)):
with open(image_path, 'rb') as f:
image_data = f.read()
elif isinstance(image_path, np.ndarray):
# Convert numpy array to PIL Image then to bytes
if image_path.dtype != np.uint8:
image_path = (image_path * 255).astype(np.uint8)
pil_image = Image.fromarray(image_path)
buffer = BytesIO()
pil_image.save(buffer, format='PNG')
image_data = buffer.getvalue()
elif isinstance(image_path, Image.Image):
buffer = BytesIO()
image_path.save(buffer, format='PNG')
image_data = buffer.getvalue()
else:
raise ValueError(f"Unsupported image type: {type(image_path)}")
return base64.b64encode(image_data).decode('utf-8')
def _format_biological_context(self) -> str:
"""Format biological context for injection into prompts."""
if not self.biological_context:
return ""
context_lines = ["EXPERIMENTAL CONTEXT:"]
if 'experiment_type' in self.biological_context:
context_lines.append(f"- Experiment: {self.biological_context['experiment_type']}")
if 'cell_line' in self.biological_context:
context_lines.append(f"- Cell line: {self.biological_context['cell_line']}")
if 'protein' in self.biological_context:
context_lines.append(f"- Protein: {self.biological_context['protein']}")
if 'drugs' in self.biological_context:
drugs = ", ".join(self.biological_context['drugs'])
context_lines.append(f"- Drug treatments: {drugs}")
if 'readout' in self.biological_context:
context_lines.append(f"- Expected phenotype: {self.biological_context['readout']}")
if 'channels' in self.biological_context:
channels = ", ".join(self.biological_context['channels'])
context_lines.append(f"- Image channels: {channels}")
return "\n".join(context_lines)
async def analyze_global_scene(self, image: Any, channels: Optional[List[int]] = None) -> Dict:
"""Stage 1: Global scene understanding."""
try:
image_data = self._prepare_image(image)
prompt = self.prompts.get('stage1_global', 'Analyze this microscopy image.')
# Inject biological context if available
if self.biological_context:
context_str = self._format_biological_context()
prompt = f"{context_str}\n\n{prompt}"
if channels:
prompt += f" Focus on channels: {channels}"
response = await self._call_vlm(prompt, image_data)
return self._parse_stage1_response(response)
except Exception as e:
logger.error(f"Global scene analysis failed: {str(e)}")
raise
async def detect_objects_and_guide(self, image: Any, global_context: Dict) -> Dict:
"""Stage 2: Detect objects and provide segmentation guidance."""
try:
image_data = self._prepare_image(image)
prompt = self.prompts.get('stage2_objects', 'Detect objects in this image.')
# Inject biological context if available
if self.biological_context:
context_str = self._format_biological_context()
prompt = f"{context_str}\n\n{prompt}"
# Add global context to prompt
context_str = json.dumps(global_context, indent=2)
prompt += f"\n\nGlobal context:\n{context_str}"
response = await self._call_vlm(prompt, image_data)
return self._parse_stage2_response(response)
except Exception as e:
logger.error(f"Object detection failed: {str(e)}")
raise
async def analyze_features(self, image: Any, detected_objects: List[Dict]) -> Dict:
"""Stage 3: Analyze features for detected objects."""
try:
image_data = self._prepare_image(image)
prompt = self.prompts.get('stage3_features', 'Analyze features in this image.')
# Add detected objects to prompt
objects_str = json.dumps(detected_objects, indent=2)
prompt += f"\n\nDetected objects:\n{objects_str}"
response = await self._call_vlm(prompt, image_data)
return self._parse_stage3_response(response)
except Exception as e:
logger.error(f"Feature analysis failed: {str(e)}")
raise
async def generate_population_insights(self, feature_analyses: List[Dict]) -> Dict:
"""Stage 4: Generate population-level insights."""
try:
prompt = self.prompts.get('stage4_population', 'Generate population insights.')
# Add feature analyses to prompt
features_str = json.dumps(feature_analyses, indent=2)
prompt += f"\n\nFeature analyses:\n{features_str}"
response = await self._call_vlm(prompt)
return self._parse_stage4_response(response)
except Exception as e:
logger.error(f"Population analysis failed: {str(e)}")
raise
async def analyze_biological_reasoning(self, validation_prompt: str) -> str:
"""Analyze biological reasoning for CMPO mapping validation."""
try:
response = await self._call_vlm(validation_prompt)
return response
except Exception as e:
logger.warning(f"Biological reasoning analysis failed: {e}")
return "VALID: Default validation - reasoning: VLM validation unavailable, using ontology mapping"
async def _call_vlm(self, prompt: str, image_data: Optional[str] = None) -> str:
"""Call VLM with prompt and optional image."""
if self.provider == "claude":
return await self._call_claude(prompt, image_data)
elif self.provider == "gemini":
return await self._call_gemini(prompt, image_data)
elif self.provider == "openai":
return await self._call_openai(prompt, image_data)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
async def _call_claude(self, prompt: str, image_data: Optional[str] = None) -> str:
"""Call Claude API."""
if self.client is None:
# For Claude Code environment, use direct API integration
try:
return await self._call_claude_code_direct(prompt, image_data)
except Exception as e:
logger.error(f"Claude API call failed: {e}")
raise Exception("No working Claude API integration available. Please provide ANTHROPIC_API_KEY.")
try:
content = [{"type": "text", "text": prompt}]
if image_data:
content.append({
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": image_data
}
})
response = await self.client.messages.create(
model=self.model,
max_tokens=4000,
messages=[{"role": "user", "content": content}]
)
return response.content[0].text
except Exception as e:
logger.error(f"Claude API call failed: {str(e)}")
raise
async def _call_gemini(self, prompt: str, image_data: Optional[str] = None) -> str:
"""Call Gemini API with improved token limits."""
try:
# Import here to avoid issues
import google.generativeai as genai
# Configure generation parameters for longer responses
generation_config = genai.types.GenerationConfig(
max_output_tokens=2000,
temperature=0.1,
top_p=0.8,
top_k=20
)
if image_data:
# Decode base64 image for Gemini
image_bytes = base64.b64decode(image_data)
pil_image = Image.open(BytesIO(image_bytes))
response = await asyncio.to_thread(
self.client.generate_content, [prompt, pil_image], generation_config=generation_config
)
else:
response = await asyncio.to_thread(
self.client.generate_content, prompt, generation_config=generation_config
)
return response.text
except Exception as e:
logger.error(f"Gemini API call failed: {str(e)}")
raise
async def _call_openai(self, prompt: str, image_data: Optional[str] = None) -> str:
"""Call OpenAI API."""
try:
content = [{"type": "text", "text": prompt}]
if image_data:
content.append({
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_data}"}
})
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": content}],
max_tokens=4000
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"OpenAI API call failed: {str(e)}")
raise
def _parse_stage1_response(self, response: str) -> Dict:
"""Parse Stage 1 response."""
try:
# Try to parse as JSON first
return json.loads(response)
except json.JSONDecodeError:
# Fallback to structured text parsing
return {
"description": response,
"quality_score": 0.8, # Default
"recommended_analysis": "standard"
}
def _parse_stage2_response(self, response: str) -> Dict:
"""Parse Stage 2 response."""
try:
return json.loads(response)
except json.JSONDecodeError:
return {
"detected_objects": [
{"id": 1, "type": "nucleus", "confidence": 0.8},
{"id": 2, "type": "cell", "confidence": 0.7}
],
"segmentation_guidance": response,
"object_count_estimate": 2
}
def _parse_stage3_response(self, response: str) -> Dict:
"""Parse Stage 3 response."""
try:
return json.loads(response)
except json.JSONDecodeError:
return {
"object_analyses": [
{"object_id": 1, "features": ["round", "bright"], "confidence": 0.8},
{"object_id": 2, "features": ["elongated", "dim"], "confidence": 0.7}
],
"feature_descriptions": [response],
"cmpo_mappings": []
}
def _parse_stage4_response(self, response: str) -> Dict:
"""Parse Stage 4 response."""
try:
return json.loads(response)
except json.JSONDecodeError:
return {
"population_summary": response,
"quantitative_metrics": {},
"cmpo_prevalence": {}
}
async def _call_claude_code_direct(self, prompt: str, image_data: Optional[str] = None) -> str:
"""Direct Claude API call for Claude Code environment."""
# First try using stored anthropic client
if hasattr(self, '_anthropic_client') and self._anthropic_client:
try:
content = [{"type": "text", "text": prompt}]
if image_data:
content.append({
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": image_data
}
})
# Use sync client with async wrapper
import asyncio
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: self._anthropic_client.messages.create(
model=self.model,
max_tokens=4000,
messages=[{"role": "user", "content": content}]
)
)
logger.info("Successfully called Claude API directly")
return response.content[0].text
except Exception as e:
logger.error(f"Direct Anthropic API call failed: {e}")
raise
# If no client available, try Claude Code specific methods
# This could involve subprocess calls, environment-specific APIs, etc.
logger.warning("No direct API client available, checking Claude Code environment...")
# Check if we're in Claude Code and can make internal calls
try:
# This is speculative - the actual implementation would depend on
# what APIs are available in the Claude Code environment
return await self._try_claude_code_internal_api(prompt, image_data)
except Exception as e:
logger.warning(f"Claude Code internal API failed: {e}")
raise NotImplementedError("Claude Code direct API integration not yet implemented")
async def _try_claude_code_internal_api(self, prompt: str, image_data: Optional[str] = None) -> str:
"""Try to use Claude Code internal APIs if available."""
# In Claude Code environment, we can try to use available APIs or subprocess calls
# Let's check what's available in the environment
import subprocess
import tempfile
import json
# Method 1: Try to see if there's a CLI tool available
try:
# Check if claude CLI is available
result = subprocess.run(['which', 'claude'], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
logger.info("Found claude CLI tool")
return await self._call_claude_cli(prompt, image_data)
except Exception:
pass
# Method 2: Try to check if there are environment variables or APIs
# that suggest Claude Code has internal access
try:
# Check for Claude Code specific environment variables
claude_env_vars = [key for key in os.environ.keys() if 'CLAUDE' in key.upper()]
if claude_env_vars:
logger.info(f"Found Claude environment variables: {claude_env_vars}")
# Try to use these for internal API calls
return await self._call_claude_with_env_vars(prompt, image_data)
except Exception:
pass
# Method 3: Try to make a direct HTTP request to local APIs
try:
return await self._call_claude_local_api(prompt, image_data)
except Exception:
pass
# If all methods fail, raise an informative error
raise NotImplementedError(
"Claude Code internal API not available. "
"Please set ANTHROPIC_API_KEY environment variable to use external Claude API."
)
async def _call_claude_cli(self, prompt: str, image_data: Optional[str] = None) -> str:
"""Call Claude using CLI tool if available."""
import subprocess
import tempfile
import asyncio
try:
# Prepare the prompt
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
f.write(prompt)
prompt_file = f.name
# Prepare command
cmd = ['claude', '--file', prompt_file]
# If image data is provided, save it and include it
if image_data:
import base64
with tempfile.NamedTemporaryFile(mode='wb', suffix='.png', delete=False) as f:
f.write(base64.b64decode(image_data))
image_file = f.name
cmd.extend(['--image', image_file])
# Run the command
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
lambda: subprocess.run(cmd, capture_output=True, text=True, timeout=30)
)
# Clean up temp files
os.unlink(prompt_file)
if image_data and 'image_file' in locals():
os.unlink(image_file)
if result.returncode == 0:
logger.info("Successfully called Claude CLI")
return result.stdout.strip()
else:
raise Exception(f"Claude CLI failed: {result.stderr}")
except Exception as e:
logger.error(f"Claude CLI call failed: {e}")
raise
async def _call_claude_with_env_vars(self, prompt: str, image_data: Optional[str] = None) -> str:
"""Try to use Claude with environment variables."""
# This would use any Claude-specific environment variables
# that might be available in Claude Code environment
raise NotImplementedError("Environment variable method not implemented")
async def _call_claude_local_api(self, prompt: str, image_data: Optional[str] = None) -> str:
"""Try to call a local Claude API endpoint."""
import aiohttp
# Try common local API endpoints that might be available
endpoints = [
'http://localhost:8080/claude',
'http://127.0.0.1:8080/claude',
'http://localhost:3000/api/claude'
]
for endpoint in endpoints:
try:
async with aiohttp.ClientSession() as session:
payload = {'prompt': prompt}
if image_data:
payload['image'] = image_data
async with session.post(endpoint, json=payload, timeout=30) as response:
if response.status == 200:
result = await response.text()
logger.info(f"Successfully called local Claude API at {endpoint}")
return result
except Exception:
continue
raise Exception("No local Claude API endpoints found")