File size: 2,919 Bytes
aed1d05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!/usr/bin/env python3
"""
Evaluate generated MIDI folders with one-shot quality metrics.
"""

from __future__ import annotations

import argparse
import sys
from pathlib import Path

PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from modelw.eval_metrics import EvaluationConfig, MIDIEvaluator


def main():
    parser = argparse.ArgumentParser(description="Evaluate MIDI quality metrics")
    parser.add_argument("midi_dir", type=str, help="Directory containing MIDI files")
    parser.add_argument(
        "--output",
        type=str,
        default="midi_quality_results.json",
        help="Output JSON path",
    )
    parser.add_argument(
        "--metadata",
        type=str,
        default=None,
        help="Optional metadata.json path with prompt information",
    )
    parser.add_argument("--min-notes", type=int, default=20, help="Minimum note count")
    parser.add_argument("--max-notes", type=int, default=2000, help="Maximum note count")
    parser.add_argument(
        "--min-duration-seconds",
        type=float,
        default=5.0,
        help="Minimum file duration",
    )
    parser.add_argument(
        "--acceptance-threshold",
        type=float,
        default=0.75,
        help="Composite score threshold for acceptance",
    )
    args = parser.parse_args()

    config = EvaluationConfig(
        min_notes=args.min_notes,
        max_notes=args.max_notes,
        min_duration_seconds=args.min_duration_seconds,
        acceptance_threshold=args.acceptance_threshold,
    )
    evaluator = MIDIEvaluator(config)
    results = evaluator.evaluate_directory(
        midi_dir=args.midi_dir,
        metadata_path=args.metadata,
        output_path=args.output,
    )

    summary = results["summary"]
    print("\n" + "=" * 60)
    print("MIDI Quality Evaluation")
    print("=" * 60)
    print(f"Files:                 {summary['total_files']}")
    print(f"Valid:                 {summary['valid_files']}")
    print(f"Hard validity rate:    {summary['hard_validity_rate']:.2%}")
    print(f"Acceptance rate:       {summary['acceptance_rate']:.2%}")
    print(f"Composite score:       {summary['mean_composite_score']:.4f}")
    print(f"Prompt match score:    {summary['mean_prompt_match_score']}")
    print(f"Key/scale adherence:   {summary['mean_key_scale_adherence']:.4f}")
    print(f"Rhythm grid accuracy:  {summary['mean_rhythm_grid_accuracy']:.4f}")
    print(f"Velocity expression:   {summary['mean_velocity_expressiveness']:.4f}")
    print(f"Repeat/variation:      {summary['mean_repetition_variation_balance']:.4f}")
    print(f"Section coherence:     {summary['mean_section_coherence']:.4f}")
    print(f"Track role integrity:  {summary['mean_track_role_integrity']}")
    print(f"Saved JSON:            {Path(args.output).resolve()}")


if __name__ == "__main__":
    main()