| """ | |
| Base classes for ensemble models. | |
| """ | |
| from abc import ABC, abstractmethod | |
| from typing import Dict, List, Any, Optional | |
| import numpy as np | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class BaseModel(ABC): | |
| """Base class for all models in the ensemble.""" | |
| def __init__(self, name: str, weight: float = 1.0, device: str = 'cpu'): | |
| self.name = name | |
| self.weight = weight | |
| self.device = device | |
| self.model = None | |
| self.is_loaded = False | |
| def load(self): | |
| """Load the model.""" | |
| pass | |
| def predict(self, audio: np.ndarray, sample_rate: int = 16000) -> Dict[str, Any]: | |
| """ | |
| Make prediction on audio. | |
| Args: | |
| audio: Audio array | |
| sample_rate: Sample rate | |
| Returns: | |
| Dictionary with predictions | |
| """ | |
| pass | |
| def __repr__(self): | |
| return f"{self.__class__.__name__}(name='{self.name}', weight={self.weight})" | |
| class BaseEnsemble(ABC): | |
| """Base class for ensemble of models.""" | |
| def __init__(self, models: List[BaseModel], voting_strategy: str = 'weighted'): | |
| self.models = models | |
| self.voting_strategy = voting_strategy | |
| self.is_loaded = False | |
| def load_models(self): | |
| """Load all models in the ensemble.""" | |
| logger.info(f"Loading {len(self.models)} models...") | |
| for model in self.models: | |
| try: | |
| model.load() | |
| logger.info(f"✅ Loaded: {model.name}") | |
| except Exception as e: | |
| logger.error(f"❌ Failed to load {model.name}: {e}") | |
| raise | |
| self.is_loaded = True | |
| logger.info("All models loaded successfully!") | |
| def predict(self, audio: np.ndarray, sample_rate: int = 16000) -> Dict[str, Any]: | |
| """ | |
| Make ensemble prediction. | |
| Args: | |
| audio: Audio array | |
| sample_rate: Sample rate | |
| Returns: | |
| Dictionary with ensemble predictions | |
| """ | |
| pass | |
| def predict_all(self, audio: np.ndarray, sample_rate: int = 16000) -> List[Dict[str, Any]]: | |
| """ | |
| Get predictions from all models. | |
| Args: | |
| audio: Audio array | |
| sample_rate: Sample rate | |
| Returns: | |
| List of predictions from each model | |
| """ | |
| if not self.is_loaded: | |
| raise RuntimeError("Models not loaded. Call load_models() first.") | |
| predictions = [] | |
| for model in self.models: | |
| try: | |
| pred = model.predict(audio, sample_rate) | |
| pred['model_name'] = model.name | |
| pred['model_weight'] = model.weight | |
| predictions.append(pred) | |
| except Exception as e: | |
| logger.warning(f"Prediction failed for {model.name}: {e}") | |
| continue | |
| return predictions | |
| def calculate_agreement(self, predictions: List[Dict[str, Any]], key: str) -> float: | |
| """ | |
| Calculate agreement score between models. | |
| Args: | |
| predictions: List of predictions | |
| key: Key to check for agreement | |
| Returns: | |
| Agreement score (0-1) | |
| """ | |
| if not predictions: | |
| return 0.0 | |
| values = [pred.get(key) for pred in predictions if key in pred] | |
| if not values: | |
| return 0.0 | |
| # Most common value | |
| from collections import Counter | |
| most_common_value = Counter(values).most_common(1)[0][0] | |
| # Calculate agreement | |
| agreement = sum(1 for v in values if v == most_common_value) / len(values) | |
| return agreement | |
| def __repr__(self): | |
| return f"{self.__class__.__name__}(models={len(self.models)}, strategy='{self.voting_strategy}')" | |