ROOM / scripts /test_tokenizer.py
solo363614's picture
Upload folder using huggingface_hub
aed1d05 verified
#!/usr/bin/env python3
"""
Test MIDI Tokenizer
See how the tokenizer encodes your MIDI files.
Shows the actual token sequences and reconstructs MIDI.
"""
import argparse
import sys
from pathlib import Path
import random
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from modelw.tokenizer import MIDITokenizer
def test_single_file(tokenizer: MIDITokenizer, midi_path: str, verbose: bool = True):
"""Test tokenization on a single file."""
print(f"\n{'─'*60}")
print(f"File: {Path(midi_path).name}")
print(f"{'─'*60}")
try:
# Encode
token_ids = tokenizer.encode(midi_path)
tokens = [tokenizer.id_to_token.get(tid, "<UNK>") for tid in token_ids]
print(f"✓ Encoded: {len(token_ids)} tokens")
# Show token breakdown
token_types = {
"special": 0,
"tempo": 0,
"instrument": 0,
"mood": 0,
"bar": 0,
"position": 0,
"pitch": 0,
"velocity": 0,
"duration": 0,
}
for t in tokens:
if t.startswith("<TEMPO"):
token_types["tempo"] += 1
elif t.startswith("<INST"):
token_types["instrument"] += 1
elif t.startswith("<MOOD"):
token_types["mood"] += 1
elif t == "<BAR>":
token_types["bar"] += 1
elif t.startswith("<POS"):
token_types["position"] += 1
elif t.startswith("<PITCH"):
token_types["pitch"] += 1
elif t.startswith("<VEL"):
token_types["velocity"] += 1
elif t.startswith("<DUR"):
token_types["duration"] += 1
else:
token_types["special"] += 1
print(f"\nToken breakdown:")
for ttype, count in token_types.items():
if count > 0:
print(f" {ttype:12s}: {count:4d}")
# Show first N tokens
if verbose:
print(f"\nFirst 50 tokens:")
for i, t in enumerate(tokens[:50]):
print(f" {i:3d}: {t}")
if len(tokens) > 50:
print(f" ... ({len(tokens) - 50} more)")
# Decode back
print(f"\nDecoding back to MIDI...")
pm = tokenizer.decode(token_ids)
if pm.instruments:
notes = pm.instruments[0].notes
print(f"✓ Reconstructed: {len(notes)} notes, {pm.get_end_time():.1f}s duration")
else:
print("⚠️ No instruments in reconstruction")
return True, len(token_ids)
except Exception as e:
print(f"✗ Error: {e}")
return False, 0
def test_dataset(data_dir: str, num_samples: int = 10, verbose: bool = False):
"""Test tokenization on multiple files."""
print(f"\n{'='*60}")
print("MIDI Tokenizer Test")
print(f"{'='*60}")
# Create tokenizer
tokenizer = MIDITokenizer()
print(f"\nTokenizer vocab size: {tokenizer.vocab_size}")
# Find MIDI files
data_path = Path(data_dir)
midi_files = list(data_path.rglob("*.mid")) + list(data_path.rglob("*.midi"))
print(f"Found {len(midi_files)} MIDI files")
if len(midi_files) == 0:
print("⚠️ No MIDI files found!")
return
# Sample files
if len(midi_files) > num_samples:
midi_files = random.sample(midi_files, num_samples)
print(f"Testing {len(midi_files)} files...")
# Test each file
success = 0
total_tokens = 0
for midi_path in midi_files:
ok, num_tokens = test_single_file(tokenizer, str(midi_path), verbose=verbose)
if ok:
success += 1
total_tokens += num_tokens
# Summary
print(f"\n{'='*60}")
print("Summary")
print(f"{'='*60}")
print(f"Success rate: {success}/{len(midi_files)} ({success/len(midi_files)*100:.0f}%)")
print(f"Avg tokens: {total_tokens/success:.0f}" if success > 0 else "N/A")
print(f"{'='*60}")
def show_vocab(tokenizer: MIDITokenizer):
"""Show the vocabulary."""
print(f"\n{'='*60}")
print("Tokenizer Vocabulary")
print(f"{'='*60}")
print(f"Total vocab size: {tokenizer.vocab_size}\n")
print("Special tokens:")
for t in tokenizer.special_tokens:
print(f" {tokenizer.token_to_id[t]:4d}: {t}")
print(f"\nTempo tokens ({len(tokenizer.tempo_tokens)}):")
print(f" {tokenizer.tempo_tokens[0]} ... {tokenizer.tempo_tokens[-1]}")
print(f"\nInstrument tokens ({len(tokenizer.instrument_tokens)}):")
for t in tokenizer.instrument_tokens:
print(f" {tokenizer.token_to_id[t]:4d}: {t}")
print(f"\nMood tokens ({len(tokenizer.mood_tokens)}):")
for t in tokenizer.mood_tokens:
print(f" {tokenizer.token_to_id[t]:4d}: {t}")
print(f"\nPitch tokens: {len(tokenizer.pitch_tokens)}")
print(f" {tokenizer.pitch_tokens[0]} (MIDI {tokenizer.config.pitch_range[0]}) to")
print(f" {tokenizer.pitch_tokens[-1]} (MIDI {tokenizer.config.pitch_range[1]-1})")
print(f"\nVelocity tokens: {len(tokenizer.velocity_tokens)}")
print(f"Duration tokens: {len(tokenizer.duration_tokens)}")
print(f"Position tokens: {len(tokenizer.position_tokens)}")
def main():
parser = argparse.ArgumentParser(description="Test MIDI tokenizer")
parser.add_argument(
"data_dir",
type=str,
nargs="?",
default=None,
help="Directory containing MIDI files (optional)",
)
parser.add_argument(
"--file",
type=str,
help="Test a specific MIDI file",
)
parser.add_argument(
"--num-samples",
type=int,
default=5,
help="Number of files to sample",
)
parser.add_argument(
"--verbose", "-v",
action="store_true",
help="Show full token sequences",
)
parser.add_argument(
"--vocab",
action="store_true",
help="Show vocabulary",
)
args = parser.parse_args()
tokenizer = MIDITokenizer()
if args.vocab:
show_vocab(tokenizer)
return
if args.file:
test_single_file(tokenizer, args.file, verbose=True)
return
if args.data_dir:
test_dataset(args.data_dir, args.num_samples, args.verbose)
else:
# Just show vocab if no data
show_vocab(tokenizer)
print("\n💡 Usage:")
print(" python scripts/test_tokenizer.py ./data/lakh_midi")
print(" python scripts/test_tokenizer.py --file song.mid -v")
if __name__ == "__main__":
main()