""" Audio Enhancement and Transcription Pipeline Handles real-time processing of user-uploaded audio """ import sys from pathlib import Path import io # Project root = ClearSpeech PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from enhancement_model.model import UNetAudioEnhancer import torch import numpy as np import librosa import soundfile as sf import whisper from typing import Union, Dict, Tuple import warnings # Suppress librosa warnings warnings.filterwarnings('ignore', category=UserWarning) def get_default_device() -> str: """Auto-detect best available device""" if torch.cuda.is_available(): return "cuda" elif torch.backends.mps.is_available(): return "mps" else: return "cpu" class AudioProcessor: """ Handles audio preprocessing (in-memory) Reuses logic from preprocessing.py but for single files """ def __init__(self, sample_rate=16000, n_fft=1024, hop_length=256, n_mels=128): self.sample_rate = sample_rate self.n_fft = n_fft self.hop_length = hop_length self.n_mels = n_mels self.fmax = 8000 def load_audio(self, audio_file: Union[str, Path, bytes, io.BytesIO]) -> np.ndarray: """ Load audio from file or bytes Args: audio_file: File path, file object, or bytes Returns: audio: Numpy array of audio samples """ try: if isinstance(audio_file, (str, Path)): audio, _ = librosa.load(audio_file, sr=self.sample_rate, mono=True) elif isinstance(audio_file, bytes): audio, _ = librosa.load(io.BytesIO(audio_file), sr=self.sample_rate, mono=True) else: audio, _ = librosa.load(audio_file, sr=self.sample_rate, mono=True) return audio except Exception as e: raise ValueError(f"Failed to load audio: {e}") def normalize_audio(self, audio: np.ndarray, target_db: float = -20.0) -> np.ndarray: """ Normalize audio with RMS-based approach for consistency Args: audio: Input audio target_db: Target RMS level in dB Returns: Normalized audio """ # RMS-based normalization (better than peak normalization) rms = np.sqrt(np.mean(audio**2)) if rms > 0: target_rms = 10 ** (target_db / 20) audio = audio * (target_rms / rms) # Clip to prevent distortion audio = np.clip(audio, -1.0, 1.0) return audio def audio_to_spectrogram(self, audio: np.ndarray) -> np.ndarray: """ Convert audio to mel-spectrogram Args: audio: Audio waveform Returns: mel_spec_db: Mel-spectrogram in dB scale """ mel_spec = librosa.feature.melspectrogram( y=audio, sr=self.sample_rate, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels, fmax=self.fmax ) # Convert to dB with proper reference mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max) # Ensure valid range mel_spec_db = np.clip(mel_spec_db, -80.0, 0.0) return mel_spec_db def spectrogram_to_audio(self, mel_spec_db: np.ndarray, n_iter: int = 60) -> np.ndarray: """ Convert mel-spectrogram back to audio using Griffin-Lim Args: mel_spec_db: Mel-spectrogram in dB n_iter: Griffin-Lim iterations (more = better quality) Returns: audio: Reconstructed waveform """ # Ensure valid dB range mel_spec_db = np.clip(mel_spec_db, -80.0, 0.0) mel_spec_db = np.nan_to_num(mel_spec_db, nan=-80.0, posinf=0.0, neginf=-80.0) # Convert from dB to power mel_spec = librosa.db_to_power(mel_spec_db) # Ensure non-negative power values mel_spec = np.maximum(mel_spec, 1e-10) # Convert to audio using Griffin-Lim with more iterations audio = librosa.feature.inverse.mel_to_audio( mel_spec, sr=self.sample_rate, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=n_iter ) # Handle any NaN or Inf in audio audio = np.nan_to_num(audio, nan=0.0, posinf=1.0, neginf=-1.0) return audio class EnhancementPipeline: """ Complete audio enhancement and transcription pipeline """ def __init__( self, cnn_checkpoint_path: str, whisper_model_name: str = "base", device: str = None, use_fp16: bool = False ): """ Initialize the pipeline with models Args: cnn_checkpoint_path: Path to trained CNN model whisper_model_name: Whisper model size (tiny, base, small, medium, large) device: 'cuda', 'mps', or 'cpu' use_fp16: Use half precision for Whisper (faster on GPU) """ if device is None: device = get_default_device() self.device = torch.device(device) self.use_fp16 = use_fp16 and (device == "cuda") print(f"๐Ÿ–ฅ๏ธ Using device: {self.device}") self.audio_processor = AudioProcessor() # Load CNN enhancement model print(f"๐Ÿ“ฅ Loading U-Net enhancement model...") self.cnn_model = UNetAudioEnhancer(in_channels=1, out_channels=1) try: checkpoint = torch.load(cnn_checkpoint_path, map_location=self.device) self.cnn_model.load_state_dict(checkpoint['model_state_dict']) self.cnn_model.to(self.device) self.cnn_model.eval() epoch = checkpoint.get('epoch', 'unknown') val_loss = checkpoint.get('val_loss', 'unknown') print(f"โœ… U-Net loaded (epoch {epoch}, val_loss: {val_loss})") except Exception as e: raise RuntimeError(f"Failed to load CNN model: {e}") # Load Whisper model print(f"๐Ÿ“ฅ Loading Whisper model ({whisper_model_name})...") try: self.whisper_model = whisper.load_model(whisper_model_name, device=str(self.device)) print("โœ… Whisper model loaded") except Exception as e: raise RuntimeError(f"Failed to load Whisper model: {e}") def enhance_audio(self, audio: np.ndarray) -> np.ndarray: """ Enhance audio using U-Net model Args: audio: Raw audio waveform Returns: enhanced_audio: Cleaned audio waveform """ # Convert to spectrogram (dB scale: [-80, 0]) noisy_spec = self.audio_processor.audio_to_spectrogram(audio) # Normalize to [-1, 1] (matching training normalization) noisy_spec_norm = (noisy_spec + 80.0) / 80.0 # [0, 1] noisy_spec_norm = noisy_spec_norm * 2.0 - 1.0 # [-1, 1] # Add batch and channel dimensions: (1, 1, H, W) noisy_spec_tensor = torch.FloatTensor(noisy_spec_norm).unsqueeze(0).unsqueeze(0) noisy_spec_tensor = noisy_spec_tensor.to(self.device) # Run U-Net inference with torch.no_grad(): clean_spec_tensor = self.cnn_model(noisy_spec_tensor) # Handle NaN/Inf immediately after model output clean_spec_tensor = torch.nan_to_num(clean_spec_tensor, nan=0.0, posinf=1.0, neginf=-1.0) clean_spec_tensor = torch.clamp(clean_spec_tensor, -1.0, 1.0) # Convert back to numpy clean_spec_norm = clean_spec_tensor.squeeze().cpu().numpy() # Denormalize: [-1, 1] โ†’ [0, 1] โ†’ [-80, 0] dB clean_spec_norm = (clean_spec_norm + 1.0) / 2.0 # [-1,1] โ†’ [0,1] clean_spec_db = clean_spec_norm * 80.0 - 80.0 # [0,1] โ†’ [-80,0] # Ensure valid dB range clean_spec_db = np.nan_to_num(clean_spec_db, nan=-80.0, posinf=0.0, neginf=-80.0) clean_spec_db = np.clip(clean_spec_db, -80.0, 0.0) # Convert spectrogram to audio (more iterations for better quality) enhanced_audio = self.audio_processor.spectrogram_to_audio(clean_spec_db, n_iter=60) # Normalize and clip enhanced_audio = self.audio_processor.normalize_audio(enhanced_audio) enhanced_audio = np.clip(enhanced_audio, -1.0, 1.0) return enhanced_audio def transcribe_audio(self, audio: np.ndarray, language: str = 'en') -> Dict: """ Transcribe audio using Whisper Args: audio: Audio waveform (numpy array) language: Language code (e.g., 'en', 'es', 'fr') Returns: result: Dictionary with transcription and metadata """ # Whisper expects float32 audio normalized to [-1, 1] audio = audio.astype(np.float32) # Pad or trim to 30 seconds max for efficiency max_length = 30 * self.audio_processor.sample_rate if len(audio) > max_length: print(f"โš ๏ธ Audio longer than 30s, processing in chunks...") result = self.whisper_model.transcribe( audio, language=language if language else None, fp16=self.use_fp16, verbose=False ) return result def process( self, audio_file: Union[str, Path, bytes, io.BytesIO], language: str = 'en', skip_enhancement: bool = False ) -> Dict: """ Complete processing pipeline Args: audio_file: Input audio (file path, bytes, or file object) language: Target language for transcription skip_enhancement: Skip enhancement step (use original audio) Returns: result: Dictionary containing: - transcript: Text transcription - enhanced_audio: Cleaned audio (numpy array) - duration: Audio duration in seconds - language: Detected language - segments: Timestamped segments """ # Load and preprocess print("๐ŸŽต Loading audio...") audio = self.audio_processor.load_audio(audio_file) audio = self.audio_processor.normalize_audio(audio) duration = len(audio) / self.audio_processor.sample_rate print(f" Duration: {duration:.2f}s") # Enhance with U-Net if not skip_enhancement: print("๐Ÿงน Enhancing audio with U-Net...") enhanced_audio = self.enhance_audio(audio) else: print("โญ๏ธ Skipping enhancement...") enhanced_audio = audio # Transcribe with Whisper print("๐Ÿ“ Transcribing with Whisper...") transcription_result = self.transcribe_audio(enhanced_audio, language=language) # Compile results result = { 'transcript': transcription_result['text'].strip(), 'enhanced_audio': enhanced_audio, 'sample_rate': self.audio_processor.sample_rate, 'duration': duration, 'language': transcription_result.get('language', language), 'segments': transcription_result.get('segments', []) } print("โœ… Processing complete!") return result def test_pipeline(): """Test the pipeline with a sample audio file""" print("="*70) print("๐Ÿงช TESTING AUDIO ENHANCEMENT PIPELINE") print("="*70) # Paths cnn_checkpoint = PROJECT_ROOT / "enhancement_model/checkpoints/best_model.pt" test_audio = PROJECT_ROOT / "data/audio_raw/noisy_0000.wav" output_audio = PROJECT_ROOT / "enhanced_test_output.wav" if not test_audio.exists(): print(f"โŒ Test audio not found: {test_audio}") return # Initialize pipeline pipeline = EnhancementPipeline( cnn_checkpoint_path=str(cnn_checkpoint), whisper_model_name="base", device=get_default_device() ) # Process audio result = pipeline.process(test_audio) # Print results print("\n" + "="*70) print("๐Ÿ“Š RESULTS") print("="*70) print(f"Transcript: {result['transcript']}") print(f"Duration: {result['duration']:.2f}s") print(f"Language: {result['language']}") print(f"Segments: {len(result.get('segments', []))}") # Save enhanced audio sf.write(output_audio, result['enhanced_audio'], result['sample_rate']) print(f"\n๐Ÿ’พ Enhanced audio saved to: {output_audio}") print("="*70) if __name__ == "__main__": test_pipeline()