Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| BatchProcessor orchestrates the complete voice extraction pipeline. | |
| This service coordinates VAD filtering, voice identification, speech/nonverbal | |
| classification, quality filtering, and segment extraction to implement the | |
| complete voice profiling workflow. | |
| """ | |
| import logging | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple | |
| import numpy as np | |
| from ..lib.audio_io import extract_segment, read_audio, write_audio | |
| from ..lib.format_converter import m4a_to_wav, wav_to_m4a | |
| from ..models.audio_segment import AudioSegment, SegmentCollection, SegmentType | |
| from ..models.processing_job import ExtractionMode, JobStatus, ProcessingJob | |
| from ..models.voice_profile import VoiceProfile | |
| from .speech_extractor import SpeechExtractor | |
| from .vad_filter import VADFilter | |
| from .voice_identifier import VoiceIdentifier | |
| logger = logging.getLogger(__name__) | |
| class BatchProcessor: | |
| """ | |
| Orchestrates the complete voice extraction pipeline. | |
| Pipeline stages: | |
| 1. Format conversion (m4a → wav if needed) | |
| 2. VAD pre-filtering (identify voice activity regions) | |
| 3. Voice identification (match target speaker) | |
| 4. Speech/nonverbal classification | |
| 5. Quality filtering | |
| 6. Segment extraction and output generation | |
| """ | |
| def __init__( | |
| self, | |
| vad_threshold: float = 0.5, | |
| voice_similarity_threshold: float = 0.7, | |
| speech_confidence_threshold: float = 0.6, | |
| enable_vad: bool = True, | |
| ): | |
| """ | |
| Initialize the batch processor. | |
| Args: | |
| vad_threshold: VAD confidence threshold (0-1) | |
| voice_similarity_threshold: Voice matching threshold (0-1) | |
| speech_confidence_threshold: Speech classification threshold (0-1) | |
| enable_vad: Whether to use VAD pre-filtering | |
| """ | |
| self.vad_threshold = vad_threshold | |
| self.voice_similarity_threshold = voice_similarity_threshold | |
| self.speech_confidence_threshold = speech_confidence_threshold | |
| self.enable_vad = enable_vad | |
| # Initialize services | |
| self.vad_filter = VADFilter() | |
| self.voice_identifier = VoiceIdentifier() | |
| self.speech_extractor = SpeechExtractor() | |
| logger.info("BatchProcessor initialized") | |
| def process_file( | |
| self, | |
| input_file: Path, | |
| voice_profile: VoiceProfile, | |
| output_dir: Path, | |
| extraction_mode: ExtractionMode = ExtractionMode.SPEECH, | |
| apply_quality_filter: bool = True, | |
| ) -> Tuple[List[AudioSegment], dict]: | |
| """ | |
| Process a single audio file through the complete pipeline. | |
| Args: | |
| input_file: Path to input audio file (m4a or wav) | |
| voice_profile: Reference voice profile to match | |
| output_dir: Directory for output files | |
| extraction_mode: What to extract (SPEECH, NONVERBAL, or BOTH) | |
| apply_quality_filter: Whether to filter by quality thresholds | |
| Returns: | |
| Tuple of (extracted_segments, statistics) | |
| """ | |
| logger.info(f"Processing file: {input_file}") | |
| # Ensure output directory exists | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Convert to wav if needed | |
| if input_file.suffix.lower() == ".m4a": | |
| logger.info("Converting m4a to wav") | |
| wav_path = output_dir / f"{input_file.stem}_temp.wav" | |
| m4a_to_wav(str(input_file), str(wav_path)) | |
| working_file = wav_path | |
| else: | |
| working_file = input_file | |
| # Load audio | |
| audio, sample_rate = read_audio(str(working_file)) | |
| logger.info(f"Loaded audio: {len(audio) / sample_rate:.2f}s at {sample_rate}Hz") | |
| # Stage 1: VAD pre-filtering (optional but recommended) | |
| if self.enable_vad: | |
| logger.info("Stage 1: VAD pre-filtering") | |
| vad_stats = self.vad_filter.get_voice_activity_stats( | |
| audio, sample_rate, self.vad_threshold | |
| ) | |
| logger.info( | |
| f"VAD: {vad_stats['voice_percentage']:.1f}% voice activity " | |
| f"({vad_stats['voice_duration']:.1f}s of {vad_stats['total_duration']:.1f}s)" | |
| ) | |
| if not vad_stats["worth_processing"]: | |
| logger.warning("Insufficient voice activity, skipping file") | |
| return [], {"error": "Insufficient voice activity"} | |
| # Get voice-only segments for processing | |
| vad_segments = self.vad_filter.detect_voice_activity( | |
| audio, sample_rate, self.vad_threshold | |
| ) | |
| else: | |
| # Process entire file | |
| vad_segments = [(0.0, len(audio) / sample_rate)] | |
| vad_stats = {} | |
| # Stage 2: Voice identification | |
| logger.info("Stage 2: Voice identification") | |
| matched_segments = self.voice_identifier.match_voice_profile( | |
| str(working_file), voice_profile, similarity_threshold=self.voice_similarity_threshold | |
| ) | |
| logger.info(f"Found {len(matched_segments)} segments matching voice profile") | |
| if not matched_segments: | |
| logger.warning("No matching voice segments found") | |
| return [], {"error": "No matching voice segments"} | |
| # Stage 3: Speech/nonverbal classification | |
| logger.info(f"Stage 3: Speech/nonverbal classification (mode: {extraction_mode.value})") | |
| if extraction_mode == ExtractionMode.SPEECH: | |
| classified_segments = self.speech_extractor.extract_speech_segments( | |
| audio, sample_rate, matched_segments, self.speech_confidence_threshold | |
| ) | |
| elif extraction_mode == ExtractionMode.NONVERBAL: | |
| classified_segments = self.speech_extractor.extract_nonverbal_segments( | |
| audio, sample_rate, matched_segments, self.speech_confidence_threshold | |
| ) | |
| else: # BOTH | |
| classified_segments = matched_segments | |
| logger.info(f"Classified {len(classified_segments)} segments as {extraction_mode.value}") | |
| # Stage 4: Quality filtering | |
| if apply_quality_filter: | |
| logger.info("Stage 4: Quality filtering") | |
| filtered_segments = self.speech_extractor.filter_by_quality( | |
| audio, sample_rate, classified_segments | |
| ) | |
| logger.info( | |
| f"Quality filter: {len(filtered_segments)}/{len(classified_segments)} segments passed" | |
| ) | |
| else: | |
| filtered_segments = classified_segments | |
| # Stage 5: Extract and save segments | |
| logger.info("Stage 5: Extracting segments") | |
| extracted_segments = [] | |
| for i, segment in enumerate(filtered_segments): | |
| # Extract audio segment | |
| segment_audio = extract_segment(audio, sample_rate, segment["start"], segment["end"]) | |
| # Create output filename | |
| segment_type = segment.get("segment_type", SegmentType.SPEECH) | |
| output_filename = ( | |
| f"{input_file.stem}_segment_{i + 1:03d}_" | |
| f"{segment_type.value}_{segment['start']:.2f}s-{segment['end']:.2f}s.m4a" | |
| ) | |
| output_path = output_dir / output_filename | |
| # Save as m4a | |
| temp_wav = output_dir / f"temp_segment_{i}.wav" | |
| write_audio(str(temp_wav), segment_audio, sample_rate) | |
| wav_to_m4a(str(temp_wav), str(output_path)) | |
| temp_wav.unlink() # Clean up temp file | |
| # Create AudioSegment record | |
| audio_segment = AudioSegment( | |
| start_time=segment["start"], | |
| end_time=segment["end"], | |
| duration=segment["end"] - segment["start"], | |
| segment_type=segment_type, | |
| confidence=segment.get("confidence", 0.0), | |
| voice_similarity=segment.get("similarity", 0.0), | |
| snr=segment.get("snr"), | |
| stoi=segment.get("stoi"), | |
| pesq=segment.get("pesq"), | |
| output_file=str(output_path), | |
| ) | |
| extracted_segments.append(audio_segment) | |
| # Generate statistics | |
| collection = SegmentCollection(extracted_segments) | |
| statistics = { | |
| "input_file": str(input_file), | |
| "total_duration": len(audio) / sample_rate, | |
| "segments_extracted": len(extracted_segments), | |
| "total_extracted_duration": collection.total_duration, | |
| "extraction_percentage": collection.total_duration / (len(audio) / sample_rate) * 100, | |
| "average_segment_duration": collection.average_duration, | |
| "average_confidence": collection.average_confidence, | |
| "average_quality_snr": collection.average_quality["snr"], | |
| "vad_stats": vad_stats, | |
| } | |
| logger.info( | |
| f"Extraction complete: {len(extracted_segments)} segments, " | |
| f"{statistics['total_extracted_duration']:.2f}s " | |
| f"({statistics['extraction_percentage']:.1f}%)" | |
| ) | |
| # Clean up temp wav if created | |
| if input_file.suffix.lower() == ".m4a": | |
| working_file.unlink() | |
| return extracted_segments, statistics | |
| def process_batch(self, job: ProcessingJob) -> ProcessingJob: | |
| """ | |
| Process multiple files in batch. | |
| Args: | |
| job: ProcessingJob with configuration and file list | |
| Returns: | |
| Updated ProcessingJob with results | |
| """ | |
| logger.info(f"Starting batch job: {job.job_id}") | |
| job.start() | |
| # Create voice profile from reference file | |
| try: | |
| logger.info(f"Extracting voice profile from: {job.reference_file}") | |
| voice_profile = self.voice_identifier.extract_voice_profile(str(job.reference_file)) | |
| logger.info( | |
| f"Voice profile extracted: quality={voice_profile.embedding_quality:.2f}, " | |
| f"duration={voice_profile.reference_duration:.2f}s" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to extract voice profile: {e}") | |
| job.fail(f"Voice profile extraction failed: {e}") | |
| return job | |
| # Process each input file | |
| for input_file in job.input_files: | |
| try: | |
| logger.info( | |
| f"Processing file {job.files_processed + 1}/{len(job.input_files)}: {input_file}" | |
| ) | |
| segments, stats = self.process_file( | |
| Path(input_file), | |
| voice_profile, | |
| Path(job.output_dir), | |
| extraction_mode=job.extraction_mode, | |
| apply_quality_filter=True, | |
| ) | |
| # Update job statistics | |
| job.add_success( | |
| input_duration=stats.get("total_duration", 0), | |
| extracted_duration=stats.get("total_extracted_duration", 0), | |
| ) | |
| logger.info(f"File processed successfully: {len(segments)} segments extracted") | |
| except Exception as e: | |
| logger.error(f"Failed to process {input_file}: {e}") | |
| job.add_failure(str(input_file), str(e)) | |
| # Complete job | |
| job.complete() | |
| logger.info( | |
| f"Batch job complete: {job.files_processed} files processed, {job.files_failed} failed" | |
| ) | |
| return job | |
| def estimate_processing_time(self, audio_file: Path, enable_vad: bool = True) -> dict: | |
| """ | |
| Estimate processing time for an audio file. | |
| Args: | |
| audio_file: Path to audio file | |
| enable_vad: Whether VAD will be used | |
| Returns: | |
| Dictionary with time estimates | |
| """ | |
| # Load audio to get duration | |
| audio, sample_rate = read_audio(str(audio_file)) | |
| total_duration = len(audio) / sample_rate | |
| if enable_vad: | |
| # Quick VAD scan | |
| stats = self.vad_filter.get_voice_activity_stats(audio, sample_rate, self.vad_threshold) | |
| voice_duration = stats["voice_duration"] | |
| # Estimate: ~0.4x realtime with VAD | |
| estimated_time = voice_duration * 0.4 | |
| else: | |
| # Estimate: ~0.8x realtime without VAD | |
| estimated_time = total_duration * 0.8 | |
| return { | |
| "total_duration": total_duration, | |
| "voice_duration": voice_duration if enable_vad else total_duration, | |
| "estimated_processing_time": estimated_time, | |
| "estimated_minutes": estimated_time / 60, | |
| } | |