Coda / src /tokenize_compound_cache.py
Prajanya Gupta
initial deploy
6b7b403
"""Pre-tokenize MIDI files with the compound tokenizer and cache chunks.
This is useful for preparing data while longer training jobs are running.
It writes train/val chunk tensors and a small JSON stats file.
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import torch
from compound_dataset import (
concat_sequences,
chunk_compound_stream,
load_encoded_compound_sequences,
split_chunks,
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Tokenize MIDI set with compound tokenizer and cache chunks."
)
p.add_argument(
"--sample-dir",
type=str,
default="data/lmd_sample_10000",
help="Directory containing .mid/.midi files.",
)
p.add_argument(
"--out-dir",
type=str,
default="data/compound_cache",
help="Directory to write cached tensors + stats JSON.",
)
p.add_argument("--block-size", type=int, default=512)
p.add_argument("--split-ratio", type=float, default=0.9)
p.add_argument("--seed", type=int, default=17)
return p.parse_args()
def main() -> None:
args = parse_args()
sample_dir = Path(args.sample_dir)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
if not sample_dir.exists():
raise FileNotFoundError(f"Sample dir not found: {sample_dir}")
midi_paths = sorted(sample_dir.rglob("*.mid")) + sorted(
sample_dir.rglob("*.midi")
)
sequences, n_failed = load_encoded_compound_sequences(sample_dir)
stream = concat_sequences(sequences)
chunks = chunk_compound_stream(stream, block_size=args.block_size)
train_chunks, val_chunks = split_chunks(
chunks, split_ratio=args.split_ratio, seed=args.seed
)
train_tensor = (
torch.stack(train_chunks, dim=0)
if train_chunks
else torch.empty(0, args.block_size, 7, dtype=torch.long)
)
val_tensor = (
torch.stack(val_chunks, dim=0)
if val_chunks
else torch.empty(0, args.block_size, 7, dtype=torch.long)
)
torch.save(train_tensor, out_dir / "compound_train_chunks.pt")
torch.save(val_tensor, out_dir / "compound_val_chunks.pt")
stats = {
"sample_dir": str(sample_dir),
"block_size": args.block_size,
"split_ratio": args.split_ratio,
"seed": args.seed,
"n_files_seen": len(midi_paths),
"n_files_encoded": len(sequences),
"n_files_failed": n_failed,
"n_steps_total": len(stream),
"n_chunks_total": len(chunks),
"n_train_chunks": len(train_chunks),
"n_val_chunks": len(val_chunks),
"train_tensor_shape": list(train_tensor.shape),
"val_tensor_shape": list(val_tensor.shape),
}
(out_dir / "compound_cache_stats.json").write_text(json.dumps(stats, indent=2))
print(
"[compound-cache] files seen/encoded/failed: "
f"{stats['n_files_seen']}/{stats['n_files_encoded']}/{stats['n_files_failed']}"
)
print(
"[compound-cache] steps/chunks/train/val: "
f"{stats['n_steps_total']}/{stats['n_chunks_total']}/"
f"{stats['n_train_chunks']}/{stats['n_val_chunks']}"
)
print(
"[compound-cache] wrote: "
f"{out_dir / 'compound_train_chunks.pt'}, "
f"{out_dir / 'compound_val_chunks.pt'}, "
f"{out_dir / 'compound_cache_stats.json'}"
)
if __name__ == "__main__":
main()