marcosremar
Initial commit: Ensemble TTS Annotation System
06b4215
"""
Voting strategies for ensemble models.
"""
from typing import List, Dict, Any, Optional
from collections import Counter
import numpy as np
import logging
logger = logging.getLogger(__name__)
class VotingStrategy:
"""Base class for voting strategies."""
def __init__(self, name: str):
self.name = name
def vote(self, predictions: List[Dict[str, Any]], key: str = 'label') -> Dict[str, Any]:
"""
Combine predictions from multiple models.
Args:
predictions: List of predictions from different models
key: Key to extract from predictions
Returns:
Combined result
"""
raise NotImplementedError
class MajorityVoting(VotingStrategy):
"""Simple majority voting - most common prediction wins."""
def __init__(self):
super().__init__("majority")
def vote(self, predictions: List[Dict[str, Any]], key: str = 'label') -> Dict[str, Any]:
"""
Select most common prediction.
Args:
predictions: List of predictions
key: Key to extract
Returns:
Most common prediction with vote counts
"""
if not predictions:
return {"label": "unknown", "confidence": 0.0, "votes": {}}
# Extract values
values = [pred.get(key) for pred in predictions if key in pred]
if not values:
return {"label": "unknown", "confidence": 0.0, "votes": {}}
# Count votes
vote_counts = Counter(values)
most_common_label, count = vote_counts.most_common(1)[0]
# Calculate confidence as agreement ratio
confidence = count / len(values)
return {
"label": most_common_label,
"confidence": float(confidence),
"votes": dict(vote_counts),
"total_votes": len(values),
"agreement": float(confidence)
}
class WeightedVoting(VotingStrategy):
"""Weighted voting - models with higher weights have more influence."""
def __init__(self):
super().__init__("weighted")
def vote(self, predictions: List[Dict[str, Any]], key: str = 'label') -> Dict[str, Any]:
"""
Weighted voting based on model weights.
Args:
predictions: List of predictions with 'model_weight'
key: Key to extract
Returns:
Weighted vote result
"""
if not predictions:
return {"label": "unknown", "confidence": 0.0, "weighted_votes": {}}
# Calculate weighted votes
weighted_votes = {}
total_weight = 0.0
for pred in predictions:
if key not in pred:
continue
label = pred[key]
weight = pred.get('model_weight', 1.0)
weighted_votes[label] = weighted_votes.get(label, 0.0) + weight
total_weight += weight
if not weighted_votes:
return {"label": "unknown", "confidence": 0.0, "weighted_votes": {}}
# Normalize weights
normalized_votes = {k: v / total_weight for k, v in weighted_votes.items()}
# Get winner
winner = max(normalized_votes.items(), key=lambda x: x[1])
winner_label, winner_score = winner
return {
"label": winner_label,
"confidence": float(winner_score),
"weighted_votes": {k: float(v) for k, v in normalized_votes.items()},
"total_weight": float(total_weight)
}
class ConfidenceVoting(VotingStrategy):
"""Voting weighted by model confidence scores."""
def __init__(self, confidence_key: str = 'confidence'):
super().__init__("confidence")
self.confidence_key = confidence_key
def vote(self, predictions: List[Dict[str, Any]], key: str = 'label') -> Dict[str, Any]:
"""
Vote weighted by prediction confidence.
Args:
predictions: List of predictions with confidence scores
key: Key to extract label
Returns:
Confidence-weighted result
"""
if not predictions:
return {"label": "unknown", "confidence": 0.0}
# Calculate confidence-weighted votes
confidence_votes = {}
total_confidence = 0.0
for pred in predictions:
if key not in pred or self.confidence_key not in pred:
continue
label = pred[key]
confidence = pred[self.confidence_key]
confidence_votes[label] = confidence_votes.get(label, 0.0) + confidence
total_confidence += confidence
if not confidence_votes:
return {"label": "unknown", "confidence": 0.0}
# Normalize
normalized_votes = {k: v / total_confidence for k, v in confidence_votes.items()}
# Get winner
winner_label = max(normalized_votes.items(), key=lambda x: x[1])[0]
winner_confidence = normalized_votes[winner_label]
return {
"label": winner_label,
"confidence": float(winner_confidence),
"confidence_votes": {k: float(v) for k, v in normalized_votes.items()}
}
class MetaLearning(VotingStrategy):
"""
Meta-learning approach - learns to combine model predictions.
Note: Requires training on ground truth data.
"""
def __init__(self, model_path: Optional[str] = None):
super().__init__("meta")
self.model_path = model_path
self.meta_model = None
def train(self, ensemble_predictions: List[List[Dict[str, Any]]], ground_truth: List[str]):
"""
Train meta-model to combine predictions.
Args:
ensemble_predictions: List of prediction lists from ensemble
ground_truth: List of correct labels
"""
logger.info("Training meta-learning model...")
# TODO: Implement meta-model training
# This would typically use a simple classifier (LogisticRegression, RF, etc.)
# to learn optimal combination of model predictions
logger.warning("Meta-learning training not implemented yet. Using weighted voting as fallback.")
def vote(self, predictions: List[Dict[str, Any]], key: str = 'label') -> Dict[str, Any]:
"""
Use meta-model to combine predictions.
Args:
predictions: List of predictions
key: Key to extract
Returns:
Meta-model combined result
"""
if self.meta_model is None:
logger.warning("Meta-model not trained. Falling back to weighted voting.")
fallback = WeightedVoting()
return fallback.vote(predictions, key)
# TODO: Implement meta-model prediction
raise NotImplementedError("Meta-model prediction not implemented")
def get_voting_strategy(strategy: str) -> VotingStrategy:
"""
Get voting strategy by name.
Args:
strategy: Name of strategy ('majority', 'weighted', 'confidence', 'meta')
Returns:
VotingStrategy instance
"""
strategies = {
'majority': MajorityVoting,
'weighted': WeightedVoting,
'confidence': ConfidenceVoting,
'meta': MetaLearning
}
if strategy not in strategies:
logger.warning(f"Unknown strategy '{strategy}'. Using 'weighted' as default.")
strategy = 'weighted'
return strategies[strategy]()