"""Model ensemble for running multiple models and aggregating results.""" import asyncio import logging from dataclasses import dataclass, field from enum import Enum from typing import Any from app.models.providers.base import ( BaseProvider, CompletionResponse, ProviderError, TokenUsage, ) from app.models.router import SmartModelRouter logger = logging.getLogger(__name__) class AggregationStrategy(str, Enum): """Strategy for aggregating ensemble results.""" MAJORITY_VOTE = "majority_vote" # Use most common response CONFIDENCE_WEIGHTED = "confidence_weighted" # Weight by model confidence FIRST_SUCCESS = "first_success" # Use first successful response BEST_QUALITY = "best_quality" # Use response from highest quality model CONCATENATE = "concatenate" # Combine all responses CONSENSUS = "consensus" # Only return if models agree @dataclass class EnsembleResult: """Result from an ensemble run.""" content: str responses: list[CompletionResponse] agreement_score: float # 0-1, how much models agreed strategy: AggregationStrategy selected_model: str | None = None total_cost: float = 0.0 total_tokens: TokenUsage = field(default_factory=TokenUsage) metadata: dict[str, Any] = field(default_factory=dict) def to_dict(self) -> dict[str, Any]: """Convert to dictionary.""" return { "content": self.content, "responses": [r.to_dict() for r in self.responses], "agreement_score": self.agreement_score, "strategy": self.strategy.value, "selected_model": self.selected_model, "total_cost": self.total_cost, "total_tokens": { "prompt": self.total_tokens.prompt_tokens, "completion": self.total_tokens.completion_tokens, "total": self.total_tokens.total_tokens, }, "metadata": self.metadata, } class ModelEnsemble: """Run multiple models and aggregate their results.""" # Model quality tiers for weighted voting MODEL_QUALITY_TIERS: dict[str, float] = { # Tier 1: Highest quality "claude-3-opus-20240229": 1.0, "gpt-4o": 0.98, "claude-3-5-sonnet-20241022": 0.97, "gemini-1.5-pro": 0.95, # Tier 2: High quality "gpt-4-turbo": 0.90, "gpt-4": 0.88, "claude-3-sonnet-20240229": 0.85, "llama-3.3-70b-versatile": 0.83, # Tier 3: Good quality "gpt-4o-mini": 0.75, "claude-3-5-haiku-20241022": 0.73, "gemini-1.5-flash": 0.70, "mixtral-8x7b-32768": 0.68, # Tier 4: Fast/cheap "claude-3-haiku-20240307": 0.60, "llama-3.1-8b-instant": 0.55, "gpt-3.5-turbo": 0.50, } def __init__( self, router: SmartModelRouter, default_models: list[str] | None = None, default_strategy: AggregationStrategy = AggregationStrategy.CONFIDENCE_WEIGHTED, timeout: float = 60.0, ): """Initialize the ensemble. Args: router: SmartModelRouter instance for accessing providers default_models: Default models to use in ensemble default_strategy: Default aggregation strategy timeout: Timeout for each model completion """ self.router = router self.default_models = default_models or [] self.default_strategy = default_strategy self.timeout = timeout async def run( self, messages: list[dict[str, Any]], models: list[str] | None = None, strategy: AggregationStrategy | None = None, min_responses: int = 1, **kwargs: Any, ) -> EnsembleResult: """Run multiple models and aggregate results. Args: messages: List of message dicts models: List of model IDs to use (uses defaults if not specified) strategy: Aggregation strategy (uses default if not specified) min_responses: Minimum number of successful responses required **kwargs: Additional completion parameters Returns: EnsembleResult with aggregated content and metadata Raises: ProviderError: If not enough models respond successfully """ models_to_use = models or self.default_models strategy = strategy or self.default_strategy if not models_to_use: # Use top 3 available models available = self.router.get_available_models() models_to_use = [m.id for m in available[:3]] if not models_to_use: raise ProviderError("No models available for ensemble", "ensemble") # Run all models concurrently tasks = [] for model_id in models_to_use: provider = self.router.get_provider_for_model(model_id) if provider: task = self._run_model(provider, model_id, messages, **kwargs) tasks.append((model_id, task)) if not tasks: raise ProviderError("No valid models for ensemble", "ensemble") # Gather results responses: list[CompletionResponse] = [] errors: list[tuple[str, Exception]] = [] results = await asyncio.gather( *[t[1] for t in tasks], return_exceptions=True, ) for (model_id, _), result in zip(tasks, results): if isinstance(result, Exception): logger.warning(f"Model {model_id} failed: {result}") errors.append((model_id, result)) elif result is not None: responses.append(result) if len(responses) < min_responses: raise ProviderError( f"Only {len(responses)} models responded, need {min_responses}. " f"Errors: {[str(e) for _, e in errors]}", "ensemble", ) # Aggregate results result = self._aggregate(responses, strategy) return result async def _run_model( self, provider: BaseProvider, model_id: str, messages: list[dict[str, Any]], **kwargs: Any, ) -> CompletionResponse | None: """Run a single model with timeout.""" try: return await asyncio.wait_for( provider.complete(messages, model_id, **kwargs), timeout=self.timeout, ) except asyncio.TimeoutError: logger.warning(f"Model {model_id} timed out") return None except Exception as e: logger.warning(f"Model {model_id} error: {e}") raise def _aggregate( self, responses: list[CompletionResponse], strategy: AggregationStrategy, ) -> EnsembleResult: """Aggregate responses based on strategy.""" if not responses: raise ProviderError("No responses to aggregate", "ensemble") # Calculate total cost and tokens total_cost = sum(r.cost for r in responses) total_tokens = TokenUsage() for r in responses: total_tokens = total_tokens + r.usage # Calculate agreement score agreement_score = self._calculate_agreement(responses) # Select content based on strategy if strategy == AggregationStrategy.FIRST_SUCCESS: content, selected_model = self._first_success(responses) elif strategy == AggregationStrategy.MAJORITY_VOTE: content, selected_model = self._majority_vote(responses) elif strategy == AggregationStrategy.CONFIDENCE_WEIGHTED: content, selected_model = self._confidence_weighted(responses) elif strategy == AggregationStrategy.BEST_QUALITY: content, selected_model = self._best_quality(responses) elif strategy == AggregationStrategy.CONCATENATE: content, selected_model = self._concatenate(responses) elif strategy == AggregationStrategy.CONSENSUS: content, selected_model = self._consensus(responses, agreement_score) else: content, selected_model = self._first_success(responses) return EnsembleResult( content=content, responses=responses, agreement_score=agreement_score, strategy=strategy, selected_model=selected_model, total_cost=total_cost, total_tokens=total_tokens, metadata={ "num_responses": len(responses), "models_used": [r.model for r in responses], }, ) def _calculate_agreement(self, responses: list[CompletionResponse]) -> float: """Calculate agreement score between responses. Uses simple similarity based on common words/tokens. """ if len(responses) < 2: return 1.0 # Tokenize responses (simple word-based) response_tokens = [] for r in responses: words = set(r.content.lower().split()) response_tokens.append(words) # Calculate pairwise Jaccard similarity similarities = [] for i in range(len(response_tokens)): for j in range(i + 1, len(response_tokens)): set_i = response_tokens[i] set_j = response_tokens[j] if not set_i and not set_j: similarities.append(1.0) elif not set_i or not set_j: similarities.append(0.0) else: intersection = len(set_i & set_j) union = len(set_i | set_j) similarities.append(intersection / union) return sum(similarities) / len(similarities) if similarities else 1.0 def _first_success( self, responses: list[CompletionResponse] ) -> tuple[str, str | None]: """Return the first successful response.""" r = responses[0] return r.content, r.model def _majority_vote( self, responses: list[CompletionResponse] ) -> tuple[str, str | None]: """Return the most common response (by content similarity).""" if len(responses) == 1: return responses[0].content, responses[0].model # Find response most similar to others best_idx = 0 best_score = 0.0 for i, r in enumerate(responses): score = 0.0 words_i = set(r.content.lower().split()) for j, other in enumerate(responses): if i != j: words_j = set(other.content.lower().split()) if words_i and words_j: intersection = len(words_i & words_j) union = len(words_i | words_j) score += intersection / union if score > best_score: best_score = score best_idx = i return responses[best_idx].content, responses[best_idx].model def _confidence_weighted( self, responses: list[CompletionResponse] ) -> tuple[str, str | None]: """Weight responses by model quality/confidence.""" if len(responses) == 1: return responses[0].content, responses[0].model # Score each response by model quality scored = [] for r in responses: quality = self.MODEL_QUALITY_TIERS.get(r.model, 0.5) scored.append((quality, r)) # Sort by quality scored.sort(key=lambda x: x[0], reverse=True) # Return highest quality response best = scored[0][1] return best.content, best.model def _best_quality( self, responses: list[CompletionResponse] ) -> tuple[str, str | None]: """Return response from highest quality model.""" best_quality = 0.0 best_response = responses[0] for r in responses: quality = self.MODEL_QUALITY_TIERS.get(r.model, 0.5) if quality > best_quality: best_quality = quality best_response = r return best_response.content, best_response.model def _concatenate( self, responses: list[CompletionResponse] ) -> tuple[str, str | None]: """Concatenate all responses.""" parts = [] models = [] for r in responses: parts.append(f"[{r.model}]:\n{r.content}") models.append(r.model) content = "\n\n---\n\n".join(parts) return content, None # No single model selected def _consensus( self, responses: list[CompletionResponse], agreement_score: float, ) -> tuple[str, str | None]: """Return result only if models agree (high agreement score).""" if agreement_score < 0.5: # Low agreement, return best quality with warning content, model = self._best_quality(responses) return f"[LOW CONSENSUS - {agreement_score:.2f}]\n{content}", model # Good agreement, return majority vote return self._majority_vote(responses) async def compare( self, messages: list[dict[str, Any]], models: list[str] | None = None, **kwargs: Any, ) -> dict[str, Any]: """Compare responses from multiple models side-by-side. Args: messages: List of message dicts models: List of model IDs to compare **kwargs: Additional completion parameters Returns: Dictionary with comparison data """ result = await self.run( messages, models, strategy=AggregationStrategy.CONCATENATE, **kwargs, ) # Build comparison comparison = { "responses": [], "agreement_score": result.agreement_score, "total_cost": result.total_cost, "total_tokens": { "prompt": result.total_tokens.prompt_tokens, "completion": result.total_tokens.completion_tokens, "total": result.total_tokens.total_tokens, }, } for r in result.responses: comparison["responses"].append({ "model": r.model, "provider": r.provider, "content": r.content, "cost": r.cost, "latency_ms": r.latency_ms, "tokens": { "prompt": r.usage.prompt_tokens, "completion": r.usage.completion_tokens, }, "quality_tier": self.MODEL_QUALITY_TIERS.get(r.model, 0.5), }) return comparison async def debate( self, messages: list[dict[str, Any]], models: list[str] | None = None, rounds: int = 2, **kwargs: Any, ) -> EnsembleResult: """Run a debate between models where they can respond to each other. Args: messages: Initial messages models: Models to participate in debate rounds: Number of debate rounds **kwargs: Additional completion parameters Returns: Final ensemble result with debate history """ models_to_use = models or self.default_models[:2] # Default to 2 models if len(models_to_use) < 2: raise ProviderError("Debate requires at least 2 models", "ensemble") all_responses: list[CompletionResponse] = [] debate_history: list[dict[str, Any]] = [] current_messages = messages.copy() for round_num in range(rounds): round_responses = [] for model_id in models_to_use: provider = self.router.get_provider_for_model(model_id) if not provider: continue try: response = await asyncio.wait_for( provider.complete(current_messages, model_id, **kwargs), timeout=self.timeout, ) round_responses.append(response) all_responses.append(response) debate_history.append({ "round": round_num + 1, "model": model_id, "content": response.content, }) except Exception as e: logger.warning(f"Model {model_id} failed in round {round_num + 1}: {e}") # Add responses to messages for next round if round_responses and round_num < rounds - 1: for r in round_responses: current_messages.append({ "role": "assistant", "content": f"[{r.model}]: {r.content}", }) # Ask for follow-up current_messages.append({ "role": "user", "content": "Consider the other perspectives and refine your answer.", }) # Aggregate final round responses final_responses = all_responses[-len(models_to_use) :] result = self._aggregate(final_responses, AggregationStrategy.CONFIDENCE_WEIGHTED) # Add debate history to metadata result.metadata["debate_history"] = debate_history result.metadata["total_rounds"] = rounds return result