Spaces:
Running
Running
| import os | |
| import subprocess | |
| import pytest | |
| from pathlib import Path | |
| from tests.utils import generate_reference_images, compare_images | |
| def fixture_input_file(): | |
| """Fixture providing the test input audio file path.""" | |
| return "tests/inputs/mardy20s.flac" | |
| def fixture_reference_dir(): | |
| """Fixture providing the reference images directory path.""" | |
| return "tests/inputs/reference" | |
| def fixture_cleanup_output_files(): | |
| """Fixture to clean up output files before and after test.""" | |
| # This list will be populated by the test functions | |
| output_files = [] | |
| # Yield to allow the test to run and add files to the list | |
| yield output_files | |
| # Clean up output files after test | |
| for file in output_files: | |
| if os.path.exists(file): | |
| print(f"Test output file exists: {file}") | |
| os.remove(file) | |
| def run_separation_test(model, audio_path, expected_files): | |
| """Helper function to run a separation test with a specific model.""" | |
| # Clean up any existing output files before the test | |
| for file in expected_files: | |
| if os.path.exists(file): | |
| print(f"Deleting existing test output file {file}") | |
| os.remove(file) | |
| # Run the CLI command | |
| result = subprocess.run(["audio-separator", "-m", model, audio_path], capture_output=True, text=True, check=False) # Explicitly set check to False as we handle errors manually | |
| # Check that the command completed successfully | |
| assert result.returncode == 0, f"Command failed with output: {result.stderr}" | |
| # Check that the output files were created | |
| for file in expected_files: | |
| assert os.path.exists(file), f"Output file {file} was not created" | |
| assert os.path.getsize(file) > 0, f"Output file {file} is empty" | |
| return result | |
| def validate_audio_output(output_file, reference_dir, waveform_threshold=0.999, spectrogram_threshold=None): | |
| """Validate an audio output file by comparing its waveform and spectrogram with reference images. | |
| Args: | |
| output_file: Path to the audio output file | |
| reference_dir: Directory containing reference images | |
| waveform_threshold: Minimum similarity required for waveform images (0.0-1.0) | |
| spectrogram_threshold: Minimum similarity for spectrogram images (0.0-1.0), defaults to waveform_threshold if None | |
| Returns: | |
| Tuple of booleans: (waveform_match, spectrogram_match) | |
| """ | |
| # If spectrogram threshold not specified, use the same as waveform threshold | |
| if spectrogram_threshold is None: | |
| spectrogram_threshold = waveform_threshold | |
| # Create temporary directory for generated images | |
| temp_dir = os.path.join(os.path.dirname(output_file), "temp_images") | |
| os.makedirs(temp_dir, exist_ok=True) | |
| # Generate waveform and spectrogram images for the output file | |
| output_filename = os.path.basename(output_file) | |
| name_without_ext = os.path.splitext(output_filename)[0] | |
| # Generate actual images | |
| actual_waveform_path, actual_spectrogram_path = generate_reference_images(output_file, temp_dir, prefix="actual_") | |
| # Path to expected reference images | |
| expected_waveform_path = os.path.join(reference_dir, f"expected_{name_without_ext}_waveform.png") | |
| expected_spectrogram_path = os.path.join(reference_dir, f"expected_{name_without_ext}_spectrogram.png") | |
| # Check if reference images exist | |
| if not os.path.exists(expected_waveform_path) or not os.path.exists(expected_spectrogram_path): | |
| print(f"Warning: Reference images not found for {output_file}") | |
| print(f"Expected: {expected_waveform_path} and {expected_spectrogram_path}") | |
| return False, False | |
| # Compare waveform images | |
| waveform_similarity, waveform_match = compare_images(expected_waveform_path, actual_waveform_path, min_similarity_threshold=waveform_threshold) | |
| # Compare spectrogram images | |
| spectrogram_similarity, spectrogram_match = compare_images(expected_spectrogram_path, actual_spectrogram_path, min_similarity_threshold=spectrogram_threshold) | |
| print(f"Validation results for {output_file}:\n") | |
| print(f" Waveform similarity: {waveform_similarity:.4f} (match: {waveform_match}, threshold: {waveform_threshold:.2f})\n") | |
| print(f" Spectrogram similarity: {spectrogram_similarity:.4f} (match: {spectrogram_match}, threshold: {spectrogram_threshold:.2f})\n") | |
| # Cleanup temp images (optional, uncomment if needed) | |
| # os.remove(actual_waveform_path) | |
| # os.remove(actual_spectrogram_path) | |
| return waveform_match, spectrogram_match | |
| # Default similarity threshold to use for most models | |
| DEFAULT_SIMILARITY_THRESHOLDS = (0.90, 0.80) # (waveform_threshold, spectrogram_threshold) | |
| # Model-specific similarity thresholds | |
| # Use lower thresholds for models that show more variation between runs | |
| MODEL_SIMILARITY_THRESHOLDS = { | |
| # Format: (waveform_threshold, spectrogram_threshold) | |
| "htdemucs_6s.yaml": (0.90, 0.70) # Demucs multi-stem output (e.g. "Other" and "Piano") is a lot more variable | |
| } | |
| # Parameterized test for multiple models | |
| MODEL_PARAMS = [ | |
| # (model_filename, expected_output_filenames) | |
| ("kuielab_b_vocals.onnx", ["mardy20s_(Instrumental)_kuielab_b_vocals.flac", "mardy20s_(Vocals)_kuielab_b_vocals.flac"]), | |
| ("MGM_MAIN_v4.pth", ["mardy20s_(Instrumental)_MGM_MAIN_v4.flac", "mardy20s_(Vocals)_MGM_MAIN_v4.flac"]), | |
| ("UVR-MDX-NET-Inst_HQ_4.onnx", ["mardy20s_(Instrumental)_UVR-MDX-NET-Inst_HQ_4.flac", "mardy20s_(Vocals)_UVR-MDX-NET-Inst_HQ_4.flac"]), | |
| ("2_HP-UVR.pth", ["mardy20s_(Instrumental)_2_HP-UVR.flac", "mardy20s_(Vocals)_2_HP-UVR.flac"]), | |
| ( | |
| "htdemucs_6s.yaml", | |
| [ | |
| "mardy20s_(Vocals)_htdemucs_6s.flac", | |
| "mardy20s_(Drums)_htdemucs_6s.flac", | |
| "mardy20s_(Bass)_htdemucs_6s.flac", | |
| "mardy20s_(Other)_htdemucs_6s.flac", | |
| "mardy20s_(Guitar)_htdemucs_6s.flac", | |
| "mardy20s_(Piano)_htdemucs_6s.flac", | |
| ], | |
| ), | |
| ("model_bs_roformer_ep_937_sdr_10.5309.ckpt", ["mardy20s_(Drum-Bass)_model_bs_roformer_ep_937_sdr_10.flac", "mardy20s_(No Drum-Bass)_model_bs_roformer_ep_937_sdr_10.flac"]), | |
| ("model_bs_roformer_ep_317_sdr_12.9755.ckpt", ["mardy20s_(Instrumental)_model_bs_roformer_ep_317_sdr_12.flac", "mardy20s_(Vocals)_model_bs_roformer_ep_317_sdr_12.flac"]), | |
| ] | |
| def test_model_separation(model, expected_files, input_file, reference_dir, cleanup_output_files): | |
| """Parameterized test for multiple model files.""" | |
| # Add files to the cleanup list | |
| cleanup_output_files.extend(expected_files) | |
| # Run the test | |
| run_separation_test(model, input_file, expected_files) | |
| # Validate the output audio files | |
| print(f"\nValidating output files for model {model}...") | |
| # Get model-specific similarity threshold or use default | |
| threshold = MODEL_SIMILARITY_THRESHOLDS.get(model, DEFAULT_SIMILARITY_THRESHOLDS) | |
| # Unpack thresholds - DEFAULT_SIMILARITY_THRESHOLDS is now always a tuple | |
| waveform_threshold, spectrogram_threshold = threshold | |
| print(f"Using thresholds - waveform: {waveform_threshold}, spectrogram: {spectrogram_threshold} for model {model}") | |
| for output_file in expected_files: | |
| # Skip validation if reference images are not required (set environment variable to skip) | |
| if os.environ.get("SKIP_AUDIO_VALIDATION") == "1": | |
| print(f"Skipping audio validation for {output_file} (SKIP_AUDIO_VALIDATION=1)") | |
| continue | |
| waveform_match, spectrogram_match = validate_audio_output(output_file, reference_dir, waveform_threshold=waveform_threshold, spectrogram_threshold=spectrogram_threshold) | |
| # Assert that the output matches the reference | |
| assert waveform_match, f"Waveform for {output_file} does not match the reference" | |
| assert spectrogram_match, f"Spectrogram for {output_file} does not match the reference" | |