marcosremar
Initial commit: Ensemble TTS Annotation System
06b4215
"""
Base classes for ensemble models.
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional
import numpy as np
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class BaseModel(ABC):
"""Base class for all models in the ensemble."""
def __init__(self, name: str, weight: float = 1.0, device: str = 'cpu'):
self.name = name
self.weight = weight
self.device = device
self.model = None
self.is_loaded = False
@abstractmethod
def load(self):
"""Load the model."""
pass
@abstractmethod
def predict(self, audio: np.ndarray, sample_rate: int = 16000) -> Dict[str, Any]:
"""
Make prediction on audio.
Args:
audio: Audio array
sample_rate: Sample rate
Returns:
Dictionary with predictions
"""
pass
def __repr__(self):
return f"{self.__class__.__name__}(name='{self.name}', weight={self.weight})"
class BaseEnsemble(ABC):
"""Base class for ensemble of models."""
def __init__(self, models: List[BaseModel], voting_strategy: str = 'weighted'):
self.models = models
self.voting_strategy = voting_strategy
self.is_loaded = False
def load_models(self):
"""Load all models in the ensemble."""
logger.info(f"Loading {len(self.models)} models...")
for model in self.models:
try:
model.load()
logger.info(f"✅ Loaded: {model.name}")
except Exception as e:
logger.error(f"❌ Failed to load {model.name}: {e}")
raise
self.is_loaded = True
logger.info("All models loaded successfully!")
@abstractmethod
def predict(self, audio: np.ndarray, sample_rate: int = 16000) -> Dict[str, Any]:
"""
Make ensemble prediction.
Args:
audio: Audio array
sample_rate: Sample rate
Returns:
Dictionary with ensemble predictions
"""
pass
def predict_all(self, audio: np.ndarray, sample_rate: int = 16000) -> List[Dict[str, Any]]:
"""
Get predictions from all models.
Args:
audio: Audio array
sample_rate: Sample rate
Returns:
List of predictions from each model
"""
if not self.is_loaded:
raise RuntimeError("Models not loaded. Call load_models() first.")
predictions = []
for model in self.models:
try:
pred = model.predict(audio, sample_rate)
pred['model_name'] = model.name
pred['model_weight'] = model.weight
predictions.append(pred)
except Exception as e:
logger.warning(f"Prediction failed for {model.name}: {e}")
continue
return predictions
def calculate_agreement(self, predictions: List[Dict[str, Any]], key: str) -> float:
"""
Calculate agreement score between models.
Args:
predictions: List of predictions
key: Key to check for agreement
Returns:
Agreement score (0-1)
"""
if not predictions:
return 0.0
values = [pred.get(key) for pred in predictions if key in pred]
if not values:
return 0.0
# Most common value
from collections import Counter
most_common_value = Counter(values).most_common(1)[0][0]
# Calculate agreement
agreement = sum(1 for v in values if v == most_common_value) / len(values)
return agreement
def __repr__(self):
return f"{self.__class__.__name__}(models={len(self.models)}, strategy='{self.voting_strategy}')"