thecodeworm commited on
Commit
9153525
Β·
verified Β·
1 Parent(s): dbe2f4e

Upload inference_pipeline.py

Browse files
Files changed (1) hide show
  1. inference_pipeline.py +382 -0
inference_pipeline.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio Enhancement and Transcription Pipeline
3
+ Handles real-time processing of user-uploaded audio
4
+ """
5
+
6
+ import sys
7
+ from pathlib import Path
8
+ import io
9
+
10
+ # Project root = ClearSpeech
11
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
12
+ if str(PROJECT_ROOT) not in sys.path:
13
+ sys.path.insert(0, str(PROJECT_ROOT))
14
+
15
+ from enhancement_model.model import UNetAudioEnhancer
16
+
17
+ import torch
18
+ import numpy as np
19
+ import librosa
20
+ import soundfile as sf
21
+ import whisper
22
+ from typing import Union, Dict, Tuple
23
+ import warnings
24
+
25
+ # Suppress librosa warnings
26
+ warnings.filterwarnings('ignore', category=UserWarning)
27
+
28
+
29
+ def get_default_device() -> str:
30
+ """Auto-detect best available device"""
31
+ if torch.cuda.is_available():
32
+ return "cuda"
33
+ elif torch.backends.mps.is_available():
34
+ return "mps"
35
+ else:
36
+ return "cpu"
37
+
38
+
39
+ class AudioProcessor:
40
+ """
41
+ Handles audio preprocessing (in-memory)
42
+ Reuses logic from preprocessing.py but for single files
43
+ """
44
+
45
+ def __init__(self, sample_rate=16000, n_fft=1024, hop_length=256, n_mels=128):
46
+ self.sample_rate = sample_rate
47
+ self.n_fft = n_fft
48
+ self.hop_length = hop_length
49
+ self.n_mels = n_mels
50
+ self.fmax = 8000
51
+
52
+ def load_audio(self, audio_file: Union[str, Path, bytes, io.BytesIO]) -> np.ndarray:
53
+ """
54
+ Load audio from file or bytes
55
+
56
+ Args:
57
+ audio_file: File path, file object, or bytes
58
+
59
+ Returns:
60
+ audio: Numpy array of audio samples
61
+ """
62
+ try:
63
+ if isinstance(audio_file, (str, Path)):
64
+ audio, _ = librosa.load(audio_file, sr=self.sample_rate, mono=True)
65
+ elif isinstance(audio_file, bytes):
66
+ audio, _ = librosa.load(io.BytesIO(audio_file), sr=self.sample_rate, mono=True)
67
+ else:
68
+ audio, _ = librosa.load(audio_file, sr=self.sample_rate, mono=True)
69
+ return audio
70
+ except Exception as e:
71
+ raise ValueError(f"Failed to load audio: {e}")
72
+
73
+ def normalize_audio(self, audio: np.ndarray, target_db: float = -20.0) -> np.ndarray:
74
+ """
75
+ Normalize audio with RMS-based approach for consistency
76
+
77
+ Args:
78
+ audio: Input audio
79
+ target_db: Target RMS level in dB
80
+
81
+ Returns:
82
+ Normalized audio
83
+ """
84
+ # RMS-based normalization (better than peak normalization)
85
+ rms = np.sqrt(np.mean(audio**2))
86
+ if rms > 0:
87
+ target_rms = 10 ** (target_db / 20)
88
+ audio = audio * (target_rms / rms)
89
+
90
+ # Clip to prevent distortion
91
+ audio = np.clip(audio, -1.0, 1.0)
92
+ return audio
93
+
94
+ def audio_to_spectrogram(self, audio: np.ndarray) -> np.ndarray:
95
+ """
96
+ Convert audio to mel-spectrogram
97
+
98
+ Args:
99
+ audio: Audio waveform
100
+
101
+ Returns:
102
+ mel_spec_db: Mel-spectrogram in dB scale
103
+ """
104
+ mel_spec = librosa.feature.melspectrogram(
105
+ y=audio,
106
+ sr=self.sample_rate,
107
+ n_fft=self.n_fft,
108
+ hop_length=self.hop_length,
109
+ n_mels=self.n_mels,
110
+ fmax=self.fmax
111
+ )
112
+
113
+ # Convert to dB with proper reference
114
+ mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
115
+
116
+ # Ensure valid range
117
+ mel_spec_db = np.clip(mel_spec_db, -80.0, 0.0)
118
+ return mel_spec_db
119
+
120
+ def spectrogram_to_audio(self, mel_spec_db: np.ndarray, n_iter: int = 60) -> np.ndarray:
121
+ """
122
+ Convert mel-spectrogram back to audio using Griffin-Lim
123
+
124
+ Args:
125
+ mel_spec_db: Mel-spectrogram in dB
126
+ n_iter: Griffin-Lim iterations (more = better quality)
127
+
128
+ Returns:
129
+ audio: Reconstructed waveform
130
+ """
131
+ # Ensure valid dB range
132
+ mel_spec_db = np.clip(mel_spec_db, -80.0, 0.0)
133
+ mel_spec_db = np.nan_to_num(mel_spec_db, nan=-80.0, posinf=0.0, neginf=-80.0)
134
+
135
+ # Convert from dB to power
136
+ mel_spec = librosa.db_to_power(mel_spec_db)
137
+
138
+ # Ensure non-negative power values
139
+ mel_spec = np.maximum(mel_spec, 1e-10)
140
+
141
+ # Convert to audio using Griffin-Lim with more iterations
142
+ audio = librosa.feature.inverse.mel_to_audio(
143
+ mel_spec,
144
+ sr=self.sample_rate,
145
+ n_fft=self.n_fft,
146
+ hop_length=self.hop_length,
147
+ n_iter=n_iter
148
+ )
149
+
150
+ # Handle any NaN or Inf in audio
151
+ audio = np.nan_to_num(audio, nan=0.0, posinf=1.0, neginf=-1.0)
152
+
153
+ return audio
154
+
155
+
156
+ class EnhancementPipeline:
157
+ """
158
+ Complete audio enhancement and transcription pipeline
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ cnn_checkpoint_path: str,
164
+ whisper_model_name: str = "base",
165
+ device: str = None,
166
+ use_fp16: bool = False
167
+ ):
168
+ """
169
+ Initialize the pipeline with models
170
+
171
+ Args:
172
+ cnn_checkpoint_path: Path to trained CNN model
173
+ whisper_model_name: Whisper model size (tiny, base, small, medium, large)
174
+ device: 'cuda', 'mps', or 'cpu'
175
+ use_fp16: Use half precision for Whisper (faster on GPU)
176
+ """
177
+ if device is None:
178
+ device = get_default_device()
179
+ self.device = torch.device(device)
180
+ self.use_fp16 = use_fp16 and (device == "cuda")
181
+
182
+ print(f"πŸ–₯️ Using device: {self.device}")
183
+
184
+ self.audio_processor = AudioProcessor()
185
+
186
+ # Load CNN enhancement model
187
+ print(f"πŸ“₯ Loading U-Net enhancement model...")
188
+ self.cnn_model = UNetAudioEnhancer(in_channels=1, out_channels=1)
189
+
190
+ try:
191
+ checkpoint = torch.load(cnn_checkpoint_path, map_location=self.device)
192
+ self.cnn_model.load_state_dict(checkpoint['model_state_dict'])
193
+ self.cnn_model.to(self.device)
194
+ self.cnn_model.eval()
195
+
196
+ epoch = checkpoint.get('epoch', 'unknown')
197
+ val_loss = checkpoint.get('val_loss', 'unknown')
198
+ print(f"βœ… U-Net loaded (epoch {epoch}, val_loss: {val_loss})")
199
+ except Exception as e:
200
+ raise RuntimeError(f"Failed to load CNN model: {e}")
201
+
202
+ # Load Whisper model
203
+ print(f"πŸ“₯ Loading Whisper model ({whisper_model_name})...")
204
+ try:
205
+ self.whisper_model = whisper.load_model(whisper_model_name, device=str(self.device))
206
+ print("βœ… Whisper model loaded")
207
+ except Exception as e:
208
+ raise RuntimeError(f"Failed to load Whisper model: {e}")
209
+
210
+ def enhance_audio(self, audio: np.ndarray) -> np.ndarray:
211
+ """
212
+ Enhance audio using U-Net model
213
+
214
+ Args:
215
+ audio: Raw audio waveform
216
+
217
+ Returns:
218
+ enhanced_audio: Cleaned audio waveform
219
+ """
220
+ # Convert to spectrogram (dB scale: [-80, 0])
221
+ noisy_spec = self.audio_processor.audio_to_spectrogram(audio)
222
+
223
+ # Normalize to [-1, 1] (matching training normalization)
224
+ noisy_spec_norm = (noisy_spec + 80.0) / 80.0 # [0, 1]
225
+ noisy_spec_norm = noisy_spec_norm * 2.0 - 1.0 # [-1, 1]
226
+
227
+ # Add batch and channel dimensions: (1, 1, H, W)
228
+ noisy_spec_tensor = torch.FloatTensor(noisy_spec_norm).unsqueeze(0).unsqueeze(0)
229
+ noisy_spec_tensor = noisy_spec_tensor.to(self.device)
230
+
231
+ # Run U-Net inference
232
+ with torch.no_grad():
233
+ clean_spec_tensor = self.cnn_model(noisy_spec_tensor)
234
+ # Handle NaN/Inf immediately after model output
235
+ clean_spec_tensor = torch.nan_to_num(clean_spec_tensor, nan=0.0, posinf=1.0, neginf=-1.0)
236
+ clean_spec_tensor = torch.clamp(clean_spec_tensor, -1.0, 1.0)
237
+
238
+ # Convert back to numpy
239
+ clean_spec_norm = clean_spec_tensor.squeeze().cpu().numpy()
240
+
241
+ # Denormalize: [-1, 1] β†’ [0, 1] β†’ [-80, 0] dB
242
+ clean_spec_norm = (clean_spec_norm + 1.0) / 2.0 # [-1,1] β†’ [0,1]
243
+ clean_spec_db = clean_spec_norm * 80.0 - 80.0 # [0,1] β†’ [-80,0]
244
+
245
+ # Ensure valid dB range
246
+ clean_spec_db = np.nan_to_num(clean_spec_db, nan=-80.0, posinf=0.0, neginf=-80.0)
247
+ clean_spec_db = np.clip(clean_spec_db, -80.0, 0.0)
248
+
249
+ # Convert spectrogram to audio (more iterations for better quality)
250
+ enhanced_audio = self.audio_processor.spectrogram_to_audio(clean_spec_db, n_iter=60)
251
+
252
+ # Normalize and clip
253
+ enhanced_audio = self.audio_processor.normalize_audio(enhanced_audio)
254
+ enhanced_audio = np.clip(enhanced_audio, -1.0, 1.0)
255
+
256
+ return enhanced_audio
257
+
258
+ def transcribe_audio(self, audio: np.ndarray, language: str = 'en') -> Dict:
259
+ """
260
+ Transcribe audio using Whisper
261
+
262
+ Args:
263
+ audio: Audio waveform (numpy array)
264
+ language: Language code (e.g., 'en', 'es', 'fr')
265
+
266
+ Returns:
267
+ result: Dictionary with transcription and metadata
268
+ """
269
+ # Whisper expects float32 audio normalized to [-1, 1]
270
+ audio = audio.astype(np.float32)
271
+
272
+ # Pad or trim to 30 seconds max for efficiency
273
+ max_length = 30 * self.audio_processor.sample_rate
274
+ if len(audio) > max_length:
275
+ print(f"⚠️ Audio longer than 30s, processing in chunks...")
276
+
277
+ result = self.whisper_model.transcribe(
278
+ audio,
279
+ language=language if language else None,
280
+ fp16=self.use_fp16,
281
+ verbose=False
282
+ )
283
+ return result
284
+
285
+ def process(
286
+ self,
287
+ audio_file: Union[str, Path, bytes, io.BytesIO],
288
+ language: str = 'en',
289
+ skip_enhancement: bool = False
290
+ ) -> Dict:
291
+ """
292
+ Complete processing pipeline
293
+
294
+ Args:
295
+ audio_file: Input audio (file path, bytes, or file object)
296
+ language: Target language for transcription
297
+ skip_enhancement: Skip enhancement step (use original audio)
298
+
299
+ Returns:
300
+ result: Dictionary containing:
301
+ - transcript: Text transcription
302
+ - enhanced_audio: Cleaned audio (numpy array)
303
+ - duration: Audio duration in seconds
304
+ - language: Detected language
305
+ - segments: Timestamped segments
306
+ """
307
+ # Load and preprocess
308
+ print("🎡 Loading audio...")
309
+ audio = self.audio_processor.load_audio(audio_file)
310
+ audio = self.audio_processor.normalize_audio(audio)
311
+
312
+ duration = len(audio) / self.audio_processor.sample_rate
313
+ print(f" Duration: {duration:.2f}s")
314
+
315
+ # Enhance with U-Net
316
+ if not skip_enhancement:
317
+ print("🧹 Enhancing audio with U-Net...")
318
+ enhanced_audio = self.enhance_audio(audio)
319
+ else:
320
+ print("⏭️ Skipping enhancement...")
321
+ enhanced_audio = audio
322
+
323
+ # Transcribe with Whisper
324
+ print("πŸ“ Transcribing with Whisper...")
325
+ transcription_result = self.transcribe_audio(enhanced_audio, language=language)
326
+
327
+ # Compile results
328
+ result = {
329
+ 'transcript': transcription_result['text'].strip(),
330
+ 'enhanced_audio': enhanced_audio,
331
+ 'sample_rate': self.audio_processor.sample_rate,
332
+ 'duration': duration,
333
+ 'language': transcription_result.get('language', language),
334
+ 'segments': transcription_result.get('segments', [])
335
+ }
336
+
337
+ print("βœ… Processing complete!")
338
+ return result
339
+
340
+
341
+ def test_pipeline():
342
+ """Test the pipeline with a sample audio file"""
343
+ print("="*70)
344
+ print("πŸ§ͺ TESTING AUDIO ENHANCEMENT PIPELINE")
345
+ print("="*70)
346
+
347
+ # Paths
348
+ cnn_checkpoint = PROJECT_ROOT / "enhancement_model/checkpoints/best_model.pt"
349
+ test_audio = PROJECT_ROOT / "data/audio_raw/noisy_0000.wav"
350
+ output_audio = PROJECT_ROOT / "enhanced_test_output.wav"
351
+
352
+ if not test_audio.exists():
353
+ print(f"❌ Test audio not found: {test_audio}")
354
+ return
355
+
356
+ # Initialize pipeline
357
+ pipeline = EnhancementPipeline(
358
+ cnn_checkpoint_path=str(cnn_checkpoint),
359
+ whisper_model_name="base",
360
+ device=get_default_device()
361
+ )
362
+
363
+ # Process audio
364
+ result = pipeline.process(test_audio)
365
+
366
+ # Print results
367
+ print("\n" + "="*70)
368
+ print("πŸ“Š RESULTS")
369
+ print("="*70)
370
+ print(f"Transcript: {result['transcript']}")
371
+ print(f"Duration: {result['duration']:.2f}s")
372
+ print(f"Language: {result['language']}")
373
+ print(f"Segments: {len(result.get('segments', []))}")
374
+
375
+ # Save enhanced audio
376
+ sf.write(output_audio, result['enhanced_audio'], result['sample_rate'])
377
+ print(f"\nπŸ’Ύ Enhanced audio saved to: {output_audio}")
378
+ print("="*70)
379
+
380
+
381
+ if __name__ == "__main__":
382
+ test_pipeline()