Spaces:
Build error
Build error
File size: 2,919 Bytes
aed1d05 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | #!/usr/bin/env python3
"""
Evaluate generated MIDI folders with one-shot quality metrics.
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from modelw.eval_metrics import EvaluationConfig, MIDIEvaluator
def main():
parser = argparse.ArgumentParser(description="Evaluate MIDI quality metrics")
parser.add_argument("midi_dir", type=str, help="Directory containing MIDI files")
parser.add_argument(
"--output",
type=str,
default="midi_quality_results.json",
help="Output JSON path",
)
parser.add_argument(
"--metadata",
type=str,
default=None,
help="Optional metadata.json path with prompt information",
)
parser.add_argument("--min-notes", type=int, default=20, help="Minimum note count")
parser.add_argument("--max-notes", type=int, default=2000, help="Maximum note count")
parser.add_argument(
"--min-duration-seconds",
type=float,
default=5.0,
help="Minimum file duration",
)
parser.add_argument(
"--acceptance-threshold",
type=float,
default=0.75,
help="Composite score threshold for acceptance",
)
args = parser.parse_args()
config = EvaluationConfig(
min_notes=args.min_notes,
max_notes=args.max_notes,
min_duration_seconds=args.min_duration_seconds,
acceptance_threshold=args.acceptance_threshold,
)
evaluator = MIDIEvaluator(config)
results = evaluator.evaluate_directory(
midi_dir=args.midi_dir,
metadata_path=args.metadata,
output_path=args.output,
)
summary = results["summary"]
print("\n" + "=" * 60)
print("MIDI Quality Evaluation")
print("=" * 60)
print(f"Files: {summary['total_files']}")
print(f"Valid: {summary['valid_files']}")
print(f"Hard validity rate: {summary['hard_validity_rate']:.2%}")
print(f"Acceptance rate: {summary['acceptance_rate']:.2%}")
print(f"Composite score: {summary['mean_composite_score']:.4f}")
print(f"Prompt match score: {summary['mean_prompt_match_score']}")
print(f"Key/scale adherence: {summary['mean_key_scale_adherence']:.4f}")
print(f"Rhythm grid accuracy: {summary['mean_rhythm_grid_accuracy']:.4f}")
print(f"Velocity expression: {summary['mean_velocity_expressiveness']:.4f}")
print(f"Repeat/variation: {summary['mean_repetition_variation_balance']:.4f}")
print(f"Section coherence: {summary['mean_section_coherence']:.4f}")
print(f"Track role integrity: {summary['mean_track_role_integrity']}")
print(f"Saved JSON: {Path(args.output).resolve()}")
if __name__ == "__main__":
main()
|