| | """ |
| | Simplified API for ensemble annotation. |
| | |
| | Example usage: |
| | from ensemble_tts import EnsembleAnnotator |
| | |
| | annotator = EnsembleAnnotator(mode='balanced', device='cuda') |
| | result = annotator.annotate('audio.wav') |
| | print(result['emotion']['label']) |
| | """ |
| |
|
| | import numpy as np |
| | import librosa |
| | from typing import Dict, Any, Union, List |
| | from pathlib import Path |
| | import logging |
| |
|
| | from .models.emotion import EmotionEnsemble |
| | from .models.events import EventEnsemble |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class EnsembleAnnotator: |
| | """ |
| | Simplified API for ensemble TTS annotation. |
| | |
| | Combines emotion and event detection in a single interface. |
| | """ |
| |
|
| | def __init__(self, |
| | mode: str = 'balanced', |
| | device: str = 'cpu', |
| | voting_strategy: str = 'weighted', |
| | enable_events: bool = True): |
| | """ |
| | Initialize ensemble annotator. |
| | |
| | Args: |
| | mode: 'quick' (2 models), 'balanced' (3 models), or 'full' (5 models) |
| | device: 'cpu' or 'cuda' |
| | voting_strategy: 'majority', 'weighted', or 'confidence' |
| | enable_events: Whether to detect non-verbal events |
| | """ |
| | self.mode = mode |
| | self.device = device |
| | self.voting_strategy = voting_strategy |
| | self.enable_events = enable_events |
| |
|
| | logger.info(f"Initializing EnsembleAnnotator (mode={mode}, device={device})") |
| |
|
| | |
| | self.emotion_ensemble = EmotionEnsemble( |
| | mode=mode, |
| | device=device, |
| | voting_strategy=voting_strategy |
| | ) |
| |
|
| | if enable_events: |
| | self.event_ensemble = EventEnsemble(device=device) |
| | else: |
| | self.event_ensemble = None |
| |
|
| | self._models_loaded = False |
| |
|
| | def load_models(self): |
| | """Load all models.""" |
| | if self._models_loaded: |
| | logger.info("Models already loaded") |
| | return |
| |
|
| | logger.info("Loading models...") |
| | self.emotion_ensemble.load_models() |
| |
|
| | if self.event_ensemble: |
| | self.event_ensemble.load_models() |
| |
|
| | self._models_loaded = True |
| | logger.info("✅ All models loaded successfully") |
| |
|
| | def annotate(self, |
| | audio_input: Union[str, Path, np.ndarray], |
| | sample_rate: int = None) -> Dict[str, Any]: |
| | """ |
| | Annotate a single audio file or array. |
| | |
| | Args: |
| | audio_input: Path to audio file or numpy array |
| | sample_rate: Sample rate (required if audio_input is array) |
| | |
| | Returns: |
| | Dictionary with emotion and event annotations |
| | """ |
| | |
| | if not self._models_loaded: |
| | self.load_models() |
| |
|
| | |
| | if isinstance(audio_input, (str, Path)): |
| | audio, sr = librosa.load(audio_input, sr=16000) |
| | else: |
| | audio = audio_input |
| | sr = sample_rate or 16000 |
| |
|
| | |
| | if sr != 16000: |
| | audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) |
| | sr = 16000 |
| |
|
| | |
| | emotion_result = self.emotion_ensemble.predict(audio, sr) |
| |
|
| | |
| | if self.event_ensemble: |
| | event_result = self.event_ensemble.predict(audio, sr) |
| | else: |
| | event_result = {"events": [], "confidence": {}} |
| |
|
| | |
| | result = { |
| | "emotion": { |
| | "label": emotion_result.get("label", "unknown"), |
| | "confidence": emotion_result.get("confidence", 0.0), |
| | "agreement": emotion_result.get("agreement", 0.0), |
| | "votes": emotion_result.get("votes", {}), |
| | "predictions": emotion_result.get("predictions", []) |
| | }, |
| | "events": { |
| | "detected": event_result.get("events", []), |
| | "confidence": event_result.get("confidence", {}), |
| | "detections": event_result.get("detections", []) |
| | } |
| | } |
| |
|
| | return result |
| |
|
| | def annotate_batch(self, |
| | audio_list: List[Union[str, Path, np.ndarray]], |
| | sample_rates: List[int] = None) -> List[Dict[str, Any]]: |
| | """ |
| | Annotate multiple audio files/arrays. |
| | |
| | Args: |
| | audio_list: List of audio file paths or numpy arrays |
| | sample_rates: List of sample rates (for arrays) |
| | |
| | Returns: |
| | List of annotation dictionaries |
| | """ |
| | if not self._models_loaded: |
| | self.load_models() |
| |
|
| | results = [] |
| | for i, audio_input in enumerate(audio_list): |
| | sr = sample_rates[i] if sample_rates else None |
| | result = self.annotate(audio_input, sr) |
| | results.append(result) |
| |
|
| | return results |
| |
|
| | def annotate_dataset(self, |
| | dataset, |
| | audio_column: str = 'audio', |
| | text_column: str = 'text', |
| | max_samples: int = None) -> List[Dict[str, Any]]: |
| | """ |
| | Annotate HuggingFace dataset. |
| | |
| | Args: |
| | dataset: HuggingFace dataset |
| | audio_column: Name of audio column |
| | text_column: Name of text column |
| | max_samples: Maximum number of samples to annotate |
| | |
| | Returns: |
| | List of annotation dictionaries |
| | """ |
| | if not self._models_loaded: |
| | self.load_models() |
| |
|
| | results = [] |
| | samples = dataset[:max_samples] if max_samples else dataset |
| |
|
| | logger.info(f"Annotating {len(samples)} samples...") |
| |
|
| | for i, sample in enumerate(samples): |
| | audio_data = sample[audio_column] |
| | audio_array = audio_data['array'] |
| | sr = audio_data['sampling_rate'] |
| |
|
| | result = self.annotate(audio_array, sr) |
| |
|
| | |
| | result['sample_id'] = i |
| | if text_column in sample: |
| | result['text'] = sample[text_column] |
| |
|
| | results.append(result) |
| |
|
| | if (i + 1) % 100 == 0: |
| | logger.info(f"Processed {i + 1} samples...") |
| |
|
| | logger.info(f"✅ Annotated {len(results)} samples") |
| | return results |
| |
|
| | def get_stats(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: |
| | """ |
| | Get statistics from annotation results. |
| | |
| | Args: |
| | results: List of annotation results |
| | |
| | Returns: |
| | Dictionary with statistics |
| | """ |
| | emotions = [r['emotion']['label'] for r in results] |
| | confidences = [r['emotion']['confidence'] for r in results] |
| | agreements = [r['emotion']['agreement'] for r in results] |
| |
|
| | |
| | from collections import Counter |
| | emotion_dist = Counter(emotions) |
| |
|
| | |
| | all_events = [] |
| | for r in results: |
| | all_events.extend(r['events']['detected']) |
| | event_dist = Counter(all_events) |
| |
|
| | stats = { |
| | "total_samples": len(results), |
| | "emotion_distribution": dict(emotion_dist), |
| | "avg_confidence": float(np.mean(confidences)), |
| | "avg_agreement": float(np.mean(agreements)), |
| | "event_distribution": dict(event_dist), |
| | "total_events_detected": len(all_events) |
| | } |
| |
|
| | return stats |
| |
|
| |
|
| | |
| | def annotate_file(audio_path: str, |
| | mode: str = 'balanced', |
| | device: str = 'cpu') -> Dict[str, Any]: |
| | """ |
| | Quickly annotate a single audio file. |
| | |
| | Args: |
| | audio_path: Path to audio file |
| | mode: 'quick', 'balanced', or 'full' |
| | device: 'cpu' or 'cuda' |
| | |
| | Returns: |
| | Annotation dictionary |
| | """ |
| | annotator = EnsembleAnnotator(mode=mode, device=device) |
| | return annotator.annotate(audio_path) |
| |
|