# backend/ml_models/validate_models.py """ Model validation script to test inference pipeline. Run this before deploying to production. """ import os import sys import torch import numpy as np import soundfile as sf from pathlib import Path # Add backend to path BACKEND_DIR = Path(__file__).parent.parent sys.path.insert(0, str(BACKEND_DIR)) from ml_models import TranscriptionEngine, WEIGHTS_DIR def generate_test_audio(duration=5.0, sr=16000, frequency=440.0): """Generate test sine wave.""" t = np.linspace(0, duration, int(sr * duration)) audio = 0.5 * np.sin(2 * np.pi * frequency * t) return audio.astype(np.float32) def generate_silence(duration=5.0, sr=16000): """Generate silence.""" return np.zeros(int(sr * duration), dtype=np.float32) def validate_bass_model(engine, test_dir): """Validate bass model.""" print("\n" + "=" * 60) print("Testing BASS Model") print("=" * 60) # Test 1: Silence (should produce 0 or very few notes) print("\nTest 1: Silence") silence_path = test_dir / "silence.wav" sf.write(silence_path, generate_silence(), 16000) try: result = engine.transcribe_bass(str(silence_path)) notes = result[0] metadata = result[1] print(f" Notes detected: {len(notes)}") print(f" Metadata: {metadata}") if len(notes) > 20: print(f" ⚠ WARNING: Too many false positives in silence!") return False else: print(f" ✓ PASS: Acceptable false positive rate") except Exception as e: print(f" ✗ FAIL: {e}") import traceback traceback.print_exc() return False # Test 2: Pure tone outside bass range (should be empty) print("\nTest 2: High frequency (A440 - out of bass range)") tone_path = test_dir / "tone_high.wav" sf.write(tone_path, generate_test_audio(frequency=440), 16000) try: result = engine.transcribe_bass(str(tone_path)) notes = result[0] print(f" Notes detected: {len(notes)}") if len(notes) > 10: print(f" ⚠ WARNING: Detecting out-of-range frequencies!") return False else: print(f" ✓ PASS: Correctly ignoring out-of-range") except Exception as e: print(f" ✗ FAIL: {e}") return False # Test 3: Bass frequency (E1 = 41.2 Hz) print("\nTest 3: Bass frequency (E1 = 41.2 Hz)") bass_path = test_dir / "tone_bass.wav" sf.write(bass_path, generate_test_audio(frequency=41.2), 16000) try: result = engine.transcribe_bass(str(bass_path)) notes = result[0] print(f" Notes detected: {len(notes)}") if len(notes) == 0: print(f" ⚠ WARNING: Not detecting bass frequencies!") return False elif len(notes) > 5: print(f" ✓ PASS: Detecting bass (note: may split sustained tone)") else: print(f" ✓ PASS: Detecting bass") except Exception as e: print(f" ✗ FAIL: {e}") return False print("\n✓ Bass model validation PASSED") return True def validate_vocal_model(engine, test_dir): """Validate vocal model.""" print("\n" + "=" * 60) print("Testing VOCAL Model") print("=" * 60) # Test 1: Silence print("\nTest 1: Silence") silence_path = test_dir / "silence.wav" try: result = engine.transcribe_vocals(str(silence_path)) notes = result[0] metadata = result[1] print(f" Notes detected: {len(notes)}") print(f" Voiced ratio: {metadata.get('voiced_ratio', 0):.3f}") if metadata.get('voiced_ratio', 0) > 0.3: print(f" ⚠ WARNING: High voiced ratio in silence!") return False else: print(f" ✓ PASS: Low false positive rate") except Exception as e: print(f" ✗ FAIL: {e}") return False # Test 2: Vocal range frequency print("\nTest 2: Vocal frequency (C4 = 261.6 Hz)") vocal_path = test_dir / "tone_vocal.wav" sf.write(vocal_path, generate_test_audio(frequency=261.6), 16000) try: result = engine.transcribe_vocals(str(vocal_path)) notes = result[0] metadata = result[1] print(f" Notes detected: {len(notes)}") print(f" Voiced ratio: {metadata.get('voiced_ratio', 0):.3f}") if len(notes) == 0: print(f" ⚠ WARNING: Not detecting vocal frequencies!") return False else: print(f" ✓ PASS: Detecting vocal frequencies") except Exception as e: print(f" ✗ FAIL: {e}") return False print("\n✓ Vocal model validation PASSED") return True def validate_drum_model(engine, test_dir): """Validate drum model.""" print("\n" + "=" * 60) print("Testing DRUM Model") print("=" * 60) # Test 1: Silence print("\nTest 1: Silence") silence_path = test_dir / "silence.wav" try: result = engine.transcribe_drums(str(silence_path)) hits = result[0] metadata = result[1] print(f" Hits detected: {len(hits)}") print(f" Hits per class: {metadata.get('hits_per_class', {})}") if len(hits) > 50: print(f" ⚠ WARNING: Too many false positives in silence!") return False else: print(f" ✓ PASS: Low false positive rate") except Exception as e: print(f" ✗ FAIL: {e}") return False # Test 2: Impulse (should detect as drum hit) print("\nTest 2: Impulse (synthetic drum)") impulse_path = test_dir / "impulse.wav" # Create impulse train impulse = np.zeros(16000 * 3) for i in range(0, len(impulse), 16000): # 1 Hz impulses impulse[i:i + 100] = 0.8 * np.random.randn(100) # Noisy impulse sf.write(impulse_path, impulse.astype(np.float32), 16000) try: result = engine.transcribe_drums(str(impulse_path)) hits = result[0] print(f" Hits detected: {len(hits)}") if len(hits) == 0: print(f" ⚠ WARNING: Not detecting impulses!") return False else: print(f" ✓ PASS: Detecting impulses") except Exception as e: print(f" ✗ FAIL: {e}") return False print("\n✓ Drum model validation PASSED") return True def main(): """Run all validation tests.""" print("=" * 60) print("MODEL VALIDATION SUITE") print("=" * 60) # Setup if torch.cuda.is_available(): device = 'cuda' elif torch.backends.mps.is_available(): device = 'mps' else: device = 'cpu' print(f"\nDevice: {device}") print(f"Weights directory: {WEIGHTS_DIR}") # Create test directory test_dir = Path("test_audio") test_dir.mkdir(exist_ok=True) # Initialize engine try: engine = TranscriptionEngine(str(WEIGHTS_DIR), device=device) except Exception as e: print(f"\n✗ CRITICAL: Could not initialize TranscriptionEngine") print(f"Error: {e}") import traceback traceback.print_exc() return False # Run tests results = {} if engine.is_model_available('bass'): results['bass'] = validate_bass_model(engine, test_dir) else: print("\n⚠ Bass model not available, skipping") results['bass'] = None if engine.is_model_available('vocals'): results['vocals'] = validate_vocal_model(engine, test_dir) else: print("\n⚠ Vocal model not available, skipping") results['vocals'] = None if engine.is_model_available('drums'): results['drums'] = validate_drum_model(engine, test_dir) else: print("\n⚠ Drum model not available, skipping") results['drums'] = None # Summary print("\n" + "=" * 60) print("VALIDATION SUMMARY") print("=" * 60) for model, result in results.items(): if result is None: status = "SKIPPED" elif result: status = "✓ PASSED" else: status = "✗ FAILED" print(f" {model}: {status}") all_passed = all(r for r in results.values() if r is not None) if all_passed: print("\n✓ ALL TESTS PASSED") return True else: print("\n✗ SOME TESTS FAILED") return False if __name__ == "__main__": success = main() sys.exit(0 if success else 1)