Spaces:
Sleeping
Sleeping
| """ | |
| Model manager for Anthropic models. | |
| This module provides utilities for managing multiple Anthropic models, | |
| environment-based switching, and model performance tracking. | |
| """ | |
| from typing import Dict, Any, Optional, List | |
| from dataclasses import dataclass | |
| from datetime import datetime, timedelta | |
| from config.settings import get_settings, AnthropicModel | |
| from config.logging import get_logger | |
| from core.anthropic_client import AnthropicClient, AnthropicResponse | |
| logger = get_logger(__name__) | |
| class ModelPerformance: | |
| """Performance metrics for a specific model.""" | |
| model_name: str | |
| total_requests: int = 0 | |
| total_tokens: int = 0 | |
| total_errors: int = 0 | |
| avg_response_time: float = 0.0 | |
| last_used: Optional[datetime] = None | |
| success_rate: float = 100.0 | |
| class ModelManager: | |
| """ | |
| Manager for multiple Anthropic models with performance tracking. | |
| Features: | |
| - Environment-based model selection | |
| - Performance tracking per model | |
| - Automatic model switching based on performance | |
| - Model health monitoring | |
| """ | |
| def __init__(self): | |
| """Initialize model manager.""" | |
| self.settings = get_settings() | |
| self.current_model = self.settings.anthropic_model | |
| self.performance_stats: Dict[str, ModelPerformance] = {} | |
| self.clients: Dict[str, AnthropicClient] = {} | |
| # Initialize performance tracking for all models | |
| for model in AnthropicModel: | |
| self.performance_stats[model.value] = ModelPerformance(model_name=model.value) | |
| logger.info(f"Model manager initialized with default model: {self.current_model.value}") | |
| async def get_client(self, model: Optional[AnthropicModel] = None) -> AnthropicClient: | |
| """ | |
| Get or create client for specified model. | |
| Args: | |
| model: Model to get client for (uses current if not specified) | |
| Returns: | |
| AnthropicClient instance for the model | |
| """ | |
| target_model = model or self.current_model | |
| model_key = target_model.value | |
| if model_key not in self.clients: | |
| self.clients[model_key] = AnthropicClient(model=target_model) | |
| logger.info(f"Created new client for model: {model_key}") | |
| return self.clients[model_key] | |
| async def generate_completion( | |
| self, | |
| messages: List[Dict[str, str]], | |
| max_tokens: int = 4000, | |
| temperature: float = 0.0, | |
| system_prompt: Optional[str] = None, | |
| model: Optional[AnthropicModel] = None, | |
| fallback_on_error: bool = True | |
| ) -> AnthropicResponse: | |
| """ | |
| Generate completion with automatic fallback and performance tracking. | |
| Args: | |
| messages: List of message dictionaries | |
| max_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature | |
| system_prompt: Optional system prompt | |
| model: Model to use (uses current if not specified) | |
| fallback_on_error: Whether to try fallback model on error | |
| Returns: | |
| AnthropicResponse with completion and metadata | |
| """ | |
| target_model = model or self.current_model | |
| start_time = datetime.now() | |
| try: | |
| client = await self.get_client(target_model) | |
| response = await client.generate_completion( | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| system_prompt=system_prompt, | |
| model=target_model | |
| ) | |
| # Update performance stats | |
| self._update_performance_stats(target_model.value, response, start_time, success=True) | |
| return response | |
| except Exception as e: | |
| # Update error stats | |
| self._update_performance_stats(target_model.value, None, start_time, success=False) | |
| # Try fallback model if enabled and not already using fallback | |
| if fallback_on_error and target_model != self._get_fallback_model(): | |
| logger.warning(f"Model {target_model.value} failed, trying fallback: {str(e)}") | |
| try: | |
| fallback_model = self._get_fallback_model() | |
| client = await self.get_client(fallback_model) | |
| response = await client.generate_completion( | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| system_prompt=system_prompt, | |
| model=fallback_model | |
| ) | |
| # Update performance stats for fallback | |
| self._update_performance_stats(fallback_model.value, response, start_time, success=True) | |
| logger.info(f"Fallback model {fallback_model.value} succeeded") | |
| return response | |
| except Exception as fallback_error: | |
| logger.error(f"Fallback model also failed: {str(fallback_error)}") | |
| self._update_performance_stats(self._get_fallback_model().value, None, start_time, success=False) | |
| # Re-raise original exception if no fallback or fallback failed | |
| raise | |
| def switch_model(self, model: AnthropicModel) -> None: | |
| """ | |
| Switch the current default model. | |
| Args: | |
| model: New model to use as default | |
| """ | |
| old_model = self.current_model.value | |
| self.current_model = model | |
| logger.info(f"Switched default model from {old_model} to {model.value}") | |
| def get_model_from_env(self) -> AnthropicModel: | |
| """ | |
| Get model from environment variable. | |
| Returns: | |
| AnthropicModel based on environment configuration | |
| """ | |
| return self.settings.anthropic_model | |
| def get_performance_stats(self) -> Dict[str, Dict[str, Any]]: | |
| """ | |
| Get performance statistics for all models. | |
| Returns: | |
| Dictionary with performance stats for each model | |
| """ | |
| stats = {} | |
| for model_name, perf in self.performance_stats.items(): | |
| stats[model_name] = { | |
| "total_requests": perf.total_requests, | |
| "total_tokens": perf.total_tokens, | |
| "total_errors": perf.total_errors, | |
| "avg_response_time": perf.avg_response_time, | |
| "success_rate": perf.success_rate, | |
| "last_used": perf.last_used.isoformat() if perf.last_used else None | |
| } | |
| return stats | |
| def get_best_performing_model(self) -> AnthropicModel: | |
| """ | |
| Get the best performing model based on success rate and response time. | |
| Returns: | |
| AnthropicModel with best performance | |
| """ | |
| best_model = self.current_model | |
| best_score = 0.0 | |
| for model_name, perf in self.performance_stats.items(): | |
| if perf.total_requests == 0: | |
| continue | |
| # Score based on success rate and inverse response time | |
| score = perf.success_rate * (1.0 / max(perf.avg_response_time, 0.1)) | |
| if score > best_score: | |
| best_score = score | |
| try: | |
| best_model = AnthropicModel(model_name) | |
| except ValueError: | |
| continue | |
| return best_model | |
| def _get_fallback_model(self) -> AnthropicModel: | |
| """Get fallback model (currently Claude 3.5 Haiku for speed).""" | |
| if self.current_model == AnthropicModel.CLAUDE_3_5_SONNET: | |
| return AnthropicModel.CLAUDE_3_5_HAIKU | |
| else: | |
| return AnthropicModel.CLAUDE_3_5_SONNET | |
| def _update_performance_stats( | |
| self, | |
| model_name: str, | |
| response: Optional[AnthropicResponse], | |
| start_time: datetime, | |
| success: bool | |
| ) -> None: | |
| """Update performance statistics for a model.""" | |
| if model_name not in self.performance_stats: | |
| self.performance_stats[model_name] = ModelPerformance(model_name=model_name) | |
| perf = self.performance_stats[model_name] | |
| perf.total_requests += 1 | |
| perf.last_used = datetime.now() | |
| if success and response: | |
| # Update token count | |
| perf.total_tokens += response.usage.get('total_tokens', 0) | |
| # Update average response time | |
| response_time = (datetime.now() - start_time).total_seconds() | |
| if perf.avg_response_time == 0: | |
| perf.avg_response_time = response_time | |
| else: | |
| # Exponential moving average | |
| perf.avg_response_time = 0.9 * perf.avg_response_time + 0.1 * response_time | |
| else: | |
| perf.total_errors += 1 | |
| # Update success rate | |
| perf.success_rate = ((perf.total_requests - perf.total_errors) / perf.total_requests) * 100 | |
| async def health_check(self) -> Dict[str, Any]: | |
| """ | |
| Perform health check on all available models. | |
| Returns: | |
| Dictionary with health status for each model | |
| """ | |
| health_status = {} | |
| for model in AnthropicModel: | |
| try: | |
| client = await self.get_client(model) | |
| is_healthy = await client.validate_connection() | |
| health_status[model.value] = { | |
| "healthy": is_healthy, | |
| "last_check": datetime.now().isoformat(), | |
| "performance": self.performance_stats[model.value].__dict__ | |
| } | |
| except Exception as e: | |
| health_status[model.value] = { | |
| "healthy": False, | |
| "error": str(e), | |
| "last_check": datetime.now().isoformat(), | |
| "performance": self.performance_stats[model.value].__dict__ | |
| } | |
| return health_status | |
| async def close_all_clients(self) -> None: | |
| """Close all model clients.""" | |
| for client in self.clients.values(): | |
| await client.close() | |
| self.clients.clear() | |
| logger.info("All model clients closed") | |
| # Global model manager instance | |
| _model_manager: Optional[ModelManager] = None | |
| def get_model_manager() -> ModelManager: | |
| """Get or create global model manager instance.""" | |
| global _model_manager | |
| if _model_manager is None: | |
| _model_manager = ModelManager() | |
| return _model_manager | |
| async def close_model_manager() -> None: | |
| """Close global model manager.""" | |
| global _model_manager | |
| if _model_manager is not None: | |
| await _model_manager.close_all_clients() | |
| _model_manager = None |