Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| """ | |
| Analyze MIDI Dataset | |
| Quick analysis of your MIDI files to see: | |
| - Instrument distribution (piano, guitar, drums, etc.) | |
| - Tempo distribution | |
| - Note statistics | |
| - File quality | |
| """ | |
| import argparse | |
| from collections import Counter, defaultdict | |
| from pathlib import Path | |
| import random | |
| import numpy as np | |
| import pretty_midi | |
| from tqdm import tqdm | |
| # GM Instrument families | |
| INSTRUMENT_FAMILIES = { | |
| range(0, 8): "Piano", | |
| range(8, 16): "Chromatic Percussion", | |
| range(16, 24): "Organ", | |
| range(24, 32): "Guitar", | |
| range(32, 40): "Bass", | |
| range(40, 48): "Strings", | |
| range(48, 56): "Ensemble", | |
| range(56, 64): "Brass", | |
| range(64, 72): "Reed", | |
| range(72, 80): "Pipe", | |
| range(80, 88): "Synth Lead", | |
| range(88, 96): "Synth Pad", | |
| range(96, 104): "Synth Effects", | |
| range(104, 112): "Ethnic", | |
| range(112, 120): "Percussive", | |
| range(120, 128): "Sound Effects", | |
| } | |
| def get_instrument_family(program: int) -> str: | |
| """Get instrument family name from GM program number.""" | |
| for prog_range, name in INSTRUMENT_FAMILIES.items(): | |
| if program in prog_range: | |
| return name | |
| return "Unknown" | |
| def analyze_midi(midi_path: str) -> dict: | |
| """Analyze a single MIDI file.""" | |
| try: | |
| pm = pretty_midi.PrettyMIDI(str(midi_path)) | |
| except Exception as e: | |
| return {"valid": False, "error": str(e)} | |
| result = { | |
| "valid": True, | |
| "instruments": [], | |
| "has_drums": False, | |
| "note_count": 0, | |
| "duration": pm.get_end_time(), | |
| "tempos": [], | |
| "time_signatures": [], | |
| } | |
| # Get tempo | |
| tempo_times, tempos = pm.get_tempo_changes() | |
| result["tempos"] = tempos.tolist() if len(tempos) > 0 else [120.0] | |
| result["avg_tempo"] = np.mean(result["tempos"]) | |
| # Analyze instruments | |
| for inst in pm.instruments: | |
| if inst.is_drum: | |
| result["has_drums"] = True | |
| result["instruments"].append("Drums") | |
| else: | |
| family = get_instrument_family(inst.program) | |
| result["instruments"].append(family) | |
| result["note_count"] += len(inst.notes) | |
| # Pitch statistics | |
| all_pitches = [] | |
| all_velocities = [] | |
| for inst in pm.instruments: | |
| for note in inst.notes: | |
| all_pitches.append(note.pitch) | |
| all_velocities.append(note.velocity) | |
| if all_pitches: | |
| result["pitch_range"] = (min(all_pitches), max(all_pitches)) | |
| result["avg_velocity"] = np.mean(all_velocities) | |
| return result | |
| def analyze_dataset( | |
| data_dir: str, | |
| max_files: int = None, | |
| sample: bool = True, | |
| ): | |
| """Analyze entire dataset.""" | |
| data_path = Path(data_dir) | |
| # Find all MIDI files | |
| midi_files = list(data_path.rglob("*.mid")) + list(data_path.rglob("*.midi")) | |
| midi_files += list(data_path.rglob("*.MID")) + list(data_path.rglob("*.MIDI")) | |
| print(f"\n{'='*60}") | |
| print(f"MIDI Dataset Analysis") | |
| print(f"{'='*60}") | |
| print(f"Directory: {data_dir}") | |
| print(f"Total files found: {len(midi_files)}") | |
| if len(midi_files) == 0: | |
| print("\n⚠️ No MIDI files found!") | |
| return | |
| # Sample if too many | |
| if sample and max_files and len(midi_files) > max_files: | |
| print(f"Sampling {max_files} files for analysis...") | |
| midi_files = random.sample(midi_files, max_files) | |
| elif max_files: | |
| midi_files = midi_files[:max_files] | |
| print(f"Analyzing {len(midi_files)} files...\n") | |
| # Counters | |
| instrument_counter = Counter() | |
| tempo_bins = Counter() | |
| valid_count = 0 | |
| invalid_count = 0 | |
| total_notes = 0 | |
| durations = [] | |
| has_drums_count = 0 | |
| piano_only_count = 0 | |
| multi_instrument_count = 0 | |
| for midi_path in tqdm(midi_files, desc="Analyzing"): | |
| result = analyze_midi(midi_path) | |
| if not result["valid"]: | |
| invalid_count += 1 | |
| continue | |
| valid_count += 1 | |
| total_notes += result["note_count"] | |
| durations.append(result["duration"]) | |
| # Count instruments | |
| for inst in result["instruments"]: | |
| instrument_counter[inst] += 1 | |
| if result["has_drums"]: | |
| has_drums_count += 1 | |
| # Check if piano only | |
| non_drum_instruments = [i for i in result["instruments"] if i != "Drums"] | |
| if non_drum_instruments == ["Piano"]: | |
| piano_only_count += 1 | |
| if len(set(non_drum_instruments)) > 1: | |
| multi_instrument_count += 1 | |
| # Bin tempo | |
| tempo = result["avg_tempo"] | |
| if tempo < 60: | |
| tempo_bins["Very Slow (<60)"] += 1 | |
| elif tempo < 90: | |
| tempo_bins["Slow (60-90)"] += 1 | |
| elif tempo < 120: | |
| tempo_bins["Medium (90-120)"] += 1 | |
| elif tempo < 150: | |
| tempo_bins["Fast (120-150)"] += 1 | |
| else: | |
| tempo_bins["Very Fast (>150)"] += 1 | |
| # Print results | |
| print(f"\n{'='*60}") | |
| print("RESULTS") | |
| print(f"{'='*60}") | |
| print(f"\n📊 File Statistics:") | |
| print(f" Valid files: {valid_count} ({valid_count/len(midi_files)*100:.1f}%)") | |
| print(f" Invalid files: {invalid_count}") | |
| print(f" Total notes: {total_notes:,}") | |
| print(f" Avg notes/file: {total_notes/valid_count:.0f}") | |
| if durations: | |
| print(f"\n⏱️ Duration:") | |
| print(f" Average: {np.mean(durations):.1f}s") | |
| print(f" Median: {np.median(durations):.1f}s") | |
| print(f" Range: {min(durations):.1f}s - {max(durations):.1f}s") | |
| print(f"\n🎹 Instrument Distribution:") | |
| for inst, count in instrument_counter.most_common(15): | |
| pct = count / valid_count * 100 | |
| bar = "█" * int(pct / 2) | |
| print(f" {inst:20s} {count:6d} ({pct:5.1f}%) {bar}") | |
| print(f"\n🎼 Composition Types:") | |
| print(f" Piano only: {piano_only_count:6d} ({piano_only_count/valid_count*100:.1f}%)") | |
| print(f" With drums: {has_drums_count:6d} ({has_drums_count/valid_count*100:.1f}%)") | |
| print(f" Multi-instrument: {multi_instrument_count:6d} ({multi_instrument_count/valid_count*100:.1f}%)") | |
| print(f"\n🎵 Tempo Distribution:") | |
| for tempo_range, count in sorted(tempo_bins.items()): | |
| pct = count / valid_count * 100 | |
| bar = "█" * int(pct / 2) | |
| print(f" {tempo_range:20s} {count:6d} ({pct:5.1f}%) {bar}") | |
| print(f"\n{'='*60}") | |
| print("✅ Analysis complete!") | |
| print(f"{'='*60}") | |
| # Recommendations | |
| print(f"\n💡 Recommendations:") | |
| top_instrument = instrument_counter.most_common(1)[0][0] | |
| print(f" • Most common instrument: {top_instrument}") | |
| if piano_only_count / valid_count > 0.3: | |
| print(f" • Lots of piano-only files - good for piano generation") | |
| if has_drums_count / valid_count > 0.5: | |
| print(f" • Strong drum presence - rhythm patterns will train well") | |
| if multi_instrument_count / valid_count > 0.4: | |
| print(f" • Many multi-instrument files - good arrangement learning") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Analyze MIDI dataset") | |
| parser.add_argument( | |
| "data_dir", | |
| type=str, | |
| help="Directory containing MIDI files", | |
| ) | |
| parser.add_argument( | |
| "--max-files", | |
| type=int, | |
| default=5000, | |
| help="Max files to analyze (default: 5000, use -1 for all)", | |
| ) | |
| parser.add_argument( | |
| "--all", | |
| action="store_true", | |
| help="Analyze all files (may take a while for 190k files)", | |
| ) | |
| args = parser.parse_args() | |
| max_files = None if args.all else args.max_files | |
| if args.max_files == -1: | |
| max_files = None | |
| analyze_dataset(args.data_dir, max_files=max_files) | |
| if __name__ == "__main__": | |
| main() | |