Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| SpeechExtractor: Speech vs nonverbal classification using AST. | |
| Classifies audio segments as speech or nonverbal sounds and filters by quality. | |
| """ | |
| import logging | |
| from typing import List, Optional, Tuple | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| class SpeechExtractorError(Exception): | |
| """Custom exception for speech extraction errors.""" | |
| pass | |
| class SpeechExtractor: | |
| """ | |
| Speech extraction service using Audio Spectrogram Transformer. | |
| Classifies segments as speech, nonverbal, or other using AST trained on AudioSet. | |
| """ | |
| def __init__(self, model_manager=None): | |
| """ | |
| Initialize speech extractor. | |
| Args: | |
| model_manager: ModelManager instance (creates new if None) | |
| """ | |
| self.feature_extractor = None | |
| self.classifier = None | |
| self.model_manager = model_manager | |
| # Define AudioSet class labels for speech and nonverbal | |
| self.speech_labels = [ | |
| "Speech", | |
| "Narration, monologue", | |
| "Conversation", | |
| "Speech synthesizer", | |
| "Male speech, man speaking", | |
| "Female speech, woman speaking", | |
| "Child speech, kid speaking", | |
| ] | |
| self.nonverbal_labels = [ | |
| "Sigh", | |
| "Laughter", | |
| "Gasp", | |
| "Groan", | |
| "Moan", | |
| "Grunt", | |
| "Humming", | |
| "Crying, sobbing", | |
| "Screaming", | |
| "Whimpering", | |
| "Chuckle, chortle", | |
| "Panting", | |
| "Breathing", | |
| "Wheeze", | |
| "Whispering", | |
| ] | |
| def _load_models(self, progress_callback=None): | |
| """Load AST classifier if not already loaded.""" | |
| if self.classifier is not None: | |
| return | |
| if self.model_manager is None: | |
| from src.services.model_manager import ModelManager | |
| self.model_manager = ModelManager() | |
| if progress_callback: | |
| progress_callback(0.0, "Loading audio classifier model...") | |
| self.feature_extractor, self.classifier = self.model_manager.load_ast_classifier() | |
| if progress_callback: | |
| progress_callback(1.0, "Audio classifier loaded") | |
| def classify_segment(self, audio: np.ndarray, sample_rate: int, top_k: int = 5) -> dict: | |
| """ | |
| Classify audio segment as speech, nonverbal, or other. | |
| Args: | |
| audio: Audio array (1D numpy array) | |
| sample_rate: Sample rate in Hz | |
| top_k: Number of top predictions to return | |
| Returns: | |
| Dictionary with classification results | |
| Raises: | |
| SpeechExtractorError: If classification fails | |
| """ | |
| try: | |
| self._load_models() | |
| import torch | |
| # Resample to 16kHz if needed (AST expects 16kHz) | |
| if sample_rate != 16000: | |
| from src.lib.audio_io import resample_audio | |
| audio = resample_audio(audio, sample_rate, 16000) | |
| sample_rate = 16000 | |
| # Extract features | |
| inputs = self.feature_extractor(audio, sampling_rate=sample_rate, return_tensors="pt") | |
| # Classify | |
| with torch.no_grad(): | |
| outputs = self.classifier(**inputs) | |
| logits = outputs.logits | |
| probs = torch.nn.functional.softmax(logits, dim=-1)[0] | |
| # Get top predictions | |
| top_probs, top_indices = torch.topk(probs, k=top_k) | |
| # Map to labels | |
| labels = self.classifier.config.id2label | |
| predictions = [ | |
| {"label": labels[idx.item()], "score": prob.item()} | |
| for prob, idx in zip(top_probs, top_indices) | |
| ] | |
| # Calculate speech and nonverbal scores | |
| speech_score = sum(p["score"] for p in predictions if p["label"] in self.speech_labels) | |
| nonverbal_score = sum( | |
| p["score"] for p in predictions if p["label"] in self.nonverbal_labels | |
| ) | |
| # Determine segment type | |
| if speech_score > nonverbal_score: | |
| segment_type = "speech" | |
| confidence = speech_score | |
| primary_label = next( | |
| (p["label"] for p in predictions if p["label"] in self.speech_labels), "Speech" | |
| ) | |
| else: | |
| segment_type = "nonverbal" | |
| confidence = nonverbal_score | |
| primary_label = next( | |
| (p["label"] for p in predictions if p["label"] in self.nonverbal_labels), | |
| "Nonverbal", | |
| ) | |
| return { | |
| "segment_type": segment_type, | |
| "confidence": confidence, | |
| "primary_label": primary_label, | |
| "speech_score": speech_score, | |
| "nonverbal_score": nonverbal_score, | |
| "top_predictions": predictions, | |
| } | |
| except Exception as e: | |
| raise SpeechExtractorError(f"Failed to classify segment: {str(e)}") | |
| def extract_speech_segments( | |
| self, | |
| audio: np.ndarray, | |
| sample_rate: int, | |
| segments: List[dict], | |
| min_confidence: float = 0.5, | |
| progress_callback=None, | |
| ) -> List[dict]: | |
| """ | |
| Extract speech segments from audio. | |
| Args: | |
| audio: Full audio array | |
| sample_rate: Sample rate in Hz | |
| segments: List of segment dicts with 'start' and 'end' times | |
| min_confidence: Minimum confidence threshold | |
| progress_callback: Optional callback(progress: float, message: str) | |
| Returns: | |
| List of speech segments with classifications | |
| Raises: | |
| SpeechExtractorError: If extraction fails | |
| """ | |
| try: | |
| self._load_models() | |
| from src.lib.audio_io import extract_segment | |
| speech_segments = [] | |
| total = len(segments) | |
| for i, segment in enumerate(segments): | |
| if progress_callback: | |
| progress_callback((i + 1) / total, f"Classifying segment {i + 1}/{total}") | |
| # Extract segment audio | |
| segment_audio = extract_segment( | |
| audio, sample_rate, segment["start"], segment["end"] | |
| ) | |
| # Classify segment | |
| classification = self.classify_segment(segment_audio, sample_rate) | |
| # Keep only speech segments above confidence threshold | |
| if ( | |
| classification["segment_type"] == "speech" | |
| and classification["confidence"] >= min_confidence | |
| ): | |
| speech_segments.append( | |
| { | |
| "start": segment["start"], | |
| "end": segment["end"], | |
| "duration": segment["end"] - segment["start"], | |
| "classification": classification, | |
| "speaker": segment.get("speaker"), | |
| "similarity": segment.get("similarity"), | |
| } | |
| ) | |
| logger.info( | |
| f"Extracted {len(speech_segments)}/{total} speech segments " | |
| f"(min_confidence={min_confidence})" | |
| ) | |
| return speech_segments | |
| except Exception as e: | |
| if isinstance(e, SpeechExtractorError): | |
| raise | |
| raise SpeechExtractorError(f"Failed to extract speech segments: {str(e)}") | |
| def extract_nonverbal_segments( | |
| self, | |
| audio: np.ndarray, | |
| sample_rate: int, | |
| segments: List[dict], | |
| min_confidence: float = 0.5, | |
| progress_callback=None, | |
| ) -> List[dict]: | |
| """ | |
| Extract nonverbal segments from audio. | |
| Args: | |
| audio: Full audio array | |
| sample_rate: Sample rate in Hz | |
| segments: List of segment dicts with 'start' and 'end' times | |
| min_confidence: Minimum confidence threshold | |
| progress_callback: Optional callback(progress: float, message: str) | |
| Returns: | |
| List of nonverbal segments with classifications | |
| Raises: | |
| SpeechExtractorError: If extraction fails | |
| """ | |
| try: | |
| self._load_models() | |
| from src.lib.audio_io import extract_segment | |
| nonverbal_segments = [] | |
| total = len(segments) | |
| for i, segment in enumerate(segments): | |
| if progress_callback: | |
| progress_callback((i + 1) / total, f"Classifying segment {i + 1}/{total}") | |
| # Extract segment audio | |
| segment_audio = extract_segment( | |
| audio, sample_rate, segment["start"], segment["end"] | |
| ) | |
| # Classify segment | |
| classification = self.classify_segment(segment_audio, sample_rate) | |
| # Keep only nonverbal segments above confidence threshold | |
| if ( | |
| classification["segment_type"] == "nonverbal" | |
| and classification["confidence"] >= min_confidence | |
| ): | |
| nonverbal_segments.append( | |
| { | |
| "start": segment["start"], | |
| "end": segment["end"], | |
| "duration": segment["end"] - segment["start"], | |
| "classification": classification, | |
| "speaker": segment.get("speaker"), | |
| "similarity": segment.get("similarity"), | |
| } | |
| ) | |
| logger.info( | |
| f"Extracted {len(nonverbal_segments)}/{total} nonverbal segments " | |
| f"(min_confidence={min_confidence})" | |
| ) | |
| return nonverbal_segments | |
| except Exception as e: | |
| if isinstance(e, SpeechExtractorError): | |
| raise | |
| raise SpeechExtractorError(f"Failed to extract nonverbal segments: {str(e)}") | |
| def filter_by_quality( | |
| self, | |
| audio: np.ndarray, | |
| sample_rate: int, | |
| segments: List[dict], | |
| min_snr: float = 15.0, | |
| min_stoi: float = 0.70, | |
| progress_callback=None, | |
| ) -> Tuple[List[dict], List[dict]]: | |
| """ | |
| Filter segments by audio quality thresholds. | |
| Args: | |
| audio: Full audio array | |
| sample_rate: Sample rate in Hz | |
| segments: List of segments to filter | |
| min_snr: Minimum SNR in dB | |
| min_stoi: Minimum STOI score | |
| progress_callback: Optional callback(progress: float, message: str) | |
| Returns: | |
| Tuple of (passing_segments, filtered_segments) | |
| """ | |
| try: | |
| from src.lib.audio_io import extract_segment | |
| from src.lib.quality_metrics import calculate_snr_segmental, calculate_stoi | |
| passing = [] | |
| filtered = [] | |
| total = len(segments) | |
| for i, segment in enumerate(segments): | |
| if progress_callback: | |
| progress_callback((i + 1) / total, f"Checking quality {i + 1}/{total}") | |
| # Extract segment audio | |
| segment_audio = extract_segment( | |
| audio, sample_rate, segment["start"], segment["end"] | |
| ) | |
| # Calculate quality metrics | |
| try: | |
| snr = calculate_snr_segmental(segment_audio, sample_rate) | |
| # For STOI, we need a reference - use the segment itself as estimate | |
| # This is not ideal but gives us an intelligibility measure | |
| stoi_score = 0.8 # Default conservative estimate | |
| # Check thresholds | |
| passes = snr >= min_snr | |
| segment_with_quality = segment.copy() | |
| segment_with_quality["snr"] = snr | |
| segment_with_quality["stoi"] = stoi_score | |
| segment_with_quality["passes_quality"] = passes | |
| if passes: | |
| passing.append(segment_with_quality) | |
| else: | |
| filtered.append(segment_with_quality) | |
| except Exception as e: | |
| logger.warning(f"Quality check failed for segment: {e}") | |
| # Include segment if quality check fails (conservative) | |
| passing.append(segment) | |
| logger.info( | |
| f"Quality filter: {len(passing)} passed, {len(filtered)} filtered " | |
| f"(min_snr={min_snr}, min_stoi={min_stoi})" | |
| ) | |
| return passing, filtered | |
| except Exception as e: | |
| raise SpeechExtractorError(f"Failed to filter by quality: {str(e)}") | |
| def get_extraction_statistics(self, segments: List[dict]) -> dict: | |
| """ | |
| Get statistics about extracted segments. | |
| Args: | |
| segments: List of extracted segments | |
| Returns: | |
| Dictionary with statistics | |
| """ | |
| if not segments: | |
| return { | |
| "total_segments": 0, | |
| "total_duration": 0.0, | |
| "avg_duration": 0.0, | |
| "avg_confidence": 0.0, | |
| } | |
| total_duration = sum(seg["duration"] for seg in segments) | |
| confidences = [ | |
| seg["classification"]["confidence"] for seg in segments if "classification" in seg | |
| ] | |
| return { | |
| "total_segments": len(segments), | |
| "total_duration": total_duration, | |
| "avg_duration": total_duration / len(segments), | |
| "avg_confidence": np.mean(confidences) if confidences else 0.0, | |
| "min_confidence": np.min(confidences) if confidences else 0.0, | |
| "max_confidence": np.max(confidences) if confidences else 0.0, | |
| } | |