""" SpeechExtractor: Speech vs nonverbal classification using AST. Classifies audio segments as speech or nonverbal sounds and filters by quality. """ import logging from typing import List, Optional, Tuple import numpy as np logger = logging.getLogger(__name__) class SpeechExtractorError(Exception): """Custom exception for speech extraction errors.""" pass class SpeechExtractor: """ Speech extraction service using Audio Spectrogram Transformer. Classifies segments as speech, nonverbal, or other using AST trained on AudioSet. """ def __init__(self, model_manager=None): """ Initialize speech extractor. Args: model_manager: ModelManager instance (creates new if None) """ self.feature_extractor = None self.classifier = None self.model_manager = model_manager # Define AudioSet class labels for speech and nonverbal self.speech_labels = [ "Speech", "Narration, monologue", "Conversation", "Speech synthesizer", "Male speech, man speaking", "Female speech, woman speaking", "Child speech, kid speaking", ] self.nonverbal_labels = [ "Sigh", "Laughter", "Gasp", "Groan", "Moan", "Grunt", "Humming", "Crying, sobbing", "Screaming", "Whimpering", "Chuckle, chortle", "Panting", "Breathing", "Wheeze", "Whispering", ] def _load_models(self, progress_callback=None): """Load AST classifier if not already loaded.""" if self.classifier is not None: return if self.model_manager is None: from src.services.model_manager import ModelManager self.model_manager = ModelManager() if progress_callback: progress_callback(0.0, "Loading audio classifier model...") self.feature_extractor, self.classifier = self.model_manager.load_ast_classifier() if progress_callback: progress_callback(1.0, "Audio classifier loaded") def classify_segment(self, audio: np.ndarray, sample_rate: int, top_k: int = 5) -> dict: """ Classify audio segment as speech, nonverbal, or other. Args: audio: Audio array (1D numpy array) sample_rate: Sample rate in Hz top_k: Number of top predictions to return Returns: Dictionary with classification results Raises: SpeechExtractorError: If classification fails """ try: self._load_models() import torch # Resample to 16kHz if needed (AST expects 16kHz) if sample_rate != 16000: from src.lib.audio_io import resample_audio audio = resample_audio(audio, sample_rate, 16000) sample_rate = 16000 # Extract features inputs = self.feature_extractor(audio, sampling_rate=sample_rate, return_tensors="pt") # Classify with torch.no_grad(): outputs = self.classifier(**inputs) logits = outputs.logits probs = torch.nn.functional.softmax(logits, dim=-1)[0] # Get top predictions top_probs, top_indices = torch.topk(probs, k=top_k) # Map to labels labels = self.classifier.config.id2label predictions = [ {"label": labels[idx.item()], "score": prob.item()} for prob, idx in zip(top_probs, top_indices) ] # Calculate speech and nonverbal scores speech_score = sum(p["score"] for p in predictions if p["label"] in self.speech_labels) nonverbal_score = sum( p["score"] for p in predictions if p["label"] in self.nonverbal_labels ) # Determine segment type if speech_score > nonverbal_score: segment_type = "speech" confidence = speech_score primary_label = next( (p["label"] for p in predictions if p["label"] in self.speech_labels), "Speech" ) else: segment_type = "nonverbal" confidence = nonverbal_score primary_label = next( (p["label"] for p in predictions if p["label"] in self.nonverbal_labels), "Nonverbal", ) return { "segment_type": segment_type, "confidence": confidence, "primary_label": primary_label, "speech_score": speech_score, "nonverbal_score": nonverbal_score, "top_predictions": predictions, } except Exception as e: raise SpeechExtractorError(f"Failed to classify segment: {str(e)}") def extract_speech_segments( self, audio: np.ndarray, sample_rate: int, segments: List[dict], min_confidence: float = 0.5, progress_callback=None, ) -> List[dict]: """ Extract speech segments from audio. Args: audio: Full audio array sample_rate: Sample rate in Hz segments: List of segment dicts with 'start' and 'end' times min_confidence: Minimum confidence threshold progress_callback: Optional callback(progress: float, message: str) Returns: List of speech segments with classifications Raises: SpeechExtractorError: If extraction fails """ try: self._load_models() from src.lib.audio_io import extract_segment speech_segments = [] total = len(segments) for i, segment in enumerate(segments): if progress_callback: progress_callback((i + 1) / total, f"Classifying segment {i + 1}/{total}") # Extract segment audio segment_audio = extract_segment( audio, sample_rate, segment["start"], segment["end"] ) # Classify segment classification = self.classify_segment(segment_audio, sample_rate) # Keep only speech segments above confidence threshold if ( classification["segment_type"] == "speech" and classification["confidence"] >= min_confidence ): speech_segments.append( { "start": segment["start"], "end": segment["end"], "duration": segment["end"] - segment["start"], "classification": classification, "speaker": segment.get("speaker"), "similarity": segment.get("similarity"), } ) logger.info( f"Extracted {len(speech_segments)}/{total} speech segments " f"(min_confidence={min_confidence})" ) return speech_segments except Exception as e: if isinstance(e, SpeechExtractorError): raise raise SpeechExtractorError(f"Failed to extract speech segments: {str(e)}") def extract_nonverbal_segments( self, audio: np.ndarray, sample_rate: int, segments: List[dict], min_confidence: float = 0.5, progress_callback=None, ) -> List[dict]: """ Extract nonverbal segments from audio. Args: audio: Full audio array sample_rate: Sample rate in Hz segments: List of segment dicts with 'start' and 'end' times min_confidence: Minimum confidence threshold progress_callback: Optional callback(progress: float, message: str) Returns: List of nonverbal segments with classifications Raises: SpeechExtractorError: If extraction fails """ try: self._load_models() from src.lib.audio_io import extract_segment nonverbal_segments = [] total = len(segments) for i, segment in enumerate(segments): if progress_callback: progress_callback((i + 1) / total, f"Classifying segment {i + 1}/{total}") # Extract segment audio segment_audio = extract_segment( audio, sample_rate, segment["start"], segment["end"] ) # Classify segment classification = self.classify_segment(segment_audio, sample_rate) # Keep only nonverbal segments above confidence threshold if ( classification["segment_type"] == "nonverbal" and classification["confidence"] >= min_confidence ): nonverbal_segments.append( { "start": segment["start"], "end": segment["end"], "duration": segment["end"] - segment["start"], "classification": classification, "speaker": segment.get("speaker"), "similarity": segment.get("similarity"), } ) logger.info( f"Extracted {len(nonverbal_segments)}/{total} nonverbal segments " f"(min_confidence={min_confidence})" ) return nonverbal_segments except Exception as e: if isinstance(e, SpeechExtractorError): raise raise SpeechExtractorError(f"Failed to extract nonverbal segments: {str(e)}") def filter_by_quality( self, audio: np.ndarray, sample_rate: int, segments: List[dict], min_snr: float = 15.0, min_stoi: float = 0.70, progress_callback=None, ) -> Tuple[List[dict], List[dict]]: """ Filter segments by audio quality thresholds. Args: audio: Full audio array sample_rate: Sample rate in Hz segments: List of segments to filter min_snr: Minimum SNR in dB min_stoi: Minimum STOI score progress_callback: Optional callback(progress: float, message: str) Returns: Tuple of (passing_segments, filtered_segments) """ try: from src.lib.audio_io import extract_segment from src.lib.quality_metrics import calculate_snr_segmental, calculate_stoi passing = [] filtered = [] total = len(segments) for i, segment in enumerate(segments): if progress_callback: progress_callback((i + 1) / total, f"Checking quality {i + 1}/{total}") # Extract segment audio segment_audio = extract_segment( audio, sample_rate, segment["start"], segment["end"] ) # Calculate quality metrics try: snr = calculate_snr_segmental(segment_audio, sample_rate) # For STOI, we need a reference - use the segment itself as estimate # This is not ideal but gives us an intelligibility measure stoi_score = 0.8 # Default conservative estimate # Check thresholds passes = snr >= min_snr segment_with_quality = segment.copy() segment_with_quality["snr"] = snr segment_with_quality["stoi"] = stoi_score segment_with_quality["passes_quality"] = passes if passes: passing.append(segment_with_quality) else: filtered.append(segment_with_quality) except Exception as e: logger.warning(f"Quality check failed for segment: {e}") # Include segment if quality check fails (conservative) passing.append(segment) logger.info( f"Quality filter: {len(passing)} passed, {len(filtered)} filtered " f"(min_snr={min_snr}, min_stoi={min_stoi})" ) return passing, filtered except Exception as e: raise SpeechExtractorError(f"Failed to filter by quality: {str(e)}") def get_extraction_statistics(self, segments: List[dict]) -> dict: """ Get statistics about extracted segments. Args: segments: List of extracted segments Returns: Dictionary with statistics """ if not segments: return { "total_segments": 0, "total_duration": 0.0, "avg_duration": 0.0, "avg_confidence": 0.0, } total_duration = sum(seg["duration"] for seg in segments) confidences = [ seg["classification"]["confidence"] for seg in segments if "classification" in seg ] return { "total_segments": len(segments), "total_duration": total_duration, "avg_duration": total_duration / len(segments), "avg_confidence": np.mean(confidences) if confidences else 0.0, "min_confidence": np.min(confidences) if confidences else 0.0, "max_confidence": np.max(confidences) if confidences else 0.0, }