Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| """ | |
| Validate generated MIDI files for quality. | |
| Checks: | |
| - File integrity (can be parsed) | |
| - Note count and distribution | |
| - Pitch diversity | |
| - Temporal structure | |
| - Velocity patterns | |
| """ | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from collections import Counter | |
| import numpy as np | |
| import pretty_midi | |
| from tqdm import tqdm | |
| def analyze_midi(midi_path: str) -> dict: | |
| """Analyze a single MIDI file.""" | |
| try: | |
| pm = pretty_midi.PrettyMIDI(midi_path) | |
| except Exception as e: | |
| return {"valid": False, "error": str(e)} | |
| # Collect all notes | |
| all_notes = [] | |
| for inst in pm.instruments: | |
| all_notes.extend(inst.notes) | |
| if len(all_notes) == 0: | |
| return {"valid": False, "error": "No notes found"} | |
| # Extract features | |
| pitches = [n.pitch for n in all_notes] | |
| velocities = [n.velocity for n in all_notes] | |
| durations = [n.end - n.start for n in all_notes] | |
| starts = [n.start for n in all_notes] | |
| # Pitch analysis | |
| unique_pitches = len(set(pitches)) | |
| pitch_range = max(pitches) - min(pitches) | |
| pitch_mean = np.mean(pitches) | |
| pitch_std = np.std(pitches) | |
| # Velocity analysis | |
| velocity_mean = np.mean(velocities) | |
| velocity_std = np.std(velocities) | |
| # Duration analysis | |
| duration_mean = np.mean(durations) | |
| duration_std = np.std(durations) | |
| # Temporal analysis | |
| total_duration = pm.get_end_time() | |
| note_density = len(all_notes) / total_duration if total_duration > 0 else 0 | |
| # Repetition analysis | |
| pitch_counter = Counter(pitches) | |
| most_common_pitch_ratio = pitch_counter.most_common(1)[0][1] / len(pitches) | |
| return { | |
| "valid": True, | |
| "note_count": len(all_notes), | |
| "unique_pitches": unique_pitches, | |
| "pitch_range": pitch_range, | |
| "pitch_mean": round(pitch_mean, 2), | |
| "pitch_std": round(pitch_std, 2), | |
| "velocity_mean": round(velocity_mean, 2), | |
| "velocity_std": round(velocity_std, 2), | |
| "duration_mean": round(duration_mean, 4), | |
| "duration_std": round(duration_std, 4), | |
| "total_duration": round(total_duration, 2), | |
| "note_density": round(note_density, 2), | |
| "most_common_pitch_ratio": round(most_common_pitch_ratio, 4), | |
| "num_instruments": len(pm.instruments), | |
| } | |
| def validate_batch( | |
| midi_dir: str, | |
| output_path: str = None, | |
| min_notes: int = 20, | |
| max_notes: int = 2000, | |
| min_unique_pitches: int = 5, | |
| min_duration: float = 5.0, | |
| max_repetition_ratio: float = 0.5, | |
| ): | |
| """Validate a batch of MIDI files.""" | |
| midi_dir = Path(midi_dir) | |
| midi_files = list(midi_dir.rglob("*.mid")) + list(midi_dir.rglob("*.midi")) | |
| print(f"Found {len(midi_files)} MIDI files") | |
| results = { | |
| "total": len(midi_files), | |
| "valid": 0, | |
| "invalid": 0, | |
| "passed_quality": 0, | |
| "failed_quality": 0, | |
| "files": [], | |
| } | |
| quality_failures = Counter() | |
| for midi_path in tqdm(midi_files, desc="Validating"): | |
| analysis = analyze_midi(str(midi_path)) | |
| analysis["path"] = str(midi_path) | |
| if not analysis.get("valid"): | |
| results["invalid"] += 1 | |
| analysis["quality_passed"] = False | |
| quality_failures["parse_error"] += 1 | |
| else: | |
| results["valid"] += 1 | |
| # Quality checks | |
| failed = [] | |
| if analysis["note_count"] < min_notes: | |
| failed.append("too_few_notes") | |
| if analysis["note_count"] > max_notes: | |
| failed.append("too_many_notes") | |
| if analysis["unique_pitches"] < min_unique_pitches: | |
| failed.append("low_pitch_diversity") | |
| if analysis["total_duration"] < min_duration: | |
| failed.append("too_short") | |
| if analysis["most_common_pitch_ratio"] > max_repetition_ratio: | |
| failed.append("too_repetitive") | |
| if failed: | |
| results["failed_quality"] += 1 | |
| analysis["quality_passed"] = False | |
| analysis["quality_failures"] = failed | |
| for f in failed: | |
| quality_failures[f] += 1 | |
| else: | |
| results["passed_quality"] += 1 | |
| analysis["quality_passed"] = True | |
| results["files"].append(analysis) | |
| # Summary stats | |
| valid_analyses = [f for f in results["files"] if f.get("valid")] | |
| if valid_analyses: | |
| results["summary"] = { | |
| "avg_notes": round(np.mean([f["note_count"] for f in valid_analyses]), 1), | |
| "avg_unique_pitches": round(np.mean([f["unique_pitches"] for f in valid_analyses]), 1), | |
| "avg_duration": round(np.mean([f["total_duration"] for f in valid_analyses]), 1), | |
| "avg_note_density": round(np.mean([f["note_density"] for f in valid_analyses]), 2), | |
| } | |
| results["quality_failure_counts"] = dict(quality_failures) | |
| # Print summary | |
| print("\n" + "="*60) | |
| print("Validation Summary") | |
| print("="*60) | |
| print(f"Total files: {results['total']}") | |
| print(f"Valid (parseable): {results['valid']} ({results['valid']/results['total']*100:.1f}%)") | |
| print(f"Invalid: {results['invalid']}") | |
| print(f"Passed quality: {results['passed_quality']} ({results['passed_quality']/results['total']*100:.1f}%)") | |
| print(f"Failed quality: {results['failed_quality']}") | |
| if valid_analyses: | |
| print(f"\nValid file statistics:") | |
| print(f" Avg notes: {results['summary']['avg_notes']}") | |
| print(f" Avg unique pitches: {results['summary']['avg_unique_pitches']}") | |
| print(f" Avg duration: {results['summary']['avg_duration']}s") | |
| print(f" Avg note density: {results['summary']['avg_note_density']} notes/s") | |
| if quality_failures: | |
| print(f"\nQuality failure breakdown:") | |
| for reason, count in quality_failures.most_common(): | |
| print(f" {reason}: {count}") | |
| # Save results | |
| if output_path: | |
| with open(output_path, "w") as f: | |
| json.dump(results, f, indent=2) | |
| print(f"\nResults saved to {output_path}") | |
| return results | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Validate MIDI files") | |
| parser.add_argument( | |
| "midi_dir", | |
| type=str, | |
| help="Directory containing MIDI files", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=str, | |
| default="validation_results.json", | |
| help="Output JSON file path", | |
| ) | |
| parser.add_argument( | |
| "--min-notes", | |
| type=int, | |
| default=20, | |
| help="Minimum note count", | |
| ) | |
| parser.add_argument( | |
| "--max-notes", | |
| type=int, | |
| default=2000, | |
| help="Maximum note count", | |
| ) | |
| args = parser.parse_args() | |
| validate_batch( | |
| args.midi_dir, | |
| args.output, | |
| min_notes=args.min_notes, | |
| max_notes=args.max_notes, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |