Yeetek's picture
Upload 43 files
7d5083d verified
"""
Model manager for Anthropic models.
This module provides utilities for managing multiple Anthropic models,
environment-based switching, and model performance tracking.
"""
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
from datetime import datetime, timedelta
from config.settings import get_settings, AnthropicModel
from config.logging import get_logger
from core.anthropic_client import AnthropicClient, AnthropicResponse
logger = get_logger(__name__)
@dataclass
class ModelPerformance:
"""Performance metrics for a specific model."""
model_name: str
total_requests: int = 0
total_tokens: int = 0
total_errors: int = 0
avg_response_time: float = 0.0
last_used: Optional[datetime] = None
success_rate: float = 100.0
class ModelManager:
"""
Manager for multiple Anthropic models with performance tracking.
Features:
- Environment-based model selection
- Performance tracking per model
- Automatic model switching based on performance
- Model health monitoring
"""
def __init__(self):
"""Initialize model manager."""
self.settings = get_settings()
self.current_model = self.settings.anthropic_model
self.performance_stats: Dict[str, ModelPerformance] = {}
self.clients: Dict[str, AnthropicClient] = {}
# Initialize performance tracking for all models
for model in AnthropicModel:
self.performance_stats[model.value] = ModelPerformance(model_name=model.value)
logger.info(f"Model manager initialized with default model: {self.current_model.value}")
async def get_client(self, model: Optional[AnthropicModel] = None) -> AnthropicClient:
"""
Get or create client for specified model.
Args:
model: Model to get client for (uses current if not specified)
Returns:
AnthropicClient instance for the model
"""
target_model = model or self.current_model
model_key = target_model.value
if model_key not in self.clients:
self.clients[model_key] = AnthropicClient(model=target_model)
logger.info(f"Created new client for model: {model_key}")
return self.clients[model_key]
async def generate_completion(
self,
messages: List[Dict[str, str]],
max_tokens: int = 4000,
temperature: float = 0.0,
system_prompt: Optional[str] = None,
model: Optional[AnthropicModel] = None,
fallback_on_error: bool = True
) -> AnthropicResponse:
"""
Generate completion with automatic fallback and performance tracking.
Args:
messages: List of message dictionaries
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
system_prompt: Optional system prompt
model: Model to use (uses current if not specified)
fallback_on_error: Whether to try fallback model on error
Returns:
AnthropicResponse with completion and metadata
"""
target_model = model or self.current_model
start_time = datetime.now()
try:
client = await self.get_client(target_model)
response = await client.generate_completion(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
system_prompt=system_prompt,
model=target_model
)
# Update performance stats
self._update_performance_stats(target_model.value, response, start_time, success=True)
return response
except Exception as e:
# Update error stats
self._update_performance_stats(target_model.value, None, start_time, success=False)
# Try fallback model if enabled and not already using fallback
if fallback_on_error and target_model != self._get_fallback_model():
logger.warning(f"Model {target_model.value} failed, trying fallback: {str(e)}")
try:
fallback_model = self._get_fallback_model()
client = await self.get_client(fallback_model)
response = await client.generate_completion(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
system_prompt=system_prompt,
model=fallback_model
)
# Update performance stats for fallback
self._update_performance_stats(fallback_model.value, response, start_time, success=True)
logger.info(f"Fallback model {fallback_model.value} succeeded")
return response
except Exception as fallback_error:
logger.error(f"Fallback model also failed: {str(fallback_error)}")
self._update_performance_stats(self._get_fallback_model().value, None, start_time, success=False)
# Re-raise original exception if no fallback or fallback failed
raise
def switch_model(self, model: AnthropicModel) -> None:
"""
Switch the current default model.
Args:
model: New model to use as default
"""
old_model = self.current_model.value
self.current_model = model
logger.info(f"Switched default model from {old_model} to {model.value}")
def get_model_from_env(self) -> AnthropicModel:
"""
Get model from environment variable.
Returns:
AnthropicModel based on environment configuration
"""
return self.settings.anthropic_model
def get_performance_stats(self) -> Dict[str, Dict[str, Any]]:
"""
Get performance statistics for all models.
Returns:
Dictionary with performance stats for each model
"""
stats = {}
for model_name, perf in self.performance_stats.items():
stats[model_name] = {
"total_requests": perf.total_requests,
"total_tokens": perf.total_tokens,
"total_errors": perf.total_errors,
"avg_response_time": perf.avg_response_time,
"success_rate": perf.success_rate,
"last_used": perf.last_used.isoformat() if perf.last_used else None
}
return stats
def get_best_performing_model(self) -> AnthropicModel:
"""
Get the best performing model based on success rate and response time.
Returns:
AnthropicModel with best performance
"""
best_model = self.current_model
best_score = 0.0
for model_name, perf in self.performance_stats.items():
if perf.total_requests == 0:
continue
# Score based on success rate and inverse response time
score = perf.success_rate * (1.0 / max(perf.avg_response_time, 0.1))
if score > best_score:
best_score = score
try:
best_model = AnthropicModel(model_name)
except ValueError:
continue
return best_model
def _get_fallback_model(self) -> AnthropicModel:
"""Get fallback model (currently Claude 3.5 Haiku for speed)."""
if self.current_model == AnthropicModel.CLAUDE_3_5_SONNET:
return AnthropicModel.CLAUDE_3_5_HAIKU
else:
return AnthropicModel.CLAUDE_3_5_SONNET
def _update_performance_stats(
self,
model_name: str,
response: Optional[AnthropicResponse],
start_time: datetime,
success: bool
) -> None:
"""Update performance statistics for a model."""
if model_name not in self.performance_stats:
self.performance_stats[model_name] = ModelPerformance(model_name=model_name)
perf = self.performance_stats[model_name]
perf.total_requests += 1
perf.last_used = datetime.now()
if success and response:
# Update token count
perf.total_tokens += response.usage.get('total_tokens', 0)
# Update average response time
response_time = (datetime.now() - start_time).total_seconds()
if perf.avg_response_time == 0:
perf.avg_response_time = response_time
else:
# Exponential moving average
perf.avg_response_time = 0.9 * perf.avg_response_time + 0.1 * response_time
else:
perf.total_errors += 1
# Update success rate
perf.success_rate = ((perf.total_requests - perf.total_errors) / perf.total_requests) * 100
async def health_check(self) -> Dict[str, Any]:
"""
Perform health check on all available models.
Returns:
Dictionary with health status for each model
"""
health_status = {}
for model in AnthropicModel:
try:
client = await self.get_client(model)
is_healthy = await client.validate_connection()
health_status[model.value] = {
"healthy": is_healthy,
"last_check": datetime.now().isoformat(),
"performance": self.performance_stats[model.value].__dict__
}
except Exception as e:
health_status[model.value] = {
"healthy": False,
"error": str(e),
"last_check": datetime.now().isoformat(),
"performance": self.performance_stats[model.value].__dict__
}
return health_status
async def close_all_clients(self) -> None:
"""Close all model clients."""
for client in self.clients.values():
await client.close()
self.clients.clear()
logger.info("All model clients closed")
# Global model manager instance
_model_manager: Optional[ModelManager] = None
def get_model_manager() -> ModelManager:
"""Get or create global model manager instance."""
global _model_manager
if _model_manager is None:
_model_manager = ModelManager()
return _model_manager
async def close_model_manager() -> None:
"""Close global model manager."""
global _model_manager
if _model_manager is not None:
await _model_manager.close_all_clients()
_model_manager = None