| """ |
| Voting strategies for ensemble models. |
| """ |
|
|
| from typing import List, Dict, Any, Optional |
| from collections import Counter |
| import numpy as np |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class VotingStrategy: |
| """Base class for voting strategies.""" |
|
|
| def __init__(self, name: str): |
| self.name = name |
|
|
| def vote(self, predictions: List[Dict[str, Any]], key: str = 'label') -> Dict[str, Any]: |
| """ |
| Combine predictions from multiple models. |
| |
| Args: |
| predictions: List of predictions from different models |
| key: Key to extract from predictions |
| |
| Returns: |
| Combined result |
| """ |
| raise NotImplementedError |
|
|
|
|
| class MajorityVoting(VotingStrategy): |
| """Simple majority voting - most common prediction wins.""" |
|
|
| def __init__(self): |
| super().__init__("majority") |
|
|
| def vote(self, predictions: List[Dict[str, Any]], key: str = 'label') -> Dict[str, Any]: |
| """ |
| Select most common prediction. |
| |
| Args: |
| predictions: List of predictions |
| key: Key to extract |
| |
| Returns: |
| Most common prediction with vote counts |
| """ |
| if not predictions: |
| return {"label": "unknown", "confidence": 0.0, "votes": {}} |
|
|
| |
| values = [pred.get(key) for pred in predictions if key in pred] |
|
|
| if not values: |
| return {"label": "unknown", "confidence": 0.0, "votes": {}} |
|
|
| |
| vote_counts = Counter(values) |
| most_common_label, count = vote_counts.most_common(1)[0] |
|
|
| |
| confidence = count / len(values) |
|
|
| return { |
| "label": most_common_label, |
| "confidence": float(confidence), |
| "votes": dict(vote_counts), |
| "total_votes": len(values), |
| "agreement": float(confidence) |
| } |
|
|
|
|
| class WeightedVoting(VotingStrategy): |
| """Weighted voting - models with higher weights have more influence.""" |
|
|
| def __init__(self): |
| super().__init__("weighted") |
|
|
| def vote(self, predictions: List[Dict[str, Any]], key: str = 'label') -> Dict[str, Any]: |
| """ |
| Weighted voting based on model weights. |
| |
| Args: |
| predictions: List of predictions with 'model_weight' |
| key: Key to extract |
| |
| Returns: |
| Weighted vote result |
| """ |
| if not predictions: |
| return {"label": "unknown", "confidence": 0.0, "weighted_votes": {}} |
|
|
| |
| weighted_votes = {} |
| total_weight = 0.0 |
|
|
| for pred in predictions: |
| if key not in pred: |
| continue |
|
|
| label = pred[key] |
| weight = pred.get('model_weight', 1.0) |
|
|
| weighted_votes[label] = weighted_votes.get(label, 0.0) + weight |
| total_weight += weight |
|
|
| if not weighted_votes: |
| return {"label": "unknown", "confidence": 0.0, "weighted_votes": {}} |
|
|
| |
| normalized_votes = {k: v / total_weight for k, v in weighted_votes.items()} |
|
|
| |
| winner = max(normalized_votes.items(), key=lambda x: x[1]) |
| winner_label, winner_score = winner |
|
|
| return { |
| "label": winner_label, |
| "confidence": float(winner_score), |
| "weighted_votes": {k: float(v) for k, v in normalized_votes.items()}, |
| "total_weight": float(total_weight) |
| } |
|
|
|
|
| class ConfidenceVoting(VotingStrategy): |
| """Voting weighted by model confidence scores.""" |
|
|
| def __init__(self, confidence_key: str = 'confidence'): |
| super().__init__("confidence") |
| self.confidence_key = confidence_key |
|
|
| def vote(self, predictions: List[Dict[str, Any]], key: str = 'label') -> Dict[str, Any]: |
| """ |
| Vote weighted by prediction confidence. |
| |
| Args: |
| predictions: List of predictions with confidence scores |
| key: Key to extract label |
| |
| Returns: |
| Confidence-weighted result |
| """ |
| if not predictions: |
| return {"label": "unknown", "confidence": 0.0} |
|
|
| |
| confidence_votes = {} |
| total_confidence = 0.0 |
|
|
| for pred in predictions: |
| if key not in pred or self.confidence_key not in pred: |
| continue |
|
|
| label = pred[key] |
| confidence = pred[self.confidence_key] |
|
|
| confidence_votes[label] = confidence_votes.get(label, 0.0) + confidence |
| total_confidence += confidence |
|
|
| if not confidence_votes: |
| return {"label": "unknown", "confidence": 0.0} |
|
|
| |
| normalized_votes = {k: v / total_confidence for k, v in confidence_votes.items()} |
|
|
| |
| winner_label = max(normalized_votes.items(), key=lambda x: x[1])[0] |
| winner_confidence = normalized_votes[winner_label] |
|
|
| return { |
| "label": winner_label, |
| "confidence": float(winner_confidence), |
| "confidence_votes": {k: float(v) for k, v in normalized_votes.items()} |
| } |
|
|
|
|
| class MetaLearning(VotingStrategy): |
| """ |
| Meta-learning approach - learns to combine model predictions. |
| |
| Note: Requires training on ground truth data. |
| """ |
|
|
| def __init__(self, model_path: Optional[str] = None): |
| super().__init__("meta") |
| self.model_path = model_path |
| self.meta_model = None |
|
|
| def train(self, ensemble_predictions: List[List[Dict[str, Any]]], ground_truth: List[str]): |
| """ |
| Train meta-model to combine predictions. |
| |
| Args: |
| ensemble_predictions: List of prediction lists from ensemble |
| ground_truth: List of correct labels |
| """ |
| logger.info("Training meta-learning model...") |
|
|
| |
| |
| |
|
|
| logger.warning("Meta-learning training not implemented yet. Using weighted voting as fallback.") |
|
|
| def vote(self, predictions: List[Dict[str, Any]], key: str = 'label') -> Dict[str, Any]: |
| """ |
| Use meta-model to combine predictions. |
| |
| Args: |
| predictions: List of predictions |
| key: Key to extract |
| |
| Returns: |
| Meta-model combined result |
| """ |
| if self.meta_model is None: |
| logger.warning("Meta-model not trained. Falling back to weighted voting.") |
| fallback = WeightedVoting() |
| return fallback.vote(predictions, key) |
|
|
| |
| raise NotImplementedError("Meta-model prediction not implemented") |
|
|
|
|
| def get_voting_strategy(strategy: str) -> VotingStrategy: |
| """ |
| Get voting strategy by name. |
| |
| Args: |
| strategy: Name of strategy ('majority', 'weighted', 'confidence', 'meta') |
| |
| Returns: |
| VotingStrategy instance |
| """ |
| strategies = { |
| 'majority': MajorityVoting, |
| 'weighted': WeightedVoting, |
| 'confidence': ConfidenceVoting, |
| 'meta': MetaLearning |
| } |
|
|
| if strategy not in strategies: |
| logger.warning(f"Unknown strategy '{strategy}'. Using 'weighted' as default.") |
| strategy = 'weighted' |
|
|
| return strategies[strategy]() |
|
|