File size: 3,497 Bytes
6b7b403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""Pre-tokenize MIDI files with the compound tokenizer and cache chunks.

This is useful for preparing data while longer training jobs are running.
It writes train/val chunk tensors and a small JSON stats file.
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path

import torch

from compound_dataset import (
    concat_sequences,
    chunk_compound_stream,
    load_encoded_compound_sequences,
    split_chunks,
)


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="Tokenize MIDI set with compound tokenizer and cache chunks."
    )
    p.add_argument(
        "--sample-dir",
        type=str,
        default="data/lmd_sample_10000",
        help="Directory containing .mid/.midi files.",
    )
    p.add_argument(
        "--out-dir",
        type=str,
        default="data/compound_cache",
        help="Directory to write cached tensors + stats JSON.",
    )
    p.add_argument("--block-size", type=int, default=512)
    p.add_argument("--split-ratio", type=float, default=0.9)
    p.add_argument("--seed", type=int, default=17)
    return p.parse_args()


def main() -> None:
    args = parse_args()
    sample_dir = Path(args.sample_dir)
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    if not sample_dir.exists():
        raise FileNotFoundError(f"Sample dir not found: {sample_dir}")

    midi_paths = sorted(sample_dir.rglob("*.mid")) + sorted(
        sample_dir.rglob("*.midi")
    )
    sequences, n_failed = load_encoded_compound_sequences(sample_dir)
    stream = concat_sequences(sequences)
    chunks = chunk_compound_stream(stream, block_size=args.block_size)
    train_chunks, val_chunks = split_chunks(
        chunks, split_ratio=args.split_ratio, seed=args.seed
    )

    train_tensor = (
        torch.stack(train_chunks, dim=0)
        if train_chunks
        else torch.empty(0, args.block_size, 7, dtype=torch.long)
    )
    val_tensor = (
        torch.stack(val_chunks, dim=0)
        if val_chunks
        else torch.empty(0, args.block_size, 7, dtype=torch.long)
    )

    torch.save(train_tensor, out_dir / "compound_train_chunks.pt")
    torch.save(val_tensor, out_dir / "compound_val_chunks.pt")

    stats = {
        "sample_dir": str(sample_dir),
        "block_size": args.block_size,
        "split_ratio": args.split_ratio,
        "seed": args.seed,
        "n_files_seen": len(midi_paths),
        "n_files_encoded": len(sequences),
        "n_files_failed": n_failed,
        "n_steps_total": len(stream),
        "n_chunks_total": len(chunks),
        "n_train_chunks": len(train_chunks),
        "n_val_chunks": len(val_chunks),
        "train_tensor_shape": list(train_tensor.shape),
        "val_tensor_shape": list(val_tensor.shape),
    }
    (out_dir / "compound_cache_stats.json").write_text(json.dumps(stats, indent=2))

    print(
        "[compound-cache] files seen/encoded/failed: "
        f"{stats['n_files_seen']}/{stats['n_files_encoded']}/{stats['n_files_failed']}"
    )
    print(
        "[compound-cache] steps/chunks/train/val: "
        f"{stats['n_steps_total']}/{stats['n_chunks_total']}/"
        f"{stats['n_train_chunks']}/{stats['n_val_chunks']}"
    )
    print(
        "[compound-cache] wrote: "
        f"{out_dir / 'compound_train_chunks.pt'}, "
        f"{out_dir / 'compound_val_chunks.pt'}, "
        f"{out_dir / 'compound_cache_stats.json'}"
    )


if __name__ == "__main__":
    main()