Spaces:
Running
Running
Upload inference_pipeline.py
Browse files- inference_pipeline.py +382 -0
inference_pipeline.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio Enhancement and Transcription Pipeline
|
| 3 |
+
Handles real-time processing of user-uploaded audio
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import io
|
| 9 |
+
|
| 10 |
+
# Project root = ClearSpeech
|
| 11 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 13 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 14 |
+
|
| 15 |
+
from enhancement_model.model import UNetAudioEnhancer
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import numpy as np
|
| 19 |
+
import librosa
|
| 20 |
+
import soundfile as sf
|
| 21 |
+
import whisper
|
| 22 |
+
from typing import Union, Dict, Tuple
|
| 23 |
+
import warnings
|
| 24 |
+
|
| 25 |
+
# Suppress librosa warnings
|
| 26 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_default_device() -> str:
|
| 30 |
+
"""Auto-detect best available device"""
|
| 31 |
+
if torch.cuda.is_available():
|
| 32 |
+
return "cuda"
|
| 33 |
+
elif torch.backends.mps.is_available():
|
| 34 |
+
return "mps"
|
| 35 |
+
else:
|
| 36 |
+
return "cpu"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class AudioProcessor:
|
| 40 |
+
"""
|
| 41 |
+
Handles audio preprocessing (in-memory)
|
| 42 |
+
Reuses logic from preprocessing.py but for single files
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, sample_rate=16000, n_fft=1024, hop_length=256, n_mels=128):
|
| 46 |
+
self.sample_rate = sample_rate
|
| 47 |
+
self.n_fft = n_fft
|
| 48 |
+
self.hop_length = hop_length
|
| 49 |
+
self.n_mels = n_mels
|
| 50 |
+
self.fmax = 8000
|
| 51 |
+
|
| 52 |
+
def load_audio(self, audio_file: Union[str, Path, bytes, io.BytesIO]) -> np.ndarray:
|
| 53 |
+
"""
|
| 54 |
+
Load audio from file or bytes
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
audio_file: File path, file object, or bytes
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
audio: Numpy array of audio samples
|
| 61 |
+
"""
|
| 62 |
+
try:
|
| 63 |
+
if isinstance(audio_file, (str, Path)):
|
| 64 |
+
audio, _ = librosa.load(audio_file, sr=self.sample_rate, mono=True)
|
| 65 |
+
elif isinstance(audio_file, bytes):
|
| 66 |
+
audio, _ = librosa.load(io.BytesIO(audio_file), sr=self.sample_rate, mono=True)
|
| 67 |
+
else:
|
| 68 |
+
audio, _ = librosa.load(audio_file, sr=self.sample_rate, mono=True)
|
| 69 |
+
return audio
|
| 70 |
+
except Exception as e:
|
| 71 |
+
raise ValueError(f"Failed to load audio: {e}")
|
| 72 |
+
|
| 73 |
+
def normalize_audio(self, audio: np.ndarray, target_db: float = -20.0) -> np.ndarray:
|
| 74 |
+
"""
|
| 75 |
+
Normalize audio with RMS-based approach for consistency
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
audio: Input audio
|
| 79 |
+
target_db: Target RMS level in dB
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Normalized audio
|
| 83 |
+
"""
|
| 84 |
+
# RMS-based normalization (better than peak normalization)
|
| 85 |
+
rms = np.sqrt(np.mean(audio**2))
|
| 86 |
+
if rms > 0:
|
| 87 |
+
target_rms = 10 ** (target_db / 20)
|
| 88 |
+
audio = audio * (target_rms / rms)
|
| 89 |
+
|
| 90 |
+
# Clip to prevent distortion
|
| 91 |
+
audio = np.clip(audio, -1.0, 1.0)
|
| 92 |
+
return audio
|
| 93 |
+
|
| 94 |
+
def audio_to_spectrogram(self, audio: np.ndarray) -> np.ndarray:
|
| 95 |
+
"""
|
| 96 |
+
Convert audio to mel-spectrogram
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
audio: Audio waveform
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
mel_spec_db: Mel-spectrogram in dB scale
|
| 103 |
+
"""
|
| 104 |
+
mel_spec = librosa.feature.melspectrogram(
|
| 105 |
+
y=audio,
|
| 106 |
+
sr=self.sample_rate,
|
| 107 |
+
n_fft=self.n_fft,
|
| 108 |
+
hop_length=self.hop_length,
|
| 109 |
+
n_mels=self.n_mels,
|
| 110 |
+
fmax=self.fmax
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Convert to dB with proper reference
|
| 114 |
+
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
|
| 115 |
+
|
| 116 |
+
# Ensure valid range
|
| 117 |
+
mel_spec_db = np.clip(mel_spec_db, -80.0, 0.0)
|
| 118 |
+
return mel_spec_db
|
| 119 |
+
|
| 120 |
+
def spectrogram_to_audio(self, mel_spec_db: np.ndarray, n_iter: int = 60) -> np.ndarray:
|
| 121 |
+
"""
|
| 122 |
+
Convert mel-spectrogram back to audio using Griffin-Lim
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
mel_spec_db: Mel-spectrogram in dB
|
| 126 |
+
n_iter: Griffin-Lim iterations (more = better quality)
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
audio: Reconstructed waveform
|
| 130 |
+
"""
|
| 131 |
+
# Ensure valid dB range
|
| 132 |
+
mel_spec_db = np.clip(mel_spec_db, -80.0, 0.0)
|
| 133 |
+
mel_spec_db = np.nan_to_num(mel_spec_db, nan=-80.0, posinf=0.0, neginf=-80.0)
|
| 134 |
+
|
| 135 |
+
# Convert from dB to power
|
| 136 |
+
mel_spec = librosa.db_to_power(mel_spec_db)
|
| 137 |
+
|
| 138 |
+
# Ensure non-negative power values
|
| 139 |
+
mel_spec = np.maximum(mel_spec, 1e-10)
|
| 140 |
+
|
| 141 |
+
# Convert to audio using Griffin-Lim with more iterations
|
| 142 |
+
audio = librosa.feature.inverse.mel_to_audio(
|
| 143 |
+
mel_spec,
|
| 144 |
+
sr=self.sample_rate,
|
| 145 |
+
n_fft=self.n_fft,
|
| 146 |
+
hop_length=self.hop_length,
|
| 147 |
+
n_iter=n_iter
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Handle any NaN or Inf in audio
|
| 151 |
+
audio = np.nan_to_num(audio, nan=0.0, posinf=1.0, neginf=-1.0)
|
| 152 |
+
|
| 153 |
+
return audio
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class EnhancementPipeline:
|
| 157 |
+
"""
|
| 158 |
+
Complete audio enhancement and transcription pipeline
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
def __init__(
|
| 162 |
+
self,
|
| 163 |
+
cnn_checkpoint_path: str,
|
| 164 |
+
whisper_model_name: str = "base",
|
| 165 |
+
device: str = None,
|
| 166 |
+
use_fp16: bool = False
|
| 167 |
+
):
|
| 168 |
+
"""
|
| 169 |
+
Initialize the pipeline with models
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
cnn_checkpoint_path: Path to trained CNN model
|
| 173 |
+
whisper_model_name: Whisper model size (tiny, base, small, medium, large)
|
| 174 |
+
device: 'cuda', 'mps', or 'cpu'
|
| 175 |
+
use_fp16: Use half precision for Whisper (faster on GPU)
|
| 176 |
+
"""
|
| 177 |
+
if device is None:
|
| 178 |
+
device = get_default_device()
|
| 179 |
+
self.device = torch.device(device)
|
| 180 |
+
self.use_fp16 = use_fp16 and (device == "cuda")
|
| 181 |
+
|
| 182 |
+
print(f"π₯οΈ Using device: {self.device}")
|
| 183 |
+
|
| 184 |
+
self.audio_processor = AudioProcessor()
|
| 185 |
+
|
| 186 |
+
# Load CNN enhancement model
|
| 187 |
+
print(f"π₯ Loading U-Net enhancement model...")
|
| 188 |
+
self.cnn_model = UNetAudioEnhancer(in_channels=1, out_channels=1)
|
| 189 |
+
|
| 190 |
+
try:
|
| 191 |
+
checkpoint = torch.load(cnn_checkpoint_path, map_location=self.device)
|
| 192 |
+
self.cnn_model.load_state_dict(checkpoint['model_state_dict'])
|
| 193 |
+
self.cnn_model.to(self.device)
|
| 194 |
+
self.cnn_model.eval()
|
| 195 |
+
|
| 196 |
+
epoch = checkpoint.get('epoch', 'unknown')
|
| 197 |
+
val_loss = checkpoint.get('val_loss', 'unknown')
|
| 198 |
+
print(f"β
U-Net loaded (epoch {epoch}, val_loss: {val_loss})")
|
| 199 |
+
except Exception as e:
|
| 200 |
+
raise RuntimeError(f"Failed to load CNN model: {e}")
|
| 201 |
+
|
| 202 |
+
# Load Whisper model
|
| 203 |
+
print(f"π₯ Loading Whisper model ({whisper_model_name})...")
|
| 204 |
+
try:
|
| 205 |
+
self.whisper_model = whisper.load_model(whisper_model_name, device=str(self.device))
|
| 206 |
+
print("β
Whisper model loaded")
|
| 207 |
+
except Exception as e:
|
| 208 |
+
raise RuntimeError(f"Failed to load Whisper model: {e}")
|
| 209 |
+
|
| 210 |
+
def enhance_audio(self, audio: np.ndarray) -> np.ndarray:
|
| 211 |
+
"""
|
| 212 |
+
Enhance audio using U-Net model
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
audio: Raw audio waveform
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
enhanced_audio: Cleaned audio waveform
|
| 219 |
+
"""
|
| 220 |
+
# Convert to spectrogram (dB scale: [-80, 0])
|
| 221 |
+
noisy_spec = self.audio_processor.audio_to_spectrogram(audio)
|
| 222 |
+
|
| 223 |
+
# Normalize to [-1, 1] (matching training normalization)
|
| 224 |
+
noisy_spec_norm = (noisy_spec + 80.0) / 80.0 # [0, 1]
|
| 225 |
+
noisy_spec_norm = noisy_spec_norm * 2.0 - 1.0 # [-1, 1]
|
| 226 |
+
|
| 227 |
+
# Add batch and channel dimensions: (1, 1, H, W)
|
| 228 |
+
noisy_spec_tensor = torch.FloatTensor(noisy_spec_norm).unsqueeze(0).unsqueeze(0)
|
| 229 |
+
noisy_spec_tensor = noisy_spec_tensor.to(self.device)
|
| 230 |
+
|
| 231 |
+
# Run U-Net inference
|
| 232 |
+
with torch.no_grad():
|
| 233 |
+
clean_spec_tensor = self.cnn_model(noisy_spec_tensor)
|
| 234 |
+
# Handle NaN/Inf immediately after model output
|
| 235 |
+
clean_spec_tensor = torch.nan_to_num(clean_spec_tensor, nan=0.0, posinf=1.0, neginf=-1.0)
|
| 236 |
+
clean_spec_tensor = torch.clamp(clean_spec_tensor, -1.0, 1.0)
|
| 237 |
+
|
| 238 |
+
# Convert back to numpy
|
| 239 |
+
clean_spec_norm = clean_spec_tensor.squeeze().cpu().numpy()
|
| 240 |
+
|
| 241 |
+
# Denormalize: [-1, 1] β [0, 1] β [-80, 0] dB
|
| 242 |
+
clean_spec_norm = (clean_spec_norm + 1.0) / 2.0 # [-1,1] β [0,1]
|
| 243 |
+
clean_spec_db = clean_spec_norm * 80.0 - 80.0 # [0,1] β [-80,0]
|
| 244 |
+
|
| 245 |
+
# Ensure valid dB range
|
| 246 |
+
clean_spec_db = np.nan_to_num(clean_spec_db, nan=-80.0, posinf=0.0, neginf=-80.0)
|
| 247 |
+
clean_spec_db = np.clip(clean_spec_db, -80.0, 0.0)
|
| 248 |
+
|
| 249 |
+
# Convert spectrogram to audio (more iterations for better quality)
|
| 250 |
+
enhanced_audio = self.audio_processor.spectrogram_to_audio(clean_spec_db, n_iter=60)
|
| 251 |
+
|
| 252 |
+
# Normalize and clip
|
| 253 |
+
enhanced_audio = self.audio_processor.normalize_audio(enhanced_audio)
|
| 254 |
+
enhanced_audio = np.clip(enhanced_audio, -1.0, 1.0)
|
| 255 |
+
|
| 256 |
+
return enhanced_audio
|
| 257 |
+
|
| 258 |
+
def transcribe_audio(self, audio: np.ndarray, language: str = 'en') -> Dict:
|
| 259 |
+
"""
|
| 260 |
+
Transcribe audio using Whisper
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
audio: Audio waveform (numpy array)
|
| 264 |
+
language: Language code (e.g., 'en', 'es', 'fr')
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
result: Dictionary with transcription and metadata
|
| 268 |
+
"""
|
| 269 |
+
# Whisper expects float32 audio normalized to [-1, 1]
|
| 270 |
+
audio = audio.astype(np.float32)
|
| 271 |
+
|
| 272 |
+
# Pad or trim to 30 seconds max for efficiency
|
| 273 |
+
max_length = 30 * self.audio_processor.sample_rate
|
| 274 |
+
if len(audio) > max_length:
|
| 275 |
+
print(f"β οΈ Audio longer than 30s, processing in chunks...")
|
| 276 |
+
|
| 277 |
+
result = self.whisper_model.transcribe(
|
| 278 |
+
audio,
|
| 279 |
+
language=language if language else None,
|
| 280 |
+
fp16=self.use_fp16,
|
| 281 |
+
verbose=False
|
| 282 |
+
)
|
| 283 |
+
return result
|
| 284 |
+
|
| 285 |
+
def process(
|
| 286 |
+
self,
|
| 287 |
+
audio_file: Union[str, Path, bytes, io.BytesIO],
|
| 288 |
+
language: str = 'en',
|
| 289 |
+
skip_enhancement: bool = False
|
| 290 |
+
) -> Dict:
|
| 291 |
+
"""
|
| 292 |
+
Complete processing pipeline
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
audio_file: Input audio (file path, bytes, or file object)
|
| 296 |
+
language: Target language for transcription
|
| 297 |
+
skip_enhancement: Skip enhancement step (use original audio)
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
result: Dictionary containing:
|
| 301 |
+
- transcript: Text transcription
|
| 302 |
+
- enhanced_audio: Cleaned audio (numpy array)
|
| 303 |
+
- duration: Audio duration in seconds
|
| 304 |
+
- language: Detected language
|
| 305 |
+
- segments: Timestamped segments
|
| 306 |
+
"""
|
| 307 |
+
# Load and preprocess
|
| 308 |
+
print("π΅ Loading audio...")
|
| 309 |
+
audio = self.audio_processor.load_audio(audio_file)
|
| 310 |
+
audio = self.audio_processor.normalize_audio(audio)
|
| 311 |
+
|
| 312 |
+
duration = len(audio) / self.audio_processor.sample_rate
|
| 313 |
+
print(f" Duration: {duration:.2f}s")
|
| 314 |
+
|
| 315 |
+
# Enhance with U-Net
|
| 316 |
+
if not skip_enhancement:
|
| 317 |
+
print("π§Ή Enhancing audio with U-Net...")
|
| 318 |
+
enhanced_audio = self.enhance_audio(audio)
|
| 319 |
+
else:
|
| 320 |
+
print("βοΈ Skipping enhancement...")
|
| 321 |
+
enhanced_audio = audio
|
| 322 |
+
|
| 323 |
+
# Transcribe with Whisper
|
| 324 |
+
print("π Transcribing with Whisper...")
|
| 325 |
+
transcription_result = self.transcribe_audio(enhanced_audio, language=language)
|
| 326 |
+
|
| 327 |
+
# Compile results
|
| 328 |
+
result = {
|
| 329 |
+
'transcript': transcription_result['text'].strip(),
|
| 330 |
+
'enhanced_audio': enhanced_audio,
|
| 331 |
+
'sample_rate': self.audio_processor.sample_rate,
|
| 332 |
+
'duration': duration,
|
| 333 |
+
'language': transcription_result.get('language', language),
|
| 334 |
+
'segments': transcription_result.get('segments', [])
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
print("β
Processing complete!")
|
| 338 |
+
return result
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def test_pipeline():
|
| 342 |
+
"""Test the pipeline with a sample audio file"""
|
| 343 |
+
print("="*70)
|
| 344 |
+
print("π§ͺ TESTING AUDIO ENHANCEMENT PIPELINE")
|
| 345 |
+
print("="*70)
|
| 346 |
+
|
| 347 |
+
# Paths
|
| 348 |
+
cnn_checkpoint = PROJECT_ROOT / "enhancement_model/checkpoints/best_model.pt"
|
| 349 |
+
test_audio = PROJECT_ROOT / "data/audio_raw/noisy_0000.wav"
|
| 350 |
+
output_audio = PROJECT_ROOT / "enhanced_test_output.wav"
|
| 351 |
+
|
| 352 |
+
if not test_audio.exists():
|
| 353 |
+
print(f"β Test audio not found: {test_audio}")
|
| 354 |
+
return
|
| 355 |
+
|
| 356 |
+
# Initialize pipeline
|
| 357 |
+
pipeline = EnhancementPipeline(
|
| 358 |
+
cnn_checkpoint_path=str(cnn_checkpoint),
|
| 359 |
+
whisper_model_name="base",
|
| 360 |
+
device=get_default_device()
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
# Process audio
|
| 364 |
+
result = pipeline.process(test_audio)
|
| 365 |
+
|
| 366 |
+
# Print results
|
| 367 |
+
print("\n" + "="*70)
|
| 368 |
+
print("π RESULTS")
|
| 369 |
+
print("="*70)
|
| 370 |
+
print(f"Transcript: {result['transcript']}")
|
| 371 |
+
print(f"Duration: {result['duration']:.2f}s")
|
| 372 |
+
print(f"Language: {result['language']}")
|
| 373 |
+
print(f"Segments: {len(result.get('segments', []))}")
|
| 374 |
+
|
| 375 |
+
# Save enhanced audio
|
| 376 |
+
sf.write(output_audio, result['enhanced_audio'], result['sample_rate'])
|
| 377 |
+
print(f"\nπΎ Enhanced audio saved to: {output_audio}")
|
| 378 |
+
print("="*70)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
if __name__ == "__main__":
|
| 382 |
+
test_pipeline()
|