jcudit HF Staff commited on
Commit
3ff2f18
·
1 Parent(s): 0456b70

fix: also correct lib/ in gitignore to only exclude root-level, add src/lib package

Browse files
.gitignore CHANGED
@@ -10,8 +10,8 @@ dist/
10
  downloads/
11
  eggs/
12
  .eggs/
13
- lib/
14
- lib64/
15
  parts/
16
  sdist/
17
  var/
 
10
  downloads/
11
  eggs/
12
  .eggs/
13
+ /lib/
14
+ /lib64/
15
  parts/
16
  sdist/
17
  var/
src/lib/__init__.py ADDED
File without changes
src/lib/audio_io.py ADDED
@@ -0,0 +1,703 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio I/O utilities: Read, write, and validate audio files.
3
+
4
+ Handles m4a and wav formats with format validation and error handling.
5
+ """
6
+
7
+ import logging
8
+ from pathlib import Path
9
+ from typing import Optional, Tuple
10
+
11
+ import numpy as np
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class AudioIOError(Exception):
17
+ """Custom exception for audio I/O errors."""
18
+
19
+ pass
20
+
21
+
22
+ def read_audio(file_path: str, target_sr: Optional[int] = None) -> Tuple[np.ndarray, int]:
23
+ """
24
+ Read audio file and return waveform and sample rate.
25
+
26
+ Supports m4a and wav formats. Automatically converts to mono if stereo.
27
+
28
+ Args:
29
+ file_path: Path to audio file
30
+ target_sr: Target sample rate (resamples if different), None = keep original
31
+
32
+ Returns:
33
+ Tuple of (audio_array, sample_rate)
34
+ - audio_array: 1D numpy array of audio samples (float32, mono)
35
+ - sample_rate: Sample rate in Hz
36
+
37
+ Raises:
38
+ AudioIOError: If file cannot be read or format is invalid
39
+ """
40
+ import subprocess
41
+ import tempfile
42
+
43
+ import soundfile as sf
44
+
45
+ file_path = Path(file_path)
46
+
47
+ if not file_path.exists():
48
+ raise AudioIOError(f"Audio file not found: {file_path}")
49
+
50
+ try:
51
+ # Try reading directly with soundfile
52
+ audio, sr = sf.read(str(file_path), dtype="float32")
53
+
54
+ except Exception as e:
55
+ # If M4A/AAC format not recognized, convert to WAV using FFmpeg
56
+ if file_path.suffix.lower() in [".m4a", ".aac", ".mp4"]:
57
+ logger.debug(f"Converting {file_path.suffix} to WAV for reading...")
58
+
59
+ # Create temporary WAV file
60
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav:
61
+ tmp_wav_path = tmp_wav.name
62
+
63
+ try:
64
+ # Convert M4A to WAV using FFmpeg
65
+ target_rate = target_sr if target_sr else 44100
66
+ cmd = [
67
+ "ffmpeg",
68
+ "-i",
69
+ str(file_path),
70
+ "-ar",
71
+ str(target_rate),
72
+ "-ac",
73
+ "1", # Mono
74
+ "-y", # Overwrite
75
+ tmp_wav_path,
76
+ ]
77
+
78
+ result = subprocess.run(cmd, capture_output=True, text=True, check=True)
79
+
80
+ # Read the converted WAV file
81
+ audio, sr = sf.read(tmp_wav_path, dtype="float32")
82
+
83
+ logger.debug(f"Converted and read {file_path.name} via FFmpeg")
84
+
85
+ finally:
86
+ # Clean up temporary file
87
+ if Path(tmp_wav_path).exists():
88
+ Path(tmp_wav_path).unlink()
89
+ else:
90
+ # Not an M4A file, re-raise the original error
91
+ raise AudioIOError(f"Failed to read audio file {file_path}: {str(e)}")
92
+
93
+ # Convert stereo to mono if needed (in case FFmpeg didn't do it)
94
+ if audio.ndim > 1:
95
+ audio = audio.mean(axis=1)
96
+
97
+ # Resample if target sample rate specified and not already done
98
+ if target_sr is not None and sr != target_sr:
99
+ audio = resample_audio(audio, sr, target_sr)
100
+ sr = target_sr
101
+
102
+ logger.debug(f"Read audio: {file_path.name} ({len(audio) / sr:.1f}s at {sr}Hz)")
103
+ return audio, sr
104
+
105
+
106
+ def write_audio(
107
+ file_path: str, audio: np.ndarray, sample_rate: int, format: Optional[str] = None
108
+ ) -> None:
109
+ """
110
+ Write audio array to file.
111
+
112
+ Args:
113
+ file_path: Output file path
114
+ audio: Audio array (1D numpy array, float32)
115
+ sample_rate: Sample rate in Hz
116
+ format: Audio format ('wav', 'm4a', etc.), auto-detected from extension if None
117
+
118
+ Raises:
119
+ AudioIOError: If file cannot be written
120
+ """
121
+ import subprocess
122
+ import tempfile
123
+
124
+ import soundfile as sf
125
+
126
+ file_path = Path(file_path)
127
+
128
+ # Create output directory if needed
129
+ file_path.parent.mkdir(parents=True, exist_ok=True)
130
+
131
+ # Ensure audio is 1D
132
+ if audio.ndim > 1:
133
+ audio = audio.squeeze()
134
+
135
+ # Auto-detect format from extension
136
+ if format is None:
137
+ format = file_path.suffix.lstrip(".")
138
+
139
+ try:
140
+ # Check if M4A/AAC format (not supported by soundfile)
141
+ if format.lower() in ["m4a", "aac", "mp4"]:
142
+ logger.debug(f"Converting to {format.upper()} via FFmpeg...")
143
+
144
+ # Write to temporary WAV file first
145
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav:
146
+ tmp_wav_path = tmp_wav.name
147
+
148
+ try:
149
+ # Write WAV using soundfile
150
+ sf.write(tmp_wav_path, audio, sample_rate, format="wav")
151
+
152
+ # Convert WAV to M4A using FFmpeg
153
+ # Clamp sample rate to M4A maximum (48kHz)
154
+ output_sr = min(sample_rate, 48000)
155
+ bitrate = "192k" # Good quality for voice
156
+
157
+ cmd = [
158
+ "ffmpeg",
159
+ "-i",
160
+ tmp_wav_path,
161
+ "-ar",
162
+ str(output_sr),
163
+ "-b:a",
164
+ bitrate,
165
+ "-c:a",
166
+ "aac",
167
+ "-y", # Overwrite
168
+ str(file_path),
169
+ ]
170
+
171
+ result = subprocess.run(cmd, capture_output=True, text=True, check=True)
172
+
173
+ logger.debug(
174
+ f"Wrote audio: {file_path.name} ({len(audio) / sample_rate:.1f}s at {output_sr}Hz, {bitrate})"
175
+ )
176
+
177
+ finally:
178
+ # Clean up temporary file
179
+ if Path(tmp_wav_path).exists():
180
+ Path(tmp_wav_path).unlink()
181
+ else:
182
+ # Write directly with soundfile for WAV and other supported formats
183
+ sf.write(str(file_path), audio, sample_rate, format=format)
184
+
185
+ logger.debug(
186
+ f"Wrote audio: {file_path.name} ({len(audio) / sample_rate:.1f}s at {sample_rate}Hz)"
187
+ )
188
+
189
+ except subprocess.CalledProcessError as e:
190
+ raise AudioIOError(f"FFmpeg conversion failed for {file_path}: {e.stderr}")
191
+ except Exception as e:
192
+ raise AudioIOError(f"Failed to write audio file {file_path}: {str(e)}")
193
+
194
+
195
+ def validate_audio_file(file_path: str, min_duration: float = 0.1) -> Tuple[bool, Optional[str]]:
196
+ """
197
+ Validate that file is a readable audio file with comprehensive checks.
198
+
199
+ Args:
200
+ file_path: Path to audio file
201
+ min_duration: Minimum duration in seconds (default: 0.1)
202
+
203
+ Returns:
204
+ Tuple of (is_valid, error_message)
205
+ - is_valid: True if file is valid audio
206
+ - error_message: Description of validation failure, None if valid
207
+ """
208
+ try:
209
+ file_path = Path(file_path)
210
+
211
+ # Check file exists
212
+ if not file_path.exists():
213
+ return False, f"File not found: {file_path}"
214
+
215
+ # Check file is not empty
216
+ if file_path.stat().st_size == 0:
217
+ return False, f"File is empty: {file_path}"
218
+
219
+ # Check file extension
220
+ valid_extensions = {".m4a", ".wav", ".mp3", ".flac", ".ogg", ".aac", ".mp4"}
221
+ if file_path.suffix.lower() not in valid_extensions:
222
+ return (
223
+ False,
224
+ f"Unsupported format: {file_path.suffix}. Supported formats: {', '.join(valid_extensions)}",
225
+ )
226
+
227
+ # Try to read file metadata
228
+ import subprocess
229
+
230
+ import soundfile as sf
231
+
232
+ try:
233
+ # For M4A/AAC, use ffprobe for metadata
234
+ if file_path.suffix.lower() in [".m4a", ".aac", ".mp4"]:
235
+ result = subprocess.run(
236
+ [
237
+ "ffprobe",
238
+ "-v",
239
+ "error",
240
+ "-show_entries",
241
+ "format=duration,bit_rate:stream=codec_name,sample_rate,channels",
242
+ "-of",
243
+ "json",
244
+ str(file_path),
245
+ ],
246
+ capture_output=True,
247
+ text=True,
248
+ check=True,
249
+ )
250
+
251
+ import json
252
+
253
+ probe_data = json.loads(result.stdout)
254
+
255
+ if "format" not in probe_data or "duration" not in probe_data["format"]:
256
+ return False, f"Invalid audio file: Cannot read metadata"
257
+
258
+ duration = float(probe_data["format"]["duration"])
259
+ if duration < min_duration:
260
+ return False, f"Audio too short: {duration:.2f}s (minimum: {min_duration}s)"
261
+
262
+ else:
263
+ # For WAV and other formats, use soundfile
264
+ info = sf.info(str(file_path))
265
+
266
+ # Check basic properties
267
+ if info.samplerate <= 0:
268
+ return False, f"Invalid sample rate: {info.samplerate}"
269
+
270
+ if info.frames <= 0:
271
+ return False, f"No audio frames in file"
272
+
273
+ duration = info.frames / info.samplerate
274
+ if duration < min_duration:
275
+ return False, f"Audio too short: {duration:.2f}s (minimum: {min_duration}s)"
276
+
277
+ except FileNotFoundError:
278
+ return False, "FFmpeg/FFprobe not found. Please install FFmpeg for M4A support."
279
+ except subprocess.CalledProcessError as e:
280
+ return False, f"Cannot read audio metadata: {e.stderr}"
281
+ except Exception as e:
282
+ return False, f"Invalid audio file: {str(e)}"
283
+
284
+ return True, None
285
+
286
+ except Exception as e:
287
+ return False, f"Validation error: {str(e)}"
288
+
289
+
290
+ def get_audio_duration(file_path: str) -> float:
291
+ """
292
+ Get duration of audio file in seconds.
293
+
294
+ Args:
295
+ file_path: Path to audio file
296
+
297
+ Returns:
298
+ Duration in seconds
299
+
300
+ Raises:
301
+ AudioIOError: If file cannot be read
302
+ """
303
+ try:
304
+ # For M4A/AAC files, use FFprobe since soundfile doesn't support them
305
+ if Path(file_path).suffix.lower() in [".m4a", ".aac", ".mp4"]:
306
+ import subprocess
307
+
308
+ result = subprocess.run(
309
+ [
310
+ "ffprobe",
311
+ "-v",
312
+ "error",
313
+ "-show_entries",
314
+ "format=duration",
315
+ "-of",
316
+ "default=noprint_wrappers=1:nokey=1",
317
+ str(file_path),
318
+ ],
319
+ capture_output=True,
320
+ text=True,
321
+ check=True,
322
+ )
323
+ return float(result.stdout.strip())
324
+ else:
325
+ # For WAV and other formats, use soundfile
326
+ import soundfile as sf
327
+
328
+ info = sf.info(str(file_path))
329
+ return info.frames / info.samplerate
330
+
331
+ except Exception as e:
332
+ raise AudioIOError(f"Failed to get audio duration for {file_path}: {str(e)}")
333
+
334
+
335
+ def get_audio_info(file_path: str) -> dict:
336
+ """
337
+ Get detailed information about audio file.
338
+
339
+ Args:
340
+ file_path: Path to audio file
341
+
342
+ Returns:
343
+ Dictionary with keys: duration, sample_rate, channels, format, subtype
344
+
345
+ Raises:
346
+ AudioIOError: If file cannot be read
347
+ """
348
+ try:
349
+ import soundfile as sf
350
+
351
+ info = sf.info(str(file_path))
352
+
353
+ return {
354
+ "duration": info.frames / info.samplerate,
355
+ "sample_rate": info.samplerate,
356
+ "channels": info.channels,
357
+ "format": info.format,
358
+ "subtype": info.subtype,
359
+ "frames": info.frames,
360
+ }
361
+
362
+ except Exception as e:
363
+ raise AudioIOError(f"Failed to get audio info for {file_path}: {str(e)}")
364
+
365
+
366
+ def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
367
+ """
368
+ Resample audio to target sample rate.
369
+
370
+ Args:
371
+ audio: Audio array
372
+ orig_sr: Original sample rate
373
+ target_sr: Target sample rate
374
+
375
+ Returns:
376
+ Resampled audio array
377
+ """
378
+ try:
379
+ import librosa
380
+
381
+ if orig_sr == target_sr:
382
+ return audio
383
+
384
+ resampled = librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
385
+ return resampled
386
+
387
+ except Exception as e:
388
+ raise AudioIOError(f"Failed to resample audio: {str(e)}")
389
+
390
+
391
+ def normalize_audio(audio: np.ndarray, target_db: float = -20.0) -> np.ndarray:
392
+ """
393
+ Normalize audio to target dB level.
394
+
395
+ Args:
396
+ audio: Audio array
397
+ target_db: Target level in dB (default: -20dB)
398
+
399
+ Returns:
400
+ Normalized audio array
401
+ """
402
+ # Calculate current RMS
403
+ rms = np.sqrt(np.mean(audio**2))
404
+
405
+ if rms == 0:
406
+ return audio
407
+
408
+ # Calculate target RMS from dB
409
+ target_rms = 10 ** (target_db / 20)
410
+
411
+ # Apply gain
412
+ gain = target_rms / rms
413
+ normalized = audio * gain
414
+
415
+ # Prevent clipping
416
+ max_val = np.abs(normalized).max()
417
+ if max_val > 1.0:
418
+ normalized = normalized / max_val * 0.99
419
+
420
+ return normalized
421
+
422
+
423
+ def extract_segment(
424
+ audio: np.ndarray, sample_rate: int, start_time: float, end_time: float
425
+ ) -> np.ndarray:
426
+ """
427
+ Extract segment from audio array.
428
+
429
+ Args:
430
+ audio: Audio array
431
+ sample_rate: Sample rate in Hz
432
+ start_time: Start time in seconds
433
+ end_time: End time in seconds
434
+
435
+ Returns:
436
+ Audio segment array
437
+ """
438
+ start_sample = int(start_time * sample_rate)
439
+ end_sample = int(end_time * sample_rate)
440
+
441
+ # Clamp to valid range
442
+ start_sample = max(0, start_sample)
443
+ end_sample = min(len(audio), end_sample)
444
+
445
+ return audio[start_sample:end_sample]
446
+
447
+
448
+ def split_audio_chunks(
449
+ audio: np.ndarray, sample_rate: int, chunk_duration: float, overlap: float = 0.0
450
+ ) -> list:
451
+ """
452
+ Split audio into chunks for processing.
453
+
454
+ Args:
455
+ audio: Audio array
456
+ sample_rate: Sample rate in Hz
457
+ chunk_duration: Chunk duration in seconds
458
+ overlap: Overlap between chunks in seconds
459
+
460
+ Returns:
461
+ List of (chunk_audio, start_time, end_time) tuples
462
+ """
463
+ chunk_samples = int(chunk_duration * sample_rate)
464
+ overlap_samples = int(overlap * sample_rate)
465
+ step_samples = chunk_samples - overlap_samples
466
+
467
+ chunks = []
468
+ position = 0
469
+
470
+ while position < len(audio):
471
+ chunk_end = min(position + chunk_samples, len(audio))
472
+ chunk = audio[position:chunk_end]
473
+
474
+ start_time = position / sample_rate
475
+ end_time = chunk_end / sample_rate
476
+
477
+ chunks.append((chunk, start_time, end_time))
478
+
479
+ position += step_samples
480
+
481
+ # Stop if we've reached the end
482
+ if chunk_end >= len(audio):
483
+ break
484
+
485
+ return chunks
486
+
487
+
488
+ # ===== M4A/WAV Conversion Utilities (T007-T008) =====
489
+
490
+
491
+ def convert_m4a_to_wav(
492
+ input_path: str, output_path: Optional[str] = None, sample_rate: int = 16000
493
+ ) -> str:
494
+ """
495
+ Convert M4A/AAC audio file to WAV format using FFmpeg.
496
+
497
+ This is required for pyannote.audio processing which expects WAV input.
498
+
499
+ Args:
500
+ input_path: Path to input M4A/AAC file
501
+ output_path: Path for output WAV file (auto-generated if None)
502
+ sample_rate: Target sample rate in Hz (default: 16000 for pyannote)
503
+
504
+ Returns:
505
+ Path to converted WAV file
506
+
507
+ Raises:
508
+ AudioIOError: If conversion fails or FFmpeg is not available
509
+ """
510
+ import subprocess
511
+ from pathlib import Path
512
+
513
+ input_path = Path(input_path)
514
+
515
+ if not input_path.exists():
516
+ raise AudioIOError(f"Input file not found: {input_path}")
517
+
518
+ # Auto-generate output path if not provided
519
+ if output_path is None:
520
+ output_path = input_path.with_suffix(".wav")
521
+ else:
522
+ output_path = Path(output_path)
523
+
524
+ # Create output directory if needed
525
+ output_path.parent.mkdir(parents=True, exist_ok=True)
526
+
527
+ try:
528
+ # Run FFmpeg conversion
529
+ cmd = [
530
+ "ffmpeg",
531
+ "-i",
532
+ str(input_path),
533
+ "-ar",
534
+ str(sample_rate), # Resample to target rate
535
+ "-ac",
536
+ "1", # Convert to mono
537
+ "-y", # Overwrite output
538
+ str(output_path),
539
+ ]
540
+
541
+ result = subprocess.run(cmd, capture_output=True, text=True, check=True)
542
+
543
+ logger.info(f"Converted {input_path.name} to WAV at {sample_rate}Hz")
544
+ return str(output_path)
545
+
546
+ except FileNotFoundError:
547
+ raise AudioIOError(
548
+ "FFmpeg not found. Please install FFmpeg: https://ffmpeg.org/download.html"
549
+ )
550
+ except subprocess.CalledProcessError as e:
551
+ raise AudioIOError(f"FFmpeg conversion failed: {e.stderr}")
552
+
553
+
554
+ def convert_wav_to_m4a(
555
+ input_path: str, output_path: str, sample_rate: int = 44100, bitrate: str = "192k"
556
+ ) -> str:
557
+ """
558
+ Convert WAV audio file to M4A/AAC format using FFmpeg.
559
+
560
+ Used for exporting final processed audio in M4A format.
561
+
562
+ Args:
563
+ input_path: Path to input WAV file
564
+ output_path: Path for output M4A file
565
+ sample_rate: Target sample rate in Hz (default: 44100, max 48000 for M4A)
566
+ bitrate: Target bitrate (default: "192k")
567
+
568
+ Returns:
569
+ Path to converted M4A file
570
+
571
+ Raises:
572
+ AudioIOError: If conversion fails or FFmpeg is not available
573
+ """
574
+ import subprocess
575
+ from pathlib import Path
576
+
577
+ input_path = Path(input_path)
578
+ output_path = Path(output_path)
579
+
580
+ if not input_path.exists():
581
+ raise AudioIOError(f"Input file not found: {input_path}")
582
+
583
+ # Validate sample rate for M4A (max 48kHz)
584
+ if sample_rate > 48000:
585
+ logger.warning(f"Sample rate {sample_rate}Hz exceeds M4A limit, using 48000Hz")
586
+ sample_rate = 48000
587
+
588
+ # Create output directory if needed
589
+ output_path.parent.mkdir(parents=True, exist_ok=True)
590
+
591
+ try:
592
+ # Run FFmpeg conversion
593
+ cmd = [
594
+ "ffmpeg",
595
+ "-i",
596
+ str(input_path),
597
+ "-ar",
598
+ str(sample_rate), # Resample to target rate
599
+ "-b:a",
600
+ bitrate, # Set bitrate
601
+ "-c:a",
602
+ "aac", # Use AAC codec
603
+ "-y", # Overwrite output
604
+ str(output_path),
605
+ ]
606
+
607
+ result = subprocess.run(cmd, capture_output=True, text=True, check=True)
608
+
609
+ logger.info(f"Converted {input_path.name} to M4A at {sample_rate}Hz, {bitrate}")
610
+ return str(output_path)
611
+
612
+ except FileNotFoundError:
613
+ raise AudioIOError(
614
+ "FFmpeg not found. Please install FFmpeg: https://ffmpeg.org/download.html"
615
+ )
616
+ except subprocess.CalledProcessError as e:
617
+ raise AudioIOError(f"FFmpeg conversion failed: {e.stderr}")
618
+
619
+
620
+ # ===== Audio Quality Validation (T009) =====
621
+
622
+
623
+ def validate_audio_quality(
624
+ audio: np.ndarray, sample_rate: int, file_path: Optional[str] = None
625
+ ) -> dict:
626
+ """
627
+ Validate audio quality and return metrics.
628
+
629
+ Checks for issues like:
630
+ - Signal-to-Noise Ratio (SNR)
631
+ - Clipping/distortion
632
+ - Duration requirements
633
+ - RMS energy levels
634
+
635
+ Args:
636
+ audio: Audio array
637
+ sample_rate: Sample rate in Hz
638
+ file_path: Optional file path for logging
639
+
640
+ Returns:
641
+ Dictionary with quality metrics and validation results:
642
+ {
643
+ 'snr_db': float, # Signal-to-noise ratio in dB
644
+ 'is_clipped': bool, # True if audio has clipping
645
+ 'clipping_ratio': float, # Percentage of clipped samples
646
+ 'rms_energy': float, # RMS energy level
647
+ 'is_too_quiet': bool, # True if audio is too quiet
648
+ 'duration': float, # Duration in seconds
649
+ 'is_valid': bool, # Overall validation result
650
+ 'warnings': list, # List of warning messages
651
+ }
652
+ """
653
+ metrics = {"duration": len(audio) / sample_rate, "warnings": []}
654
+
655
+ # Calculate SNR estimate
656
+ noise_floor = np.percentile(np.abs(audio), 10)
657
+ signal_peak = np.percentile(np.abs(audio), 90)
658
+ snr_db = 20 * np.log10(signal_peak / (noise_floor + 1e-10))
659
+ metrics["snr_db"] = float(snr_db)
660
+
661
+ if snr_db < 15:
662
+ metrics["warnings"].append(f"Low SNR ({snr_db:.1f} dB < 15 dB)")
663
+
664
+ # Check for clipping
665
+ clipping_threshold = 0.99
666
+ clipped_samples = np.sum(np.abs(audio) > clipping_threshold)
667
+ clipping_ratio = clipped_samples / len(audio)
668
+ metrics["is_clipped"] = clipping_ratio > 0.01
669
+ metrics["clipping_ratio"] = float(clipping_ratio)
670
+
671
+ if metrics["is_clipped"]:
672
+ metrics["warnings"].append(f"Audio has clipping ({clipping_ratio * 100:.1f}% of samples)")
673
+
674
+ # Check RMS energy
675
+ rms_energy = np.sqrt(np.mean(audio**2))
676
+ metrics["rms_energy"] = float(rms_energy)
677
+ metrics["is_too_quiet"] = rms_energy < 0.01
678
+
679
+ if metrics["is_too_quiet"]:
680
+ metrics["warnings"].append(f"Audio is too quiet (RMS: {rms_energy:.4f})")
681
+
682
+ # Check duration
683
+ if metrics["duration"] < 1.0:
684
+ metrics["warnings"].append(f"Audio is very short ({metrics['duration']:.1f}s)")
685
+
686
+ # Overall validation
687
+ metrics["is_valid"] = (
688
+ snr_db >= 10 # Minimum acceptable SNR
689
+ and not metrics["is_clipped"]
690
+ and not metrics["is_too_quiet"]
691
+ and metrics["duration"] > 0.5
692
+ )
693
+
694
+ # Log results
695
+ file_desc = f" for {file_path}" if file_path else ""
696
+ if metrics["is_valid"]:
697
+ logger.debug(f"Audio quality validation passed{file_desc}")
698
+ else:
699
+ logger.warning(
700
+ f"Audio quality validation failed{file_desc}: " + ", ".join(metrics["warnings"])
701
+ )
702
+
703
+ return metrics
src/lib/format_converter.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio format converter: m4a ↔ wav conversion, sample rate normalization.
3
+
4
+ Converts between m4a (compressed) and wav (lossless) formats.
5
+ Normalizes to 48kHz/24-bit for processing, outputs as m4a/192kbps for final.
6
+ """
7
+
8
+ import logging
9
+ import tempfile
10
+ from pathlib import Path
11
+ from typing import Optional
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class FormatConversionError(Exception):
17
+ """Custom exception for format conversion errors."""
18
+
19
+ pass
20
+
21
+
22
+ def m4a_to_wav(
23
+ input_path: str,
24
+ output_path: Optional[str] = None,
25
+ target_sr: int = 48000,
26
+ target_bit_depth: int = 24,
27
+ ) -> str:
28
+ """
29
+ Convert m4a to wav format with normalization.
30
+
31
+ Args:
32
+ input_path: Path to input m4a file
33
+ output_path: Path to output wav file (temp file if None)
34
+ target_sr: Target sample rate in Hz (default: 48000)
35
+ target_bit_depth: Target bit depth (default: 24)
36
+
37
+ Returns:
38
+ Path to output wav file
39
+
40
+ Raises:
41
+ FormatConversionError: If conversion fails
42
+ """
43
+ try:
44
+ from pydub import AudioSegment
45
+
46
+ input_path = Path(input_path)
47
+
48
+ if not input_path.exists():
49
+ raise FormatConversionError(f"Input file not found: {input_path}")
50
+
51
+ # Create output path if not provided
52
+ if output_path is None:
53
+ temp_dir = tempfile.gettempdir()
54
+ output_path = Path(temp_dir) / f"{input_path.stem}_temp.wav"
55
+ else:
56
+ output_path = Path(output_path)
57
+ output_path.parent.mkdir(parents=True, exist_ok=True)
58
+
59
+ # Load m4a
60
+ audio = AudioSegment.from_file(str(input_path), format="m4a")
61
+
62
+ # Normalize to target format
63
+ audio = audio.set_frame_rate(target_sr)
64
+ audio = audio.set_channels(1) # Mono
65
+ audio = audio.set_sample_width(target_bit_depth // 8) # Bytes (24-bit = 3 bytes)
66
+
67
+ # Export as wav
68
+ audio.export(str(output_path), format="wav")
69
+
70
+ logger.debug(f"Converted m4a to wav: {input_path.name} -> {output_path.name}")
71
+ return str(output_path)
72
+
73
+ except Exception as e:
74
+ if isinstance(e, FormatConversionError):
75
+ raise
76
+ raise FormatConversionError(f"Failed to convert m4a to wav: {str(e)}")
77
+
78
+
79
+ def wav_to_m4a(
80
+ input_path: str, output_path: str, bitrate: str = "192k", sample_rate: int = 48000
81
+ ) -> str:
82
+ """
83
+ Convert wav to m4a format.
84
+
85
+ Args:
86
+ input_path: Path to input wav file
87
+ output_path: Path to output m4a file
88
+ bitrate: AAC bitrate (default: "192k")
89
+ sample_rate: Sample rate in Hz (default: 48000)
90
+
91
+ Returns:
92
+ Path to output m4a file
93
+
94
+ Raises:
95
+ FormatConversionError: If conversion fails
96
+ """
97
+ try:
98
+ from pydub import AudioSegment
99
+
100
+ input_path = Path(input_path)
101
+ output_path = Path(output_path)
102
+
103
+ if not input_path.exists():
104
+ raise FormatConversionError(f"Input file not found: {input_path}")
105
+
106
+ # Create output directory if needed
107
+ output_path.parent.mkdir(parents=True, exist_ok=True)
108
+
109
+ # Load wav
110
+ audio = AudioSegment.from_file(str(input_path), format="wav")
111
+
112
+ # Normalize sample rate
113
+ audio = audio.set_frame_rate(sample_rate)
114
+
115
+ # Export as m4a with AAC codec
116
+ audio.export(
117
+ str(output_path),
118
+ format="mp4", # m4a uses mp4 container
119
+ codec="aac",
120
+ bitrate=bitrate,
121
+ parameters=["-profile:a", "aac_low"],
122
+ )
123
+
124
+ logger.debug(f"Converted wav to m4a: {input_path.name} -> {output_path.name}")
125
+ return str(output_path)
126
+
127
+ except Exception as e:
128
+ if isinstance(e, FormatConversionError):
129
+ raise
130
+ raise FormatConversionError(f"Failed to convert wav to m4a: {str(e)}")
131
+
132
+
133
+ def normalize_to_intermediate(input_path: str, output_path: Optional[str] = None) -> str:
134
+ """
135
+ Normalize any audio format to intermediate wav format (48kHz/24-bit/mono).
136
+
137
+ This is the standard intermediate format for all processing.
138
+
139
+ Args:
140
+ input_path: Path to input audio file (m4a, wav, mp3, etc.)
141
+ output_path: Path to output wav file (temp file if None)
142
+
143
+ Returns:
144
+ Path to normalized wav file
145
+
146
+ Raises:
147
+ FormatConversionError: If normalization fails
148
+ """
149
+ try:
150
+ from pydub import AudioSegment
151
+
152
+ input_path = Path(input_path)
153
+
154
+ if not input_path.exists():
155
+ raise FormatConversionError(f"Input file not found: {input_path}")
156
+
157
+ # Create output path if not provided
158
+ if output_path is None:
159
+ temp_dir = tempfile.gettempdir()
160
+ output_path = Path(temp_dir) / f"{input_path.stem}_normalized.wav"
161
+ else:
162
+ output_path = Path(output_path)
163
+ output_path.parent.mkdir(parents=True, exist_ok=True)
164
+
165
+ # Detect input format
166
+ input_format = input_path.suffix.lstrip(".")
167
+
168
+ # Load audio
169
+ audio = AudioSegment.from_file(str(input_path), format=input_format)
170
+
171
+ # Normalize to intermediate format: 48kHz, 24-bit, mono
172
+ audio = audio.set_frame_rate(48000)
173
+ audio = audio.set_channels(1)
174
+ audio = audio.set_sample_width(3) # 24-bit = 3 bytes
175
+
176
+ # Export as wav
177
+ audio.export(str(output_path), format="wav")
178
+
179
+ logger.debug(f"Normalized to intermediate: {input_path.name} -> {output_path.name}")
180
+ return str(output_path)
181
+
182
+ except Exception as e:
183
+ if isinstance(e, FormatConversionError):
184
+ raise
185
+ raise FormatConversionError(f"Failed to normalize audio: {str(e)}")
186
+
187
+
188
+ def convert_to_final_output(input_path: str, output_path: str, format: str = "m4a") -> str:
189
+ """
190
+ Convert intermediate wav to final output format.
191
+
192
+ Final output is m4a with AAC 192kbps, 48kHz, mono.
193
+
194
+ Args:
195
+ input_path: Path to input wav file
196
+ output_path: Path to output file
197
+ format: Output format (default: "m4a")
198
+
199
+ Returns:
200
+ Path to output file
201
+
202
+ Raises:
203
+ FormatConversionError: If conversion fails
204
+ """
205
+ if format == "m4a":
206
+ return wav_to_m4a(input_path, output_path, bitrate="192k", sample_rate=48000)
207
+ elif format == "wav":
208
+ # Just copy if wav output requested
209
+ import shutil
210
+
211
+ output_path = Path(output_path)
212
+ output_path.parent.mkdir(parents=True, exist_ok=True)
213
+ shutil.copy(input_path, output_path)
214
+ return str(output_path)
215
+ else:
216
+ raise FormatConversionError(f"Unsupported output format: {format}")
217
+
218
+
219
+ def batch_convert(
220
+ input_files: list, output_dir: str, output_format: str = "m4a", progress_callback=None
221
+ ) -> list:
222
+ """
223
+ Convert multiple files to output format.
224
+
225
+ Args:
226
+ input_files: List of input file paths
227
+ output_dir: Output directory
228
+ output_format: Output format (default: "m4a")
229
+ progress_callback: Optional callback(index, total, filename)
230
+
231
+ Returns:
232
+ List of output file paths
233
+
234
+ Raises:
235
+ FormatConversionError: If any conversion fails
236
+ """
237
+ output_dir = Path(output_dir)
238
+ output_dir.mkdir(parents=True, exist_ok=True)
239
+
240
+ output_files = []
241
+ total = len(input_files)
242
+
243
+ for i, input_file in enumerate(input_files):
244
+ input_path = Path(input_file)
245
+
246
+ # Generate output filename
247
+ output_name = f"{input_path.stem}.{output_format}"
248
+ output_path = output_dir / output_name
249
+
250
+ if progress_callback:
251
+ progress_callback(i + 1, total, input_path.name)
252
+
253
+ # Convert to intermediate then to final
254
+ intermediate = normalize_to_intermediate(str(input_path))
255
+ final = convert_to_final_output(intermediate, str(output_path), output_format)
256
+
257
+ # Clean up intermediate file
258
+ Path(intermediate).unlink(missing_ok=True)
259
+
260
+ output_files.append(final)
261
+
262
+ return output_files
263
+
264
+
265
+ def get_conversion_info(input_path: str) -> dict:
266
+ """
267
+ Get information about required conversion.
268
+
269
+ Args:
270
+ input_path: Path to input file
271
+
272
+ Returns:
273
+ Dictionary with conversion details
274
+ """
275
+ try:
276
+ from pydub import AudioSegment
277
+
278
+ input_path = Path(input_path)
279
+
280
+ if not input_path.exists():
281
+ return {"error": "File not found"}
282
+
283
+ # Load audio to inspect properties
284
+ input_format = input_path.suffix.lstrip(".")
285
+ audio = AudioSegment.from_file(str(input_path), format=input_format)
286
+
287
+ return {
288
+ "current_format": input_format,
289
+ "current_sample_rate": audio.frame_rate,
290
+ "current_channels": audio.channels,
291
+ "current_sample_width": audio.sample_width,
292
+ "duration_seconds": len(audio) / 1000.0,
293
+ "needs_conversion": (
294
+ audio.frame_rate != 48000 or audio.channels != 1 or audio.sample_width != 3
295
+ ),
296
+ "target_format": "wav (intermediate) -> m4a (final)",
297
+ "target_sample_rate": 48000,
298
+ "target_channels": 1,
299
+ "target_bit_depth": 24,
300
+ }
301
+
302
+ except Exception as e:
303
+ return {"error": str(e)}
304
+
305
+
306
+ def estimate_output_size(input_path: str, output_format: str = "m4a") -> int:
307
+ """
308
+ Estimate output file size in bytes.
309
+
310
+ Args:
311
+ input_path: Path to input file
312
+ output_format: Output format
313
+
314
+ Returns:
315
+ Estimated file size in bytes
316
+ """
317
+ try:
318
+ info = get_conversion_info(input_path)
319
+
320
+ if "error" in info:
321
+ return 0
322
+
323
+ duration = info["duration_seconds"]
324
+
325
+ if output_format == "m4a":
326
+ # AAC 192kbps = 192 * 1000 / 8 bytes per second
327
+ bitrate_bps = 192 * 1000 / 8
328
+ return int(duration * bitrate_bps)
329
+ elif output_format == "wav":
330
+ # 48kHz * 3 bytes (24-bit) * 1 channel
331
+ return int(duration * 48000 * 3)
332
+ else:
333
+ return 0
334
+
335
+ except Exception:
336
+ return 0
src/lib/gpu_utils.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GPU resource management utilities for ZeroGPU compatibility.
2
+
3
+ This module provides utilities for managing GPU resources, including model device
4
+ transfers, cache management, and context managers for automatic cleanup.
5
+ """
6
+
7
+ import logging
8
+ import time
9
+ from contextlib import contextmanager
10
+ from typing import Any, Optional
11
+
12
+ import torch
13
+
14
+ from src.config.gpu_config import GPUConfig
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def acquire_gpu(model: torch.nn.Module, device: str = "cuda") -> bool:
20
+ """Move a model to the specified GPU device.
21
+
22
+ Args:
23
+ model: PyTorch model to move to GPU
24
+ device: Target device (default: "cuda")
25
+
26
+ Returns:
27
+ bool: True if successful, False otherwise
28
+ """
29
+ try:
30
+ start_time = time.time()
31
+ target_device = torch.device(device)
32
+ model.to(target_device)
33
+ elapsed = time.time() - start_time
34
+
35
+ logger.debug(f"Model {model.__class__.__name__} moved to {device} in {elapsed:.3f}s")
36
+ return True
37
+ except Exception as e:
38
+ logger.error(f"Failed to move model to {device}: {e}")
39
+ return False
40
+
41
+
42
+ def release_gpu(model: torch.nn.Module, clear_cache: bool = True) -> bool:
43
+ """Move a model back to CPU and optionally clear CUDA cache.
44
+
45
+ Args:
46
+ model: PyTorch model to move to CPU
47
+ clear_cache: Whether to clear CUDA cache after moving
48
+
49
+ Returns:
50
+ bool: True if successful, False otherwise
51
+ """
52
+ try:
53
+ start_time = time.time()
54
+ model.to(torch.device("cpu"))
55
+
56
+ if clear_cache and GPUConfig.ENABLE_CACHE_CLEARING and torch.cuda.is_available():
57
+ torch.cuda.empty_cache()
58
+
59
+ elapsed = time.time() - start_time
60
+
61
+ if elapsed > GPUConfig.CLEANUP_TIMEOUT:
62
+ logger.warning(
63
+ f"GPU cleanup took {elapsed:.3f}s, exceeding {GPUConfig.CLEANUP_TIMEOUT}s limit"
64
+ )
65
+ else:
66
+ logger.debug(f"GPU released in {elapsed:.3f}s")
67
+
68
+ return True
69
+ except Exception as e:
70
+ logger.error(f"Failed to release GPU: {e}")
71
+ return False
72
+
73
+
74
+ @contextmanager
75
+ def gpu_context(model: torch.nn.Module, device: str = "cuda"):
76
+ """Context manager for automatic GPU resource management.
77
+
78
+ Acquires GPU on entry and releases it on exit, even if an exception occurs.
79
+
80
+ Args:
81
+ model: PyTorch model to manage
82
+ device: Target GPU device (default: "cuda")
83
+
84
+ Yields:
85
+ torch.nn.Module: The model on the GPU device
86
+
87
+ Example:
88
+ >>> with gpu_context(my_model) as model:
89
+ ... result = model(input_data)
90
+ """
91
+ acquired = False
92
+ try:
93
+ acquired = acquire_gpu(model, device)
94
+ if not acquired:
95
+ logger.warning(f"Failed to acquire GPU, model remains on {model.device}")
96
+ yield model
97
+ finally:
98
+ if acquired:
99
+ release_gpu(model, clear_cache=True)
100
+
101
+
102
+ def move_to_device(data: Any, device: torch.device) -> Any:
103
+ """Recursively move tensors to the specified device.
104
+
105
+ Handles nested structures like lists, tuples, and dicts.
106
+
107
+ Args:
108
+ data: Data to move (tensor, list, tuple, dict, or other)
109
+ device: Target device
110
+
111
+ Returns:
112
+ Data with all tensors moved to the device
113
+ """
114
+ if isinstance(data, torch.Tensor):
115
+ return data.to(device)
116
+ elif isinstance(data, dict):
117
+ return {k: move_to_device(v, device) for k, v in data.items()}
118
+ elif isinstance(data, list):
119
+ return [move_to_device(item, device) for item in data]
120
+ elif isinstance(data, tuple):
121
+ return tuple(move_to_device(item, device) for item in data)
122
+ else:
123
+ return data
124
+
125
+
126
+ def get_gpu_memory_info() -> Optional[dict]:
127
+ """Get current GPU memory usage information.
128
+
129
+ Returns:
130
+ dict: Memory information with 'allocated' and 'reserved' in GB, or None if CUDA unavailable
131
+ """
132
+ if not torch.cuda.is_available():
133
+ return None
134
+
135
+ try:
136
+ allocated = torch.cuda.memory_allocated() / 1024**3 # Convert to GB
137
+ reserved = torch.cuda.memory_reserved() / 1024**3
138
+ return {
139
+ "allocated_gb": round(allocated, 2),
140
+ "reserved_gb": round(reserved, 2),
141
+ }
142
+ except Exception as e:
143
+ logger.error(f"Failed to get GPU memory info: {e}")
144
+ return None
145
+
146
+
147
+ def log_gpu_usage(operation: str):
148
+ """Log current GPU memory usage for a specific operation.
149
+
150
+ Args:
151
+ operation: Description of the operation being performed
152
+ """
153
+ memory_info = get_gpu_memory_info()
154
+ if memory_info:
155
+ logger.info(
156
+ f"[{operation}] GPU Memory - Allocated: {memory_info['allocated_gb']}GB, "
157
+ f"Reserved: {memory_info['reserved_gb']}GB"
158
+ )
src/lib/memory_optimizer.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Memory optimization utilities.
3
+
4
+ Provides utilities for processing large audio files (>1 hour) efficiently
5
+ without running out of memory.
6
+ """
7
+
8
+ import gc
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Iterator, List, Optional, Tuple
12
+
13
+ import numpy as np
14
+
15
+ from src.lib.audio_io import AudioIOError, read_audio
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class AudioChunker:
21
+ """
22
+ Utility for processing large audio files in chunks.
23
+
24
+ Allows processing audio files that are too large to fit in memory
25
+ by streaming them in manageable chunks.
26
+ """
27
+
28
+ def __init__(self, chunk_duration: float = 60.0, overlap: float = 5.0):
29
+ """
30
+ Initialize audio chunker.
31
+
32
+ Args:
33
+ chunk_duration: Duration of each chunk in seconds (default: 60s)
34
+ overlap: Overlap between chunks in seconds (default: 5s)
35
+ """
36
+ self.chunk_duration = chunk_duration
37
+ self.overlap = overlap
38
+
39
+ logger.debug(f"AudioChunker initialized (chunk: {chunk_duration}s, overlap: {overlap}s)")
40
+
41
+ def iter_chunks(
42
+ self, file_path: str, target_sr: int = 16000
43
+ ) -> Iterator[Tuple[np.ndarray, int, float, float]]:
44
+ """
45
+ Iterate over audio file in chunks.
46
+
47
+ Args:
48
+ file_path: Path to audio file
49
+ target_sr: Target sample rate
50
+
51
+ Yields:
52
+ Tuples of (audio_chunk, sample_rate, start_time, end_time)
53
+
54
+ Raises:
55
+ AudioIOError: If file cannot be read
56
+ """
57
+ try:
58
+ # Read full audio (we'll optimize this for truly large files later)
59
+ audio, sr = read_audio(file_path, target_sr=target_sr)
60
+ total_duration = len(audio) / sr
61
+
62
+ logger.info(
63
+ f"Processing {Path(file_path).name} in chunks "
64
+ f"(duration: {total_duration:.1f}s, chunk size: {self.chunk_duration}s)"
65
+ )
66
+
67
+ # Calculate chunk parameters
68
+ chunk_samples = int(self.chunk_duration * sr)
69
+ overlap_samples = int(self.overlap * sr)
70
+ step_samples = chunk_samples - overlap_samples
71
+
72
+ position = 0
73
+ chunk_idx = 0
74
+
75
+ while position < len(audio):
76
+ # Extract chunk
77
+ chunk_start = position
78
+ chunk_end = min(position + chunk_samples, len(audio))
79
+ chunk = audio[chunk_start:chunk_end]
80
+
81
+ # Calculate time boundaries
82
+ start_time = chunk_start / sr
83
+ end_time = chunk_end / sr
84
+
85
+ logger.debug(
86
+ f"Chunk {chunk_idx}: {start_time:.1f}s - {end_time:.1f}s "
87
+ f"({len(chunk) / sr:.1f}s)"
88
+ )
89
+
90
+ yield chunk, sr, start_time, end_time
91
+
92
+ # Move to next chunk
93
+ position += step_samples
94
+ chunk_idx += 1
95
+
96
+ # Force garbage collection between chunks
97
+ gc.collect()
98
+
99
+ logger.info(f"Processed {chunk_idx} chunks")
100
+
101
+ except Exception as e:
102
+ logger.error(f"Failed to process chunks: {e}")
103
+ raise AudioIOError(f"Chunking failed: {e}")
104
+
105
+ def process_file_in_chunks(
106
+ self, file_path: str, processor_func, target_sr: int = 16000, **processor_kwargs
107
+ ) -> List:
108
+ """
109
+ Process audio file in chunks with custom processor function.
110
+
111
+ Args:
112
+ file_path: Path to audio file
113
+ processor_func: Function to process each chunk
114
+ Should accept (audio, sr, start_time, end_time, **kwargs)
115
+ target_sr: Target sample rate
116
+ **processor_kwargs: Additional arguments for processor function
117
+
118
+ Returns:
119
+ List of processing results from each chunk
120
+
121
+ Example:
122
+ >>> def detect_segments(audio, sr, start_time, end_time):
123
+ ... # Process audio chunk
124
+ ... return segments
125
+ >>>
126
+ >>> chunker = AudioChunker(chunk_duration=60.0)
127
+ >>> results = chunker.process_file_in_chunks(
128
+ ... "long_file.m4a",
129
+ ... detect_segments
130
+ ... )
131
+ """
132
+ results = []
133
+
134
+ for chunk, sr, start_time, end_time in self.iter_chunks(file_path, target_sr):
135
+ try:
136
+ result = processor_func(chunk, sr, start_time, end_time, **processor_kwargs)
137
+ results.append(result)
138
+
139
+ except Exception as e:
140
+ logger.error(f"Chunk processing failed at {start_time:.1f}s: {e}")
141
+ # Continue with next chunk
142
+ continue
143
+
144
+ return results
145
+
146
+
147
+ class MemoryMonitor:
148
+ """
149
+ Monitor and manage memory usage during processing.
150
+ """
151
+
152
+ def __init__(self, max_memory_mb: Optional[float] = None):
153
+ """
154
+ Initialize memory monitor.
155
+
156
+ Args:
157
+ max_memory_mb: Maximum memory usage in MB (None = no limit)
158
+ """
159
+ self.max_memory_mb = max_memory_mb
160
+
161
+ try:
162
+ import os
163
+
164
+ import psutil
165
+
166
+ self.process = psutil.Process(os.getpid())
167
+ self.psutil_available = True
168
+ except ImportError:
169
+ logger.warning("psutil not available, memory monitoring disabled")
170
+ self.psutil_available = False
171
+
172
+ def get_current_memory_mb(self) -> float:
173
+ """
174
+ Get current memory usage in MB.
175
+
176
+ Returns:
177
+ Memory usage in MB, or 0 if unavailable
178
+ """
179
+ if not self.psutil_available:
180
+ return 0.0
181
+
182
+ try:
183
+ return self.process.memory_info().rss / 1024 / 1024
184
+ except Exception:
185
+ return 0.0
186
+
187
+ def check_memory_limit(self) -> bool:
188
+ """
189
+ Check if memory usage is below limit.
190
+
191
+ Returns:
192
+ True if within limit (or no limit set), False if exceeded
193
+ """
194
+ if self.max_memory_mb is None:
195
+ return True
196
+
197
+ current_mb = self.get_current_memory_mb()
198
+
199
+ if current_mb > self.max_memory_mb:
200
+ logger.warning(
201
+ f"Memory limit exceeded: {current_mb:.1f}MB > {self.max_memory_mb:.1f}MB"
202
+ )
203
+ return False
204
+
205
+ return True
206
+
207
+ def force_cleanup(self):
208
+ """Force garbage collection and cleanup."""
209
+ gc.collect()
210
+
211
+ if self.psutil_available:
212
+ try:
213
+ import torch
214
+
215
+ if torch.cuda.is_available():
216
+ torch.cuda.empty_cache()
217
+ logger.debug("Cleared CUDA cache")
218
+ except ImportError:
219
+ pass
220
+
221
+ logger.debug("Forced garbage collection")
222
+
223
+
224
+ def optimize_for_large_files(audio_duration: float) -> dict:
225
+ """
226
+ Get optimization recommendations for large files.
227
+
228
+ Args:
229
+ audio_duration: Duration of audio file in seconds
230
+
231
+ Returns:
232
+ Dictionary with optimization parameters
233
+ """
234
+ # Thresholds
235
+ LARGE_FILE_THRESHOLD = 3600 # 1 hour
236
+ VERY_LARGE_FILE_THRESHOLD = 7200 # 2 hours
237
+
238
+ config = {
239
+ "use_chunking": False,
240
+ "chunk_duration": 60.0,
241
+ "chunk_overlap": 5.0,
242
+ "force_gc_frequency": 10, # Force GC every N chunks
243
+ "recommended_batch_size": 32,
244
+ }
245
+
246
+ if audio_duration > VERY_LARGE_FILE_THRESHOLD:
247
+ # Very large file (>2 hours)
248
+ config.update(
249
+ {
250
+ "use_chunking": True,
251
+ "chunk_duration": 30.0, # Smaller chunks
252
+ "chunk_overlap": 3.0,
253
+ "force_gc_frequency": 5, # More frequent GC
254
+ "recommended_batch_size": 16, # Smaller batches
255
+ }
256
+ )
257
+ logger.info(
258
+ f"Large file detected ({audio_duration / 3600:.1f}h), "
259
+ "using aggressive memory optimization"
260
+ )
261
+
262
+ elif audio_duration > LARGE_FILE_THRESHOLD:
263
+ # Large file (>1 hour)
264
+ config.update(
265
+ {
266
+ "use_chunking": True,
267
+ "chunk_duration": 60.0,
268
+ "chunk_overlap": 5.0,
269
+ "force_gc_frequency": 10,
270
+ "recommended_batch_size": 24,
271
+ }
272
+ )
273
+ logger.info(
274
+ f"Large file detected ({audio_duration / 3600:.1f}h), using memory optimization"
275
+ )
276
+
277
+ return config
278
+
279
+
280
+ def estimate_memory_requirements(
281
+ audio_duration: float, sample_rate: int = 16000, num_models: int = 3, safety_factor: float = 2.0
282
+ ) -> float:
283
+ """
284
+ Estimate memory requirements for processing.
285
+
286
+ Args:
287
+ audio_duration: Duration in seconds
288
+ sample_rate: Sample rate in Hz
289
+ num_models: Number of ML models to load
290
+ safety_factor: Safety multiplier (default: 2.0)
291
+
292
+ Returns:
293
+ Estimated memory requirement in MB
294
+ """
295
+ # Audio data (float32 = 4 bytes)
296
+ audio_mb = (audio_duration * sample_rate * 4) / 1024 / 1024
297
+
298
+ # Model overhead (rough estimate)
299
+ model_mb = num_models * 500 # ~500MB per model
300
+
301
+ # Processing overhead
302
+ processing_mb = audio_mb * 2 # Intermediate buffers, embeddings, etc.
303
+
304
+ total_mb = (audio_mb + model_mb + processing_mb) * safety_factor
305
+
306
+ logger.debug(
307
+ f"Estimated memory: audio={audio_mb:.1f}MB, "
308
+ f"models={model_mb:.1f}MB, processing={processing_mb:.1f}MB, "
309
+ f"total={total_mb:.1f}MB (with {safety_factor}x safety factor)"
310
+ )
311
+
312
+ return total_mb
src/lib/metadata_logger.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processing metadata logging utility.
3
+
4
+ Tracks and logs processing metadata for all workflows including timing,
5
+ resource usage, and processing statistics.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ import os
11
+ import time
12
+ from dataclasses import asdict, dataclass, field
13
+ from datetime import datetime
14
+ from pathlib import Path
15
+ from typing import Any, Dict, Optional
16
+
17
+ import psutil
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ @dataclass
23
+ class ProcessingMetadata:
24
+ """
25
+ Metadata for a processing job.
26
+
27
+ Tracks timing, resource usage, and processing statistics.
28
+ """
29
+
30
+ # Job identification
31
+ job_id: str
32
+ workflow: str # 'separation', 'extraction', 'denoising'
33
+ timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
34
+
35
+ # Input/Output
36
+ input_files: list = field(default_factory=list)
37
+ output_files: list = field(default_factory=list)
38
+
39
+ # Timing (seconds)
40
+ start_time: Optional[float] = None
41
+ end_time: Optional[float] = None
42
+ processing_time: Optional[float] = None
43
+
44
+ # Resource usage
45
+ peak_memory_mb: float = 0.0
46
+ avg_cpu_percent: float = 0.0
47
+
48
+ # Processing statistics (workflow-specific)
49
+ statistics: Dict[str, Any] = field(default_factory=dict)
50
+
51
+ # Configuration
52
+ configuration: Dict[str, Any] = field(default_factory=dict)
53
+
54
+ # Status
55
+ status: str = "pending" # pending, running, completed, failed
56
+ error_message: Optional[str] = None
57
+
58
+ def to_dict(self) -> Dict[str, Any]:
59
+ """Convert metadata to dictionary."""
60
+ return asdict(self)
61
+
62
+ def to_json(self) -> str:
63
+ """Convert metadata to JSON string."""
64
+ return json.dumps(self.to_dict(), indent=2)
65
+
66
+
67
+ class MetadataLogger:
68
+ """
69
+ Logger for processing metadata.
70
+
71
+ Tracks timing, resource usage, and statistics for processing jobs.
72
+ """
73
+
74
+ def __init__(self, output_dir: Optional[Path] = None):
75
+ """
76
+ Initialize metadata logger.
77
+
78
+ Args:
79
+ output_dir: Directory to save metadata logs (default: ./metadata_logs)
80
+ """
81
+ self.output_dir = output_dir or Path("./metadata_logs")
82
+ self.output_dir.mkdir(parents=True, exist_ok=True)
83
+
84
+ self.current_metadata: Optional[ProcessingMetadata] = None
85
+ self.process = psutil.Process(os.getpid())
86
+
87
+ # Resource tracking
88
+ self._start_memory = 0.0
89
+ self._cpu_samples = []
90
+
91
+ logger.debug(f"Metadata logger initialized (output: {self.output_dir})")
92
+
93
+ def start_job(
94
+ self, job_id: str, workflow: str, input_files: list, configuration: Dict[str, Any]
95
+ ) -> ProcessingMetadata:
96
+ """
97
+ Start tracking a new processing job.
98
+
99
+ Args:
100
+ job_id: Unique job identifier
101
+ workflow: Workflow name ('separation', 'extraction', 'denoising')
102
+ input_files: List of input file paths
103
+ configuration: Job configuration parameters
104
+
105
+ Returns:
106
+ ProcessingMetadata object for this job
107
+ """
108
+ self.current_metadata = ProcessingMetadata(
109
+ job_id=job_id,
110
+ workflow=workflow,
111
+ input_files=[str(f) for f in input_files],
112
+ configuration=configuration,
113
+ start_time=time.time(),
114
+ status="running",
115
+ )
116
+
117
+ # Initialize resource tracking
118
+ self._start_memory = self.process.memory_info().rss / 1024 / 1024 # MB
119
+ self._cpu_samples = []
120
+
121
+ logger.info(f"Started tracking job: {job_id} ({workflow})")
122
+ return self.current_metadata
123
+
124
+ def update_progress(self, statistics: Dict[str, Any]):
125
+ """
126
+ Update job statistics during processing.
127
+
128
+ Args:
129
+ statistics: Current processing statistics
130
+ """
131
+ if self.current_metadata is None:
132
+ logger.warning("No active job to update")
133
+ return
134
+
135
+ self.current_metadata.statistics.update(statistics)
136
+
137
+ # Track resources
138
+ current_memory = self.process.memory_info().rss / 1024 / 1024 # MB
139
+ self.current_metadata.peak_memory_mb = max(
140
+ self.current_metadata.peak_memory_mb, current_memory
141
+ )
142
+
143
+ # Sample CPU usage
144
+ try:
145
+ cpu_percent = self.process.cpu_percent(interval=0.1)
146
+ self._cpu_samples.append(cpu_percent)
147
+ except Exception:
148
+ pass
149
+
150
+ def complete_job(
151
+ self, output_files: list, final_statistics: Optional[Dict[str, Any]] = None
152
+ ) -> ProcessingMetadata:
153
+ """
154
+ Mark job as completed and finalize metadata.
155
+
156
+ Args:
157
+ output_files: List of output file paths
158
+ final_statistics: Final processing statistics
159
+
160
+ Returns:
161
+ Completed ProcessingMetadata object
162
+ """
163
+ if self.current_metadata is None:
164
+ raise ValueError("No active job to complete")
165
+
166
+ self.current_metadata.end_time = time.time()
167
+ self.current_metadata.processing_time = (
168
+ self.current_metadata.end_time - self.current_metadata.start_time
169
+ )
170
+ self.current_metadata.output_files = [str(f) for f in output_files]
171
+ self.current_metadata.status = "completed"
172
+
173
+ # Update final statistics
174
+ if final_statistics:
175
+ self.current_metadata.statistics.update(final_statistics)
176
+
177
+ # Calculate average CPU usage
178
+ if self._cpu_samples:
179
+ self.current_metadata.avg_cpu_percent = sum(self._cpu_samples) / len(self._cpu_samples)
180
+
181
+ # Save metadata
182
+ self._save_metadata()
183
+
184
+ logger.info(
185
+ f"Completed job: {self.current_metadata.job_id} "
186
+ f"(time: {self.current_metadata.processing_time:.2f}s, "
187
+ f"memory: {self.current_metadata.peak_memory_mb:.2f}MB)"
188
+ )
189
+
190
+ completed_metadata = self.current_metadata
191
+ self.current_metadata = None
192
+ return completed_metadata
193
+
194
+ def fail_job(self, error_message: str) -> ProcessingMetadata:
195
+ """
196
+ Mark job as failed.
197
+
198
+ Args:
199
+ error_message: Error description
200
+
201
+ Returns:
202
+ Failed ProcessingMetadata object
203
+ """
204
+ if self.current_metadata is None:
205
+ raise ValueError("No active job to fail")
206
+
207
+ self.current_metadata.end_time = time.time()
208
+ self.current_metadata.processing_time = (
209
+ self.current_metadata.end_time - self.current_metadata.start_time
210
+ )
211
+ self.current_metadata.status = "failed"
212
+ self.current_metadata.error_message = error_message
213
+
214
+ # Save metadata
215
+ self._save_metadata()
216
+
217
+ logger.error(f"Failed job: {self.current_metadata.job_id} - {error_message}")
218
+
219
+ failed_metadata = self.current_metadata
220
+ self.current_metadata = None
221
+ return failed_metadata
222
+
223
+ def _save_metadata(self):
224
+ """Save metadata to file."""
225
+ if self.current_metadata is None:
226
+ return
227
+
228
+ try:
229
+ # Create filename from job ID and timestamp
230
+ filename = f"{self.current_metadata.workflow}_{self.current_metadata.job_id}.json"
231
+ filepath = self.output_dir / filename
232
+
233
+ # Write metadata
234
+ with open(filepath, "w") as f:
235
+ f.write(self.current_metadata.to_json())
236
+
237
+ logger.debug(f"Saved metadata: {filepath}")
238
+
239
+ except Exception as e:
240
+ logger.error(f"Failed to save metadata: {e}")
241
+
242
+ def get_job_history(self, workflow: Optional[str] = None) -> list:
243
+ """
244
+ Get processing history for completed jobs.
245
+
246
+ Args:
247
+ workflow: Filter by workflow name (None = all workflows)
248
+
249
+ Returns:
250
+ List of ProcessingMetadata dictionaries
251
+ """
252
+ history = []
253
+
254
+ try:
255
+ for metadata_file in self.output_dir.glob("*.json"):
256
+ # Filter by workflow if specified
257
+ if workflow and not metadata_file.stem.startswith(workflow):
258
+ continue
259
+
260
+ with open(metadata_file) as f:
261
+ metadata = json.load(f)
262
+ history.append(metadata)
263
+
264
+ # Sort by timestamp (newest first)
265
+ history.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
266
+
267
+ except Exception as e:
268
+ logger.error(f"Failed to load job history: {e}")
269
+
270
+ return history
271
+
272
+ def get_statistics_summary(self, workflow: str) -> Dict[str, Any]:
273
+ """
274
+ Get aggregated statistics for a workflow.
275
+
276
+ Args:
277
+ workflow: Workflow name
278
+
279
+ Returns:
280
+ Dictionary with aggregated statistics
281
+ """
282
+ history = self.get_job_history(workflow=workflow)
283
+
284
+ if not history:
285
+ return {
286
+ "total_jobs": 0,
287
+ "completed_jobs": 0,
288
+ "failed_jobs": 0,
289
+ }
290
+
291
+ completed = [j for j in history if j["status"] == "completed"]
292
+ failed = [j for j in history if j["status"] == "failed"]
293
+
294
+ summary = {
295
+ "total_jobs": len(history),
296
+ "completed_jobs": len(completed),
297
+ "failed_jobs": len(failed),
298
+ "success_rate": len(completed) / len(history) if history else 0.0,
299
+ }
300
+
301
+ if completed:
302
+ processing_times = [j["processing_time"] for j in completed if j.get("processing_time")]
303
+ memory_usage = [j["peak_memory_mb"] for j in completed if j.get("peak_memory_mb")]
304
+
305
+ if processing_times:
306
+ summary["avg_processing_time"] = sum(processing_times) / len(processing_times)
307
+ summary["min_processing_time"] = min(processing_times)
308
+ summary["max_processing_time"] = max(processing_times)
309
+
310
+ if memory_usage:
311
+ summary["avg_memory_mb"] = sum(memory_usage) / len(memory_usage)
312
+ summary["peak_memory_mb"] = max(memory_usage)
313
+
314
+ return summary
315
+
316
+
317
+ # Global metadata logger instance
318
+ _global_logger: Optional[MetadataLogger] = None
319
+
320
+
321
+ def get_metadata_logger(output_dir: Optional[Path] = None) -> MetadataLogger:
322
+ """
323
+ Get global metadata logger instance.
324
+
325
+ Args:
326
+ output_dir: Directory to save metadata logs
327
+
328
+ Returns:
329
+ MetadataLogger instance
330
+ """
331
+ global _global_logger
332
+
333
+ if _global_logger is None:
334
+ _global_logger = MetadataLogger(output_dir=output_dir)
335
+
336
+ return _global_logger
src/lib/quality_metrics.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio quality metrics: SNR, STOI, PESQ calculation functions.
3
+
4
+ Provides objective quality measurements for audio extraction validation.
5
+ """
6
+
7
+ import logging
8
+ from typing import Optional, Tuple
9
+
10
+ import numpy as np
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class QualityMetricsError(Exception):
16
+ """Custom exception for quality metric calculation errors."""
17
+
18
+ pass
19
+
20
+
21
+ def calculate_snr(clean_signal: np.ndarray, noisy_signal: np.ndarray) -> float:
22
+ """
23
+ Calculate Signal-to-Noise Ratio (SNR) in dB.
24
+
25
+ Measures the ratio of signal power to noise power.
26
+ Higher values indicate cleaner audio.
27
+
28
+ Args:
29
+ clean_signal: Clean reference signal
30
+ noisy_signal: Signal with noise
31
+
32
+ Returns:
33
+ SNR in dB
34
+
35
+ Raises:
36
+ QualityMetricsError: If signals have different lengths or calculation fails
37
+ """
38
+ try:
39
+ # Ensure same length
40
+ min_len = min(len(clean_signal), len(noisy_signal))
41
+ clean_signal = clean_signal[:min_len]
42
+ noisy_signal = noisy_signal[:min_len]
43
+
44
+ # Calculate noise
45
+ noise = noisy_signal - clean_signal
46
+
47
+ # Calculate power
48
+ signal_power = np.mean(clean_signal**2)
49
+ noise_power = np.mean(noise**2)
50
+
51
+ # Handle edge case: no noise
52
+ if noise_power == 0:
53
+ return float("inf")
54
+
55
+ # Handle edge case: no signal
56
+ if signal_power == 0:
57
+ return float("-inf")
58
+
59
+ # Calculate SNR in dB
60
+ snr = 10 * np.log10(signal_power / noise_power)
61
+
62
+ return snr
63
+
64
+ except Exception as e:
65
+ raise QualityMetricsError(f"Failed to calculate SNR: {str(e)}")
66
+
67
+
68
+ def calculate_snr_segmental(
69
+ signal: np.ndarray, sample_rate: int, frame_length_ms: int = 20
70
+ ) -> float:
71
+ """
72
+ Calculate segmental SNR for signal without clean reference.
73
+
74
+ Useful when you don't have a clean reference - estimates SNR
75
+ by analyzing signal characteristics.
76
+
77
+ Args:
78
+ signal: Audio signal
79
+ sample_rate: Sample rate in Hz
80
+ frame_length_ms: Frame length in milliseconds
81
+
82
+ Returns:
83
+ Segmental SNR in dB
84
+ """
85
+ try:
86
+ frame_length = int(sample_rate * frame_length_ms / 1000)
87
+ hop_length = frame_length // 2
88
+
89
+ snrs = []
90
+
91
+ for i in range(0, len(signal) - frame_length, hop_length):
92
+ frame = signal[i : i + frame_length]
93
+ signal_power = np.mean(frame**2)
94
+
95
+ if signal_power > 0:
96
+ snr_db = 10 * np.log10(signal_power)
97
+ snrs.append(snr_db)
98
+
99
+ if not snrs:
100
+ return 0.0
101
+
102
+ return np.mean(snrs)
103
+
104
+ except Exception as e:
105
+ raise QualityMetricsError(f"Failed to calculate segmental SNR: {str(e)}")
106
+
107
+
108
+ def calculate_stoi(
109
+ clean_signal: np.ndarray, degraded_signal: np.ndarray, sample_rate: int, extended: bool = True
110
+ ) -> float:
111
+ """
112
+ Calculate Short-Time Objective Intelligibility (STOI) score.
113
+
114
+ Measures speech intelligibility. Range: 0-1 (higher = better).
115
+ Extended STOI (e-STOI) is better for intermediate quality levels.
116
+
117
+ Args:
118
+ clean_signal: Clean reference signal
119
+ degraded_signal: Degraded signal to evaluate
120
+ sample_rate: Sample rate in Hz
121
+ extended: Use extended STOI (default: True)
122
+
123
+ Returns:
124
+ STOI score (0-1)
125
+
126
+ Raises:
127
+ QualityMetricsError: If calculation fails
128
+ """
129
+ try:
130
+ from pystoi import stoi
131
+
132
+ # Ensure same length
133
+ min_len = min(len(clean_signal), len(degraded_signal))
134
+ clean_signal = clean_signal[:min_len]
135
+ degraded_signal = degraded_signal[:min_len]
136
+
137
+ # Calculate STOI
138
+ score = stoi(clean_signal, degraded_signal, sample_rate, extended=extended)
139
+
140
+ return score
141
+
142
+ except Exception as e:
143
+ raise QualityMetricsError(f"Failed to calculate STOI: {str(e)}")
144
+
145
+
146
+ def calculate_pesq(
147
+ reference_signal: np.ndarray, degraded_signal: np.ndarray, sample_rate: int, mode: str = "wb"
148
+ ) -> float:
149
+ """
150
+ Calculate Perceptual Evaluation of Speech Quality (PESQ) score.
151
+
152
+ Correlates with human perception of quality. Range: -0.5 to 4.5 (higher = better).
153
+
154
+ Args:
155
+ reference_signal: Reference (clean) signal
156
+ degraded_signal: Degraded signal to evaluate
157
+ sample_rate: Sample rate in Hz (must be 8000 or 16000)
158
+ mode: 'wb' (wideband, 16kHz) or 'nb' (narrowband, 8kHz)
159
+
160
+ Returns:
161
+ PESQ score
162
+
163
+ Raises:
164
+ QualityMetricsError: If calculation fails or sample rate is invalid
165
+ """
166
+ try:
167
+ from pesq import pesq
168
+
169
+ # Ensure same length
170
+ min_len = min(len(reference_signal), len(degraded_signal))
171
+ reference_signal = reference_signal[:min_len]
172
+ degraded_signal = degraded_signal[:min_len]
173
+
174
+ # PESQ requires specific sample rates
175
+ if mode == "wb" and sample_rate != 16000:
176
+ raise QualityMetricsError(
177
+ f"Wideband PESQ requires 16kHz sample rate, got {sample_rate}Hz. "
178
+ "Resample before calling this function."
179
+ )
180
+ elif mode == "nb" and sample_rate != 8000:
181
+ raise QualityMetricsError(
182
+ f"Narrowband PESQ requires 8kHz sample rate, got {sample_rate}Hz. "
183
+ "Resample before calling this function."
184
+ )
185
+
186
+ # Calculate PESQ
187
+ score = pesq(sample_rate, reference_signal, degraded_signal, mode)
188
+
189
+ return score
190
+
191
+ except Exception as e:
192
+ if isinstance(e, QualityMetricsError):
193
+ raise
194
+ raise QualityMetricsError(f"Failed to calculate PESQ: {str(e)}")
195
+
196
+
197
+ def calculate_pesq_with_resampling(
198
+ reference_signal: np.ndarray, degraded_signal: np.ndarray, sample_rate: int, mode: str = "wb"
199
+ ) -> float:
200
+ """
201
+ Calculate PESQ with automatic resampling to required sample rate.
202
+
203
+ Args:
204
+ reference_signal: Reference signal
205
+ degraded_signal: Degraded signal
206
+ sample_rate: Current sample rate
207
+ mode: 'wb' (wideband, 16kHz) or 'nb' (narrowband, 8kHz)
208
+
209
+ Returns:
210
+ PESQ score
211
+ """
212
+ try:
213
+ from pesq import pesq
214
+ from scipy.signal import resample
215
+
216
+ # Ensure same length
217
+ min_len = min(len(reference_signal), len(degraded_signal))
218
+ reference_signal = reference_signal[:min_len]
219
+ degraded_signal = degraded_signal[:min_len]
220
+
221
+ # Determine target sample rate
222
+ target_sr = 16000 if mode == "wb" else 8000
223
+
224
+ # Resample if needed
225
+ if sample_rate != target_sr:
226
+ target_len = int(len(reference_signal) * target_sr / sample_rate)
227
+ reference_signal = resample(reference_signal, target_len)
228
+ degraded_signal = resample(degraded_signal, target_len)
229
+
230
+ # Calculate PESQ
231
+ score = pesq(target_sr, reference_signal, degraded_signal, mode)
232
+
233
+ return score
234
+
235
+ except Exception as e:
236
+ raise QualityMetricsError(f"Failed to calculate PESQ with resampling: {str(e)}")
237
+
238
+
239
+ def validate_extraction_quality(
240
+ original_signal: np.ndarray,
241
+ extracted_signal: np.ndarray,
242
+ sample_rate: int,
243
+ snr_threshold: float = 20.0,
244
+ stoi_threshold: float = 0.75,
245
+ pesq_threshold: float = 2.5,
246
+ ) -> dict:
247
+ """
248
+ Validate extraction quality against thresholds.
249
+
250
+ Calculates all three metrics and checks if they meet minimum thresholds.
251
+
252
+ Args:
253
+ original_signal: Original (noisy) signal
254
+ extracted_signal: Extracted (cleaned) signal
255
+ sample_rate: Sample rate in Hz
256
+ snr_threshold: Minimum SNR in dB (default: 20)
257
+ stoi_threshold: Minimum STOI score (default: 0.75)
258
+ pesq_threshold: Minimum PESQ score (default: 2.5)
259
+
260
+ Returns:
261
+ Dictionary with metrics and pass/fail status
262
+ """
263
+ results = {
264
+ "snr": None,
265
+ "snr_pass": False,
266
+ "stoi": None,
267
+ "stoi_pass": False,
268
+ "pesq": None,
269
+ "pesq_pass": False,
270
+ "overall_pass": False,
271
+ }
272
+
273
+ try:
274
+ # Calculate SNR
275
+ try:
276
+ results["snr"] = calculate_snr(original_signal, extracted_signal)
277
+ results["snr_pass"] = results["snr"] >= snr_threshold
278
+ except Exception as e:
279
+ logger.warning(f"SNR calculation failed: {e}")
280
+
281
+ # Calculate STOI
282
+ try:
283
+ results["stoi"] = calculate_stoi(
284
+ original_signal, extracted_signal, sample_rate, extended=True
285
+ )
286
+ results["stoi_pass"] = results["stoi"] >= stoi_threshold
287
+ except Exception as e:
288
+ logger.warning(f"STOI calculation failed: {e}")
289
+
290
+ # Calculate PESQ (with resampling if needed)
291
+ try:
292
+ results["pesq"] = calculate_pesq_with_resampling(
293
+ original_signal, extracted_signal, sample_rate, mode="wb"
294
+ )
295
+ results["pesq_pass"] = results["pesq"] >= pesq_threshold
296
+ except Exception as e:
297
+ logger.warning(f"PESQ calculation failed: {e}")
298
+
299
+ # Overall pass if all metrics that were calculated passed
300
+ results["overall_pass"] = (
301
+ results.get("snr_pass", False)
302
+ and results.get("stoi_pass", False)
303
+ and results.get("pesq_pass", False)
304
+ )
305
+
306
+ except Exception as e:
307
+ logger.error(f"Quality validation failed: {e}")
308
+
309
+ return results
310
+
311
+
312
+ def get_quality_label(metric_name: str, value: float) -> str:
313
+ """
314
+ Get quality label for a metric value.
315
+
316
+ Args:
317
+ metric_name: Metric name ('snr', 'stoi', 'pesq')
318
+ value: Metric value
319
+
320
+ Returns:
321
+ Quality label string
322
+ """
323
+ if metric_name == "snr":
324
+ if value > 40:
325
+ return "Excellent"
326
+ elif value > 30:
327
+ return "Very Good"
328
+ elif value > 20:
329
+ return "Good"
330
+ elif value > 10:
331
+ return "Fair"
332
+ else:
333
+ return "Poor"
334
+
335
+ elif metric_name == "stoi":
336
+ if value > 0.9:
337
+ return "Excellent"
338
+ elif value > 0.8:
339
+ return "Very Good"
340
+ elif value > 0.7:
341
+ return "Good"
342
+ elif value > 0.6:
343
+ return "Fair"
344
+ else:
345
+ return "Poor"
346
+
347
+ elif metric_name == "pesq":
348
+ if value > 3.5:
349
+ return "Excellent"
350
+ elif value > 3.0:
351
+ return "Good"
352
+ elif value > 2.5:
353
+ return "Fair"
354
+ elif value > 2.0:
355
+ return "Poor"
356
+ else:
357
+ return "Bad"
358
+
359
+ return "Unknown"
360
+
361
+
362
+ def generate_quality_report(metrics: dict) -> str:
363
+ """
364
+ Generate human-readable quality report.
365
+
366
+ Args:
367
+ metrics: Dictionary from validate_extraction_quality()
368
+
369
+ Returns:
370
+ Formatted report string
371
+ """
372
+ report = ["=== Voice Extraction Quality Report ===", ""]
373
+
374
+ # SNR
375
+ if metrics["snr"] is not None:
376
+ status = "PASS" if metrics["snr_pass"] else "FAIL"
377
+ quality = get_quality_label("snr", metrics["snr"])
378
+ report.append(f"SNR: {metrics['snr']:.2f} dB [{status}] - {quality}")
379
+
380
+ # STOI
381
+ if metrics["stoi"] is not None:
382
+ status = "PASS" if metrics["stoi_pass"] else "FAIL"
383
+ quality = get_quality_label("stoi", metrics["stoi"])
384
+ report.append(f"STOI: {metrics['stoi']:.3f} [{status}] - {quality}")
385
+
386
+ # PESQ
387
+ if metrics["pesq"] is not None:
388
+ status = "PASS" if metrics["pesq_pass"] else "FAIL"
389
+ quality = get_quality_label("pesq", metrics["pesq"])
390
+ report.append(f"PESQ: {metrics['pesq']:.2f} [{status}] - {quality}")
391
+
392
+ # Overall
393
+ overall = "PASS" if metrics["overall_pass"] else "FAIL"
394
+ report.append("")
395
+ report.append(f"Overall Quality: [{overall}]")
396
+
397
+ return "\n".join(report)