#!/usr/bin/env python3 """ Evaluate generated MIDI folders with one-shot quality metrics. """ from __future__ import annotations import argparse import sys from pathlib import Path PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from modelw.eval_metrics import EvaluationConfig, MIDIEvaluator def main(): parser = argparse.ArgumentParser(description="Evaluate MIDI quality metrics") parser.add_argument("midi_dir", type=str, help="Directory containing MIDI files") parser.add_argument( "--output", type=str, default="midi_quality_results.json", help="Output JSON path", ) parser.add_argument( "--metadata", type=str, default=None, help="Optional metadata.json path with prompt information", ) parser.add_argument("--min-notes", type=int, default=20, help="Minimum note count") parser.add_argument("--max-notes", type=int, default=2000, help="Maximum note count") parser.add_argument( "--min-duration-seconds", type=float, default=5.0, help="Minimum file duration", ) parser.add_argument( "--acceptance-threshold", type=float, default=0.75, help="Composite score threshold for acceptance", ) args = parser.parse_args() config = EvaluationConfig( min_notes=args.min_notes, max_notes=args.max_notes, min_duration_seconds=args.min_duration_seconds, acceptance_threshold=args.acceptance_threshold, ) evaluator = MIDIEvaluator(config) results = evaluator.evaluate_directory( midi_dir=args.midi_dir, metadata_path=args.metadata, output_path=args.output, ) summary = results["summary"] print("\n" + "=" * 60) print("MIDI Quality Evaluation") print("=" * 60) print(f"Files: {summary['total_files']}") print(f"Valid: {summary['valid_files']}") print(f"Hard validity rate: {summary['hard_validity_rate']:.2%}") print(f"Acceptance rate: {summary['acceptance_rate']:.2%}") print(f"Composite score: {summary['mean_composite_score']:.4f}") print(f"Prompt match score: {summary['mean_prompt_match_score']}") print(f"Key/scale adherence: {summary['mean_key_scale_adherence']:.4f}") print(f"Rhythm grid accuracy: {summary['mean_rhythm_grid_accuracy']:.4f}") print(f"Velocity expression: {summary['mean_velocity_expressiveness']:.4f}") print(f"Repeat/variation: {summary['mean_repetition_variation_balance']:.4f}") print(f"Section coherence: {summary['mean_section_coherence']:.4f}") print(f"Track role integrity: {summary['mean_track_role_integrity']}") print(f"Saved JSON: {Path(args.output).resolve()}") if __name__ == "__main__": main()