muse-archive / scripts /validate_models.py
lamooon's picture
Upload scripts/validate_models.py with huggingface_hub
58a923b verified
# backend/ml_models/validate_models.py
"""
Model validation script to test inference pipeline.
Run this before deploying to production.
"""
import os
import sys
import torch
import numpy as np
import soundfile as sf
from pathlib import Path
# Add backend to path
BACKEND_DIR = Path(__file__).parent.parent
sys.path.insert(0, str(BACKEND_DIR))
from ml_models import TranscriptionEngine, WEIGHTS_DIR
def generate_test_audio(duration=5.0, sr=16000, frequency=440.0):
"""Generate test sine wave."""
t = np.linspace(0, duration, int(sr * duration))
audio = 0.5 * np.sin(2 * np.pi * frequency * t)
return audio.astype(np.float32)
def generate_silence(duration=5.0, sr=16000):
"""Generate silence."""
return np.zeros(int(sr * duration), dtype=np.float32)
def validate_bass_model(engine, test_dir):
"""Validate bass model."""
print("\n" + "=" * 60)
print("Testing BASS Model")
print("=" * 60)
# Test 1: Silence (should produce 0 or very few notes)
print("\nTest 1: Silence")
silence_path = test_dir / "silence.wav"
sf.write(silence_path, generate_silence(), 16000)
try:
result = engine.transcribe_bass(str(silence_path))
notes = result[0]
metadata = result[1]
print(f" Notes detected: {len(notes)}")
print(f" Metadata: {metadata}")
if len(notes) > 20:
print(f" ⚠ WARNING: Too many false positives in silence!")
return False
else:
print(f" βœ“ PASS: Acceptable false positive rate")
except Exception as e:
print(f" βœ— FAIL: {e}")
import traceback
traceback.print_exc()
return False
# Test 2: Pure tone outside bass range (should be empty)
print("\nTest 2: High frequency (A440 - out of bass range)")
tone_path = test_dir / "tone_high.wav"
sf.write(tone_path, generate_test_audio(frequency=440), 16000)
try:
result = engine.transcribe_bass(str(tone_path))
notes = result[0]
print(f" Notes detected: {len(notes)}")
if len(notes) > 10:
print(f" ⚠ WARNING: Detecting out-of-range frequencies!")
return False
else:
print(f" βœ“ PASS: Correctly ignoring out-of-range")
except Exception as e:
print(f" βœ— FAIL: {e}")
return False
# Test 3: Bass frequency (E1 = 41.2 Hz)
print("\nTest 3: Bass frequency (E1 = 41.2 Hz)")
bass_path = test_dir / "tone_bass.wav"
sf.write(bass_path, generate_test_audio(frequency=41.2), 16000)
try:
result = engine.transcribe_bass(str(bass_path))
notes = result[0]
print(f" Notes detected: {len(notes)}")
if len(notes) == 0:
print(f" ⚠ WARNING: Not detecting bass frequencies!")
return False
elif len(notes) > 5:
print(f" βœ“ PASS: Detecting bass (note: may split sustained tone)")
else:
print(f" βœ“ PASS: Detecting bass")
except Exception as e:
print(f" βœ— FAIL: {e}")
return False
print("\nβœ“ Bass model validation PASSED")
return True
def validate_vocal_model(engine, test_dir):
"""Validate vocal model."""
print("\n" + "=" * 60)
print("Testing VOCAL Model")
print("=" * 60)
# Test 1: Silence
print("\nTest 1: Silence")
silence_path = test_dir / "silence.wav"
try:
result = engine.transcribe_vocals(str(silence_path))
notes = result[0]
metadata = result[1]
print(f" Notes detected: {len(notes)}")
print(f" Voiced ratio: {metadata.get('voiced_ratio', 0):.3f}")
if metadata.get('voiced_ratio', 0) > 0.3:
print(f" ⚠ WARNING: High voiced ratio in silence!")
return False
else:
print(f" βœ“ PASS: Low false positive rate")
except Exception as e:
print(f" βœ— FAIL: {e}")
return False
# Test 2: Vocal range frequency
print("\nTest 2: Vocal frequency (C4 = 261.6 Hz)")
vocal_path = test_dir / "tone_vocal.wav"
sf.write(vocal_path, generate_test_audio(frequency=261.6), 16000)
try:
result = engine.transcribe_vocals(str(vocal_path))
notes = result[0]
metadata = result[1]
print(f" Notes detected: {len(notes)}")
print(f" Voiced ratio: {metadata.get('voiced_ratio', 0):.3f}")
if len(notes) == 0:
print(f" ⚠ WARNING: Not detecting vocal frequencies!")
return False
else:
print(f" βœ“ PASS: Detecting vocal frequencies")
except Exception as e:
print(f" βœ— FAIL: {e}")
return False
print("\nβœ“ Vocal model validation PASSED")
return True
def validate_drum_model(engine, test_dir):
"""Validate drum model."""
print("\n" + "=" * 60)
print("Testing DRUM Model")
print("=" * 60)
# Test 1: Silence
print("\nTest 1: Silence")
silence_path = test_dir / "silence.wav"
try:
result = engine.transcribe_drums(str(silence_path))
hits = result[0]
metadata = result[1]
print(f" Hits detected: {len(hits)}")
print(f" Hits per class: {metadata.get('hits_per_class', {})}")
if len(hits) > 50:
print(f" ⚠ WARNING: Too many false positives in silence!")
return False
else:
print(f" βœ“ PASS: Low false positive rate")
except Exception as e:
print(f" βœ— FAIL: {e}")
return False
# Test 2: Impulse (should detect as drum hit)
print("\nTest 2: Impulse (synthetic drum)")
impulse_path = test_dir / "impulse.wav"
# Create impulse train
impulse = np.zeros(16000 * 3)
for i in range(0, len(impulse), 16000): # 1 Hz impulses
impulse[i:i + 100] = 0.8 * np.random.randn(100) # Noisy impulse
sf.write(impulse_path, impulse.astype(np.float32), 16000)
try:
result = engine.transcribe_drums(str(impulse_path))
hits = result[0]
print(f" Hits detected: {len(hits)}")
if len(hits) == 0:
print(f" ⚠ WARNING: Not detecting impulses!")
return False
else:
print(f" βœ“ PASS: Detecting impulses")
except Exception as e:
print(f" βœ— FAIL: {e}")
return False
print("\nβœ“ Drum model validation PASSED")
return True
def main():
"""Run all validation tests."""
print("=" * 60)
print("MODEL VALIDATION SUITE")
print("=" * 60)
# Setup
if torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'
print(f"\nDevice: {device}")
print(f"Weights directory: {WEIGHTS_DIR}")
# Create test directory
test_dir = Path("test_audio")
test_dir.mkdir(exist_ok=True)
# Initialize engine
try:
engine = TranscriptionEngine(str(WEIGHTS_DIR), device=device)
except Exception as e:
print(f"\nβœ— CRITICAL: Could not initialize TranscriptionEngine")
print(f"Error: {e}")
import traceback
traceback.print_exc()
return False
# Run tests
results = {}
if engine.is_model_available('bass'):
results['bass'] = validate_bass_model(engine, test_dir)
else:
print("\n⚠ Bass model not available, skipping")
results['bass'] = None
if engine.is_model_available('vocals'):
results['vocals'] = validate_vocal_model(engine, test_dir)
else:
print("\n⚠ Vocal model not available, skipping")
results['vocals'] = None
if engine.is_model_available('drums'):
results['drums'] = validate_drum_model(engine, test_dir)
else:
print("\n⚠ Drum model not available, skipping")
results['drums'] = None
# Summary
print("\n" + "=" * 60)
print("VALIDATION SUMMARY")
print("=" * 60)
for model, result in results.items():
if result is None:
status = "SKIPPED"
elif result:
status = "βœ“ PASSED"
else:
status = "βœ— FAILED"
print(f" {model}: {status}")
all_passed = all(r for r in results.values() if r is not None)
if all_passed:
print("\nβœ“ ALL TESTS PASSED")
return True
else:
print("\nβœ— SOME TESTS FAILED")
return False
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)