marcosremar
Add complete infrastructure and updated README for OPTION A
fe63a26
"""
Simplified API for ensemble annotation.
Example usage:
from ensemble_tts import EnsembleAnnotator
annotator = EnsembleAnnotator(mode='balanced', device='cuda')
result = annotator.annotate('audio.wav')
print(result['emotion']['label'])
"""
import numpy as np
import librosa
from typing import Dict, Any, Union, List
from pathlib import Path
import logging
from .models.emotion import EmotionEnsemble
from .models.events import EventEnsemble
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EnsembleAnnotator:
"""
Simplified API for ensemble TTS annotation.
Combines emotion and event detection in a single interface.
"""
def __init__(self,
mode: str = 'balanced',
device: str = 'cpu',
voting_strategy: str = 'weighted',
enable_events: bool = True):
"""
Initialize ensemble annotator.
Args:
mode: 'quick' (2 models), 'balanced' (3 models), or 'full' (5 models)
device: 'cpu' or 'cuda'
voting_strategy: 'majority', 'weighted', or 'confidence'
enable_events: Whether to detect non-verbal events
"""
self.mode = mode
self.device = device
self.voting_strategy = voting_strategy
self.enable_events = enable_events
logger.info(f"Initializing EnsembleAnnotator (mode={mode}, device={device})")
# Initialize ensembles
self.emotion_ensemble = EmotionEnsemble(
mode=mode,
device=device,
voting_strategy=voting_strategy
)
if enable_events:
self.event_ensemble = EventEnsemble(device=device)
else:
self.event_ensemble = None
self._models_loaded = False
def load_models(self):
"""Load all models."""
if self._models_loaded:
logger.info("Models already loaded")
return
logger.info("Loading models...")
self.emotion_ensemble.load_models()
if self.event_ensemble:
self.event_ensemble.load_models()
self._models_loaded = True
logger.info("✅ All models loaded successfully")
def annotate(self,
audio_input: Union[str, Path, np.ndarray],
sample_rate: int = None) -> Dict[str, Any]:
"""
Annotate a single audio file or array.
Args:
audio_input: Path to audio file or numpy array
sample_rate: Sample rate (required if audio_input is array)
Returns:
Dictionary with emotion and event annotations
"""
# Load models if not already loaded
if not self._models_loaded:
self.load_models()
# Load audio
if isinstance(audio_input, (str, Path)):
audio, sr = librosa.load(audio_input, sr=16000)
else:
audio = audio_input
sr = sample_rate or 16000
# Resample if needed
if sr != 16000:
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
sr = 16000
# Emotion detection
emotion_result = self.emotion_ensemble.predict(audio, sr)
# Event detection
if self.event_ensemble:
event_result = self.event_ensemble.predict(audio, sr)
else:
event_result = {"events": [], "confidence": {}}
# Combine results
result = {
"emotion": {
"label": emotion_result.get("label", "unknown"),
"confidence": emotion_result.get("confidence", 0.0),
"agreement": emotion_result.get("agreement", 0.0),
"votes": emotion_result.get("votes", {}),
"predictions": emotion_result.get("predictions", [])
},
"events": {
"detected": event_result.get("events", []),
"confidence": event_result.get("confidence", {}),
"detections": event_result.get("detections", [])
}
}
return result
def annotate_batch(self,
audio_list: List[Union[str, Path, np.ndarray]],
sample_rates: List[int] = None) -> List[Dict[str, Any]]:
"""
Annotate multiple audio files/arrays.
Args:
audio_list: List of audio file paths or numpy arrays
sample_rates: List of sample rates (for arrays)
Returns:
List of annotation dictionaries
"""
if not self._models_loaded:
self.load_models()
results = []
for i, audio_input in enumerate(audio_list):
sr = sample_rates[i] if sample_rates else None
result = self.annotate(audio_input, sr)
results.append(result)
return results
def annotate_dataset(self,
dataset,
audio_column: str = 'audio',
text_column: str = 'text',
max_samples: int = None) -> List[Dict[str, Any]]:
"""
Annotate HuggingFace dataset.
Args:
dataset: HuggingFace dataset
audio_column: Name of audio column
text_column: Name of text column
max_samples: Maximum number of samples to annotate
Returns:
List of annotation dictionaries
"""
if not self._models_loaded:
self.load_models()
results = []
samples = dataset[:max_samples] if max_samples else dataset
logger.info(f"Annotating {len(samples)} samples...")
for i, sample in enumerate(samples):
audio_data = sample[audio_column]
audio_array = audio_data['array']
sr = audio_data['sampling_rate']
result = self.annotate(audio_array, sr)
# Add metadata
result['sample_id'] = i
if text_column in sample:
result['text'] = sample[text_column]
results.append(result)
if (i + 1) % 100 == 0:
logger.info(f"Processed {i + 1} samples...")
logger.info(f"✅ Annotated {len(results)} samples")
return results
def get_stats(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Get statistics from annotation results.
Args:
results: List of annotation results
Returns:
Dictionary with statistics
"""
emotions = [r['emotion']['label'] for r in results]
confidences = [r['emotion']['confidence'] for r in results]
agreements = [r['emotion']['agreement'] for r in results]
# Emotion distribution
from collections import Counter
emotion_dist = Counter(emotions)
# Events statistics
all_events = []
for r in results:
all_events.extend(r['events']['detected'])
event_dist = Counter(all_events)
stats = {
"total_samples": len(results),
"emotion_distribution": dict(emotion_dist),
"avg_confidence": float(np.mean(confidences)),
"avg_agreement": float(np.mean(agreements)),
"event_distribution": dict(event_dist),
"total_events_detected": len(all_events)
}
return stats
# Convenience function for quick annotation
def annotate_file(audio_path: str,
mode: str = 'balanced',
device: str = 'cpu') -> Dict[str, Any]:
"""
Quickly annotate a single audio file.
Args:
audio_path: Path to audio file
mode: 'quick', 'balanced', or 'full'
device: 'cpu' or 'cuda'
Returns:
Annotation dictionary
"""
annotator = EnsembleAnnotator(mode=mode, device=device)
return annotator.annotate(audio_path)