""" Simplified API for ensemble annotation. Example usage: from ensemble_tts import EnsembleAnnotator annotator = EnsembleAnnotator(mode='balanced', device='cuda') result = annotator.annotate('audio.wav') print(result['emotion']['label']) """ import numpy as np import librosa from typing import Dict, Any, Union, List from pathlib import Path import logging from .models.emotion import EmotionEnsemble from .models.events import EventEnsemble logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EnsembleAnnotator: """ Simplified API for ensemble TTS annotation. Combines emotion and event detection in a single interface. """ def __init__(self, mode: str = 'balanced', device: str = 'cpu', voting_strategy: str = 'weighted', enable_events: bool = True): """ Initialize ensemble annotator. Args: mode: 'quick' (2 models), 'balanced' (3 models), or 'full' (5 models) device: 'cpu' or 'cuda' voting_strategy: 'majority', 'weighted', or 'confidence' enable_events: Whether to detect non-verbal events """ self.mode = mode self.device = device self.voting_strategy = voting_strategy self.enable_events = enable_events logger.info(f"Initializing EnsembleAnnotator (mode={mode}, device={device})") # Initialize ensembles self.emotion_ensemble = EmotionEnsemble( mode=mode, device=device, voting_strategy=voting_strategy ) if enable_events: self.event_ensemble = EventEnsemble(device=device) else: self.event_ensemble = None self._models_loaded = False def load_models(self): """Load all models.""" if self._models_loaded: logger.info("Models already loaded") return logger.info("Loading models...") self.emotion_ensemble.load_models() if self.event_ensemble: self.event_ensemble.load_models() self._models_loaded = True logger.info("✅ All models loaded successfully") def annotate(self, audio_input: Union[str, Path, np.ndarray], sample_rate: int = None) -> Dict[str, Any]: """ Annotate a single audio file or array. Args: audio_input: Path to audio file or numpy array sample_rate: Sample rate (required if audio_input is array) Returns: Dictionary with emotion and event annotations """ # Load models if not already loaded if not self._models_loaded: self.load_models() # Load audio if isinstance(audio_input, (str, Path)): audio, sr = librosa.load(audio_input, sr=16000) else: audio = audio_input sr = sample_rate or 16000 # Resample if needed if sr != 16000: audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) sr = 16000 # Emotion detection emotion_result = self.emotion_ensemble.predict(audio, sr) # Event detection if self.event_ensemble: event_result = self.event_ensemble.predict(audio, sr) else: event_result = {"events": [], "confidence": {}} # Combine results result = { "emotion": { "label": emotion_result.get("label", "unknown"), "confidence": emotion_result.get("confidence", 0.0), "agreement": emotion_result.get("agreement", 0.0), "votes": emotion_result.get("votes", {}), "predictions": emotion_result.get("predictions", []) }, "events": { "detected": event_result.get("events", []), "confidence": event_result.get("confidence", {}), "detections": event_result.get("detections", []) } } return result def annotate_batch(self, audio_list: List[Union[str, Path, np.ndarray]], sample_rates: List[int] = None) -> List[Dict[str, Any]]: """ Annotate multiple audio files/arrays. Args: audio_list: List of audio file paths or numpy arrays sample_rates: List of sample rates (for arrays) Returns: List of annotation dictionaries """ if not self._models_loaded: self.load_models() results = [] for i, audio_input in enumerate(audio_list): sr = sample_rates[i] if sample_rates else None result = self.annotate(audio_input, sr) results.append(result) return results def annotate_dataset(self, dataset, audio_column: str = 'audio', text_column: str = 'text', max_samples: int = None) -> List[Dict[str, Any]]: """ Annotate HuggingFace dataset. Args: dataset: HuggingFace dataset audio_column: Name of audio column text_column: Name of text column max_samples: Maximum number of samples to annotate Returns: List of annotation dictionaries """ if not self._models_loaded: self.load_models() results = [] samples = dataset[:max_samples] if max_samples else dataset logger.info(f"Annotating {len(samples)} samples...") for i, sample in enumerate(samples): audio_data = sample[audio_column] audio_array = audio_data['array'] sr = audio_data['sampling_rate'] result = self.annotate(audio_array, sr) # Add metadata result['sample_id'] = i if text_column in sample: result['text'] = sample[text_column] results.append(result) if (i + 1) % 100 == 0: logger.info(f"Processed {i + 1} samples...") logger.info(f"✅ Annotated {len(results)} samples") return results def get_stats(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: """ Get statistics from annotation results. Args: results: List of annotation results Returns: Dictionary with statistics """ emotions = [r['emotion']['label'] for r in results] confidences = [r['emotion']['confidence'] for r in results] agreements = [r['emotion']['agreement'] for r in results] # Emotion distribution from collections import Counter emotion_dist = Counter(emotions) # Events statistics all_events = [] for r in results: all_events.extend(r['events']['detected']) event_dist = Counter(all_events) stats = { "total_samples": len(results), "emotion_distribution": dict(emotion_dist), "avg_confidence": float(np.mean(confidences)), "avg_agreement": float(np.mean(agreements)), "event_distribution": dict(event_dist), "total_events_detected": len(all_events) } return stats # Convenience function for quick annotation def annotate_file(audio_path: str, mode: str = 'balanced', device: str = 'cpu') -> Dict[str, Any]: """ Quickly annotate a single audio file. Args: audio_path: Path to audio file mode: 'quick', 'balanced', or 'full' device: 'cpu' or 'cuda' Returns: Annotation dictionary """ annotator = EnsembleAnnotator(mode=mode, device=device) return annotator.annotate(audio_path)