| import torch
|
| import torchaudio
|
| import numpy as np
|
| import logging
|
| import tempfile
|
| import os
|
| import threading
|
| from typing import List, Tuple, Dict, Optional, Any
|
| import silero_vad
|
| import soundfile as sf
|
| import librosa
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
| TARGET_CHUNK_DURATION = 30.0
|
| MIN_CHUNK_DURATION = 5.0
|
| SAMPLE_RATE = 16000
|
|
|
|
|
| class AudioChunker:
|
| """
|
| Handles audio chunking with different strategies:
|
| - 'none': Single chunk (no chunking)
|
| - 'vad': VAD-based intelligent chunking
|
| - 'static': Fixed-duration time-based chunking
|
| """
|
|
|
| _instance = None
|
| _instance_lock = threading.Lock()
|
| vad_model: Optional[Any]
|
|
|
| def __new__(cls):
|
| if cls._instance is None:
|
| with cls._instance_lock:
|
|
|
| if cls._instance is None:
|
| cls._instance = super().__new__(cls)
|
|
|
| cls._instance.vad_model = cls.load_vad_model()
|
| return cls._instance
|
|
|
| @staticmethod
|
| def load_vad_model():
|
| """Load silero VAD model with error handling."""
|
| try:
|
| logger.info("Loading Silero VAD model...")
|
| vad_model = silero_vad.load_silero_vad()
|
| logger.info("✓ VAD model loaded successfully")
|
| return vad_model
|
| except Exception as e:
|
| logger.error(f"Failed to load VAD model: {e}")
|
| logger.warning("VAD chunking will fall back to time-based chunking")
|
| return None
|
|
|
| @torch.inference_mode()
|
| def chunk_audio(self, audio_tensor: torch.Tensor, sample_rate: int = SAMPLE_RATE, mode: str = "vad", chunk_duration: float = 30.0) -> List[Dict]:
|
| """
|
| Chunk audio tensor using specified strategy.
|
|
|
| Args:
|
| audio_tensor: Audio tensor (1D waveform)
|
| sample_rate: Sample rate of the audio tensor
|
| mode: Chunking mode - 'none', 'vad', or 'static'
|
| chunk_duration: Target duration for static chunking (seconds)
|
|
|
| Returns:
|
| List of chunk info dicts with uniform format:
|
| - start_time: Start time in seconds
|
| - end_time: End time in seconds
|
| - duration: Duration in seconds
|
| - audio_data: Audio tensor for this chunk
|
| - sample_rate: Sample rate
|
| - chunk_index: Index of this chunk
|
| """
|
| logger.info(f"Chunking audio tensor: {audio_tensor.shape} at {sample_rate}Hz (mode: {mode})")
|
|
|
| try:
|
|
|
| assert len(audio_tensor.shape) == 1, f"Expected 1D audio tensor, got shape {audio_tensor.shape}"
|
|
|
|
|
| assert sample_rate == SAMPLE_RATE, f"Expected {SAMPLE_RATE}Hz sample rate, got {sample_rate}Hz"
|
|
|
|
|
| if mode == "none":
|
| return self._create_single_chunk(audio_tensor, sample_rate)
|
| elif mode == "vad":
|
| if self.vad_model is not None:
|
| return self._chunk_with_vad(audio_tensor)
|
| else:
|
| logger.warning("VAD model not available, falling back to static chunking")
|
| return self._chunk_static(audio_tensor, chunk_duration)
|
| elif mode == "static":
|
| return self._chunk_static(audio_tensor, chunk_duration)
|
| else:
|
| raise ValueError(f"Unknown chunking mode: {mode}")
|
|
|
| except Exception as e:
|
| logger.error(f"Error chunking audio tensor: {e}")
|
|
|
| return self._create_single_chunk(audio_tensor, sample_rate)
|
|
|
| def _create_single_chunk(self, waveform: torch.Tensor, sample_rate: int = SAMPLE_RATE) -> List[Dict]:
|
| """Create a single chunk containing the entire audio."""
|
| duration = len(waveform) / sample_rate
|
|
|
| return [{
|
| "start_time": 0.0,
|
| "end_time": duration,
|
| "duration": duration,
|
| "audio_data": waveform,
|
| "sample_rate": sample_rate,
|
| "chunk_index": 0,
|
| }]
|
|
|
| def _chunk_static(self, waveform: torch.Tensor, chunk_duration: float) -> List[Dict]:
|
| """Create fixed-duration chunks."""
|
| chunks = []
|
| total_samples = len(waveform)
|
| target_samples = int(chunk_duration * SAMPLE_RATE)
|
|
|
| start_sample = 0
|
| chunk_idx = 0
|
|
|
| while start_sample < total_samples:
|
| end_sample = min(start_sample + target_samples, total_samples)
|
| chunk_audio = waveform[start_sample:end_sample]
|
| duration = len(chunk_audio) / SAMPLE_RATE
|
|
|
|
|
| if duration >= MIN_CHUNK_DURATION:
|
| chunks.append({
|
| "start_time": start_sample / SAMPLE_RATE,
|
| "end_time": end_sample / SAMPLE_RATE,
|
| "duration": duration,
|
| "audio_data": chunk_audio,
|
| "sample_rate": SAMPLE_RATE,
|
| "chunk_index": chunk_idx,
|
| })
|
| chunk_idx += 1
|
|
|
| start_sample = end_sample
|
|
|
| logger.info(f"Created {len(chunks)} static chunks of ~{chunk_duration}s each")
|
| return chunks
|
|
|
| def _chunk_fallback(self, audio_path: str) -> List[Dict]:
|
| """Ultimate fallback - create single chunk using librosa (for file-based legacy method)."""
|
| try:
|
| logger.warning("Using librosa fallback for chunking")
|
| data, sr = librosa.load(audio_path, sr=SAMPLE_RATE)
|
| waveform = torch.from_numpy(data)
|
| return self._create_single_chunk(waveform, SAMPLE_RATE)
|
| except Exception as e:
|
| logger.error(f"All chunking methods failed: {e}")
|
| return []
|
| def _chunk_with_vad(self, waveform: torch.Tensor) -> List[Dict]:
|
| """Chunk audio using VAD for speech detection with uniform return format."""
|
| try:
|
|
|
| vad_waveform = waveform.cpu() if waveform.is_cuda else waveform
|
|
|
|
|
| speech_timestamps = silero_vad.get_speech_timestamps(
|
| vad_waveform,
|
| self.vad_model,
|
| sampling_rate=SAMPLE_RATE,
|
| min_speech_duration_ms=500,
|
| min_silence_duration_ms=300,
|
| window_size_samples=1536,
|
| speech_pad_ms=100,
|
| )
|
|
|
| logger.info(f"Found {len(speech_timestamps)} speech segments")
|
|
|
|
|
|
|
| chunks = self._create_chunks_from_speech_segments(
|
| waveform, speech_timestamps
|
| )
|
|
|
| logger.info(f"Created {len(chunks)} audio chunks using VAD")
|
| return chunks
|
|
|
| except Exception as e:
|
| logger.error(f"VAD chunking failed: {e}")
|
| return self._chunk_static(waveform, TARGET_CHUNK_DURATION)
|
| def _create_chunks_from_speech_segments(
|
| self, waveform: torch.Tensor, speech_segments: List[Dict]
|
| ) -> List[Dict]:
|
| """Create chunks that respect speech boundaries and target duration with uniform format."""
|
| if not speech_segments:
|
| logger.warning(
|
| "No speech segments found, falling back to static chunking"
|
| )
|
| return self._chunk_static(waveform, TARGET_CHUNK_DURATION)
|
|
|
| chunks = []
|
| current_chunk_start = 0
|
| target_samples = int(TARGET_CHUNK_DURATION * SAMPLE_RATE)
|
| total_samples = len(waveform)
|
| chunk_idx = 0
|
|
|
| while current_chunk_start < total_samples:
|
|
|
| target_chunk_end = current_chunk_start + target_samples
|
|
|
|
|
| if target_chunk_end >= total_samples or (
|
| total_samples - target_chunk_end
|
| ) < (target_samples * 0.3):
|
| chunk_end = total_samples
|
| else:
|
|
|
| chunk_end = self._find_best_chunk_end_continuous(
|
| speech_segments,
|
| current_chunk_start,
|
| target_chunk_end,
|
| total_samples,
|
| )
|
|
|
|
|
| chunk_audio = waveform[current_chunk_start:chunk_end]
|
| duration = len(chunk_audio) / SAMPLE_RATE
|
|
|
| chunks.append({
|
| "start_time": current_chunk_start / SAMPLE_RATE,
|
| "end_time": chunk_end / SAMPLE_RATE,
|
| "duration": duration,
|
| "audio_data": chunk_audio,
|
| "sample_rate": SAMPLE_RATE,
|
| "chunk_index": chunk_idx,
|
| })
|
|
|
| logger.info(
|
| f"Created chunk {chunk_idx + 1}: {current_chunk_start/SAMPLE_RATE:.2f}s - {chunk_end/SAMPLE_RATE:.2f}s ({duration:.2f}s)"
|
| )
|
| chunk_idx += 1
|
|
|
|
|
| current_chunk_start = chunk_end
|
|
|
|
|
| total_audio_duration = len(waveform) / SAMPLE_RATE
|
| total_chunks_duration = sum(chunk["duration"] for chunk in chunks)
|
| logger.info(
|
| f"Audio chunking complete: {len(chunks)} chunks covering {total_chunks_duration:.2f}s of {total_audio_duration:.2f}s total audio"
|
| )
|
|
|
| if (
|
| abs(total_chunks_duration - total_audio_duration) > 0.01
|
| ):
|
| logger.error(
|
| f"Duration mismatch: chunks={total_chunks_duration:.2f}s, original={total_audio_duration:.2f}s"
|
| )
|
| else:
|
| logger.info("✓ Perfect audio coverage achieved")
|
|
|
| return chunks
|
|
|
| def _find_best_chunk_end_continuous(
|
| self,
|
| speech_segments: List[Dict],
|
| chunk_start: int,
|
| target_end: int,
|
| total_samples: int,
|
| ) -> int:
|
| """Find the best place to end a chunk while ensuring continuous coverage."""
|
|
|
|
|
| target_end = min(target_end, total_samples)
|
|
|
|
|
| search_window = int(SAMPLE_RATE * 3)
|
| search_start = max(chunk_start, target_end - search_window)
|
| search_end = min(total_samples, target_end + search_window)
|
|
|
| best_end = target_end
|
| best_score = 0
|
|
|
|
|
| for segment in speech_segments:
|
| segment_start = segment["start"]
|
| segment_end = segment["end"]
|
|
|
|
|
| if search_start <= segment_end <= search_end:
|
|
|
| distance_score = 1.0 - abs(segment_end - target_end) / search_window
|
|
|
|
|
| boundary_score = 1.0
|
|
|
| total_score = distance_score * boundary_score
|
|
|
| if total_score > best_score:
|
| best_score = total_score
|
| best_end = segment_end
|
|
|
|
|
| best_end = min(int(best_end), total_samples)
|
|
|
|
|
| if best_end <= chunk_start:
|
| best_end = min(target_end, total_samples)
|
|
|
| return best_end
|
|
|
| def _find_best_chunk_end(
|
| self,
|
| speech_segments: List[Dict],
|
| start_idx: int,
|
| chunk_start: int,
|
| target_end: int,
|
| ) -> int:
|
| """Find the best place to end a chunk (at silence, near target duration)."""
|
|
|
| best_end = target_end
|
|
|
|
|
| for i in range(start_idx, len(speech_segments)):
|
| segment = speech_segments[i]
|
| segment_start = segment["start"]
|
| segment_end = segment["end"]
|
|
|
|
|
| if segment_start > target_end:
|
| best_end = min(target_end, segment_start)
|
| break
|
|
|
|
|
| if abs(segment_end - target_end) < SAMPLE_RATE * 5:
|
| best_end = segment_end
|
| break
|
|
|
|
|
| if segment_end > target_end + SAMPLE_RATE * 10:
|
|
|
| best_end = target_end
|
| break
|
|
|
| return int(best_end)
|
|
|
| def save_chunk_to_file(self, chunk: Dict, output_path: str) -> str:
|
| """Save a chunk to a temporary audio file."""
|
| try:
|
|
|
| audio_data = chunk["audio_data"]
|
| if isinstance(audio_data, torch.Tensor):
|
|
|
| audio_data = audio_data.cpu().numpy()
|
|
|
|
|
| sf.write(output_path, audio_data, chunk["sample_rate"])
|
| return output_path
|
|
|
| except Exception as e:
|
| logger.error(f"Failed to save chunk to file: {e}")
|
| raise
|
|
|