| |
| """ |
| Model validation script to test inference pipeline. |
| Run this before deploying to production. |
| """ |
|
|
| import os |
| import sys |
| import torch |
| import numpy as np |
| import soundfile as sf |
| from pathlib import Path |
|
|
| |
| BACKEND_DIR = Path(__file__).parent.parent |
| sys.path.insert(0, str(BACKEND_DIR)) |
|
|
| from ml_models import TranscriptionEngine, WEIGHTS_DIR |
|
|
|
|
| def generate_test_audio(duration=5.0, sr=16000, frequency=440.0): |
| """Generate test sine wave.""" |
| t = np.linspace(0, duration, int(sr * duration)) |
| audio = 0.5 * np.sin(2 * np.pi * frequency * t) |
| return audio.astype(np.float32) |
|
|
|
|
| def generate_silence(duration=5.0, sr=16000): |
| """Generate silence.""" |
| return np.zeros(int(sr * duration), dtype=np.float32) |
|
|
|
|
| def validate_bass_model(engine, test_dir): |
| """Validate bass model.""" |
| print("\n" + "=" * 60) |
| print("Testing BASS Model") |
| print("=" * 60) |
|
|
| |
| print("\nTest 1: Silence") |
| silence_path = test_dir / "silence.wav" |
| sf.write(silence_path, generate_silence(), 16000) |
|
|
| try: |
| result = engine.transcribe_bass(str(silence_path)) |
| notes = result[0] |
| metadata = result[1] |
|
|
| print(f" Notes detected: {len(notes)}") |
| print(f" Metadata: {metadata}") |
|
|
| if len(notes) > 20: |
| print(f" β WARNING: Too many false positives in silence!") |
| return False |
| else: |
| print(f" β PASS: Acceptable false positive rate") |
| except Exception as e: |
| print(f" β FAIL: {e}") |
| import traceback |
| traceback.print_exc() |
| return False |
|
|
| |
| print("\nTest 2: High frequency (A440 - out of bass range)") |
| tone_path = test_dir / "tone_high.wav" |
| sf.write(tone_path, generate_test_audio(frequency=440), 16000) |
|
|
| try: |
| result = engine.transcribe_bass(str(tone_path)) |
| notes = result[0] |
|
|
| print(f" Notes detected: {len(notes)}") |
|
|
| if len(notes) > 10: |
| print(f" β WARNING: Detecting out-of-range frequencies!") |
| return False |
| else: |
| print(f" β PASS: Correctly ignoring out-of-range") |
| except Exception as e: |
| print(f" β FAIL: {e}") |
| return False |
|
|
| |
| print("\nTest 3: Bass frequency (E1 = 41.2 Hz)") |
| bass_path = test_dir / "tone_bass.wav" |
| sf.write(bass_path, generate_test_audio(frequency=41.2), 16000) |
|
|
| try: |
| result = engine.transcribe_bass(str(bass_path)) |
| notes = result[0] |
|
|
| print(f" Notes detected: {len(notes)}") |
|
|
| if len(notes) == 0: |
| print(f" β WARNING: Not detecting bass frequencies!") |
| return False |
| elif len(notes) > 5: |
| print(f" β PASS: Detecting bass (note: may split sustained tone)") |
| else: |
| print(f" β PASS: Detecting bass") |
| except Exception as e: |
| print(f" β FAIL: {e}") |
| return False |
|
|
| print("\nβ Bass model validation PASSED") |
| return True |
|
|
|
|
| def validate_vocal_model(engine, test_dir): |
| """Validate vocal model.""" |
| print("\n" + "=" * 60) |
| print("Testing VOCAL Model") |
| print("=" * 60) |
|
|
| |
| print("\nTest 1: Silence") |
| silence_path = test_dir / "silence.wav" |
|
|
| try: |
| result = engine.transcribe_vocals(str(silence_path)) |
| notes = result[0] |
| metadata = result[1] |
|
|
| print(f" Notes detected: {len(notes)}") |
| print(f" Voiced ratio: {metadata.get('voiced_ratio', 0):.3f}") |
|
|
| if metadata.get('voiced_ratio', 0) > 0.3: |
| print(f" β WARNING: High voiced ratio in silence!") |
| return False |
| else: |
| print(f" β PASS: Low false positive rate") |
| except Exception as e: |
| print(f" β FAIL: {e}") |
| return False |
|
|
| |
| print("\nTest 2: Vocal frequency (C4 = 261.6 Hz)") |
| vocal_path = test_dir / "tone_vocal.wav" |
| sf.write(vocal_path, generate_test_audio(frequency=261.6), 16000) |
|
|
| try: |
| result = engine.transcribe_vocals(str(vocal_path)) |
| notes = result[0] |
| metadata = result[1] |
|
|
| print(f" Notes detected: {len(notes)}") |
| print(f" Voiced ratio: {metadata.get('voiced_ratio', 0):.3f}") |
|
|
| if len(notes) == 0: |
| print(f" β WARNING: Not detecting vocal frequencies!") |
| return False |
| else: |
| print(f" β PASS: Detecting vocal frequencies") |
| except Exception as e: |
| print(f" β FAIL: {e}") |
| return False |
|
|
| print("\nβ Vocal model validation PASSED") |
| return True |
|
|
|
|
| def validate_drum_model(engine, test_dir): |
| """Validate drum model.""" |
| print("\n" + "=" * 60) |
| print("Testing DRUM Model") |
| print("=" * 60) |
|
|
| |
| print("\nTest 1: Silence") |
| silence_path = test_dir / "silence.wav" |
|
|
| try: |
| result = engine.transcribe_drums(str(silence_path)) |
| hits = result[0] |
| metadata = result[1] |
|
|
| print(f" Hits detected: {len(hits)}") |
| print(f" Hits per class: {metadata.get('hits_per_class', {})}") |
|
|
| if len(hits) > 50: |
| print(f" β WARNING: Too many false positives in silence!") |
| return False |
| else: |
| print(f" β PASS: Low false positive rate") |
| except Exception as e: |
| print(f" β FAIL: {e}") |
| return False |
|
|
| |
| print("\nTest 2: Impulse (synthetic drum)") |
| impulse_path = test_dir / "impulse.wav" |
|
|
| |
| impulse = np.zeros(16000 * 3) |
| for i in range(0, len(impulse), 16000): |
| impulse[i:i + 100] = 0.8 * np.random.randn(100) |
|
|
| sf.write(impulse_path, impulse.astype(np.float32), 16000) |
|
|
| try: |
| result = engine.transcribe_drums(str(impulse_path)) |
| hits = result[0] |
|
|
| print(f" Hits detected: {len(hits)}") |
|
|
| if len(hits) == 0: |
| print(f" β WARNING: Not detecting impulses!") |
| return False |
| else: |
| print(f" β PASS: Detecting impulses") |
| except Exception as e: |
| print(f" β FAIL: {e}") |
| return False |
|
|
| print("\nβ Drum model validation PASSED") |
| return True |
|
|
|
|
| def main(): |
| """Run all validation tests.""" |
| print("=" * 60) |
| print("MODEL VALIDATION SUITE") |
| print("=" * 60) |
|
|
| |
| if torch.cuda.is_available(): |
| device = 'cuda' |
| elif torch.backends.mps.is_available(): |
| device = 'mps' |
| else: |
| device = 'cpu' |
|
|
| print(f"\nDevice: {device}") |
| print(f"Weights directory: {WEIGHTS_DIR}") |
|
|
| |
| test_dir = Path("test_audio") |
| test_dir.mkdir(exist_ok=True) |
|
|
| |
| try: |
| engine = TranscriptionEngine(str(WEIGHTS_DIR), device=device) |
| except Exception as e: |
| print(f"\nβ CRITICAL: Could not initialize TranscriptionEngine") |
| print(f"Error: {e}") |
| import traceback |
| traceback.print_exc() |
| return False |
|
|
| |
| results = {} |
|
|
| if engine.is_model_available('bass'): |
| results['bass'] = validate_bass_model(engine, test_dir) |
| else: |
| print("\nβ Bass model not available, skipping") |
| results['bass'] = None |
|
|
| if engine.is_model_available('vocals'): |
| results['vocals'] = validate_vocal_model(engine, test_dir) |
| else: |
| print("\nβ Vocal model not available, skipping") |
| results['vocals'] = None |
|
|
| if engine.is_model_available('drums'): |
| results['drums'] = validate_drum_model(engine, test_dir) |
| else: |
| print("\nβ Drum model not available, skipping") |
| results['drums'] = None |
|
|
| |
| print("\n" + "=" * 60) |
| print("VALIDATION SUMMARY") |
| print("=" * 60) |
|
|
| for model, result in results.items(): |
| if result is None: |
| status = "SKIPPED" |
| elif result: |
| status = "β PASSED" |
| else: |
| status = "β FAILED" |
| print(f" {model}: {status}") |
|
|
| all_passed = all(r for r in results.values() if r is not None) |
|
|
| if all_passed: |
| print("\nβ ALL TESTS PASSED") |
| return True |
| else: |
| print("\nβ SOME TESTS FAILED") |
| return False |
|
|
|
|
| if __name__ == "__main__": |
| success = main() |
| sys.exit(0 if success else 1) |