"""Dataset utilities for the Octuple-style compound tokenizer. Mirrors ``dataset.py`` but yields compound steps shaped ``(T, N_AXES)`` instead of flat token sequences. Each step is a tuple of feature ids that the compound model embeds in parallel and sums. """ from __future__ import annotations import random from dataclasses import dataclass from pathlib import Path from typing import List, Sequence import pretty_midi import torch from torch.utils.data import DataLoader, Dataset from compound import ( AXIS_SIZES, N_AXES, SENTINELS, STEP_BOS, STEP_EOS, STEP_PAD, encode_compound, ) DEFAULT_BLOCK_SIZE = 512 DEFAULT_BATCH_SIZE = 16 DEFAULT_SPLIT_RATIO = 0.9 DEFAULT_SEED = 17 @dataclass class CompoundDatasetStats: n_files_seen: int n_files_encoded: int n_files_failed: int n_steps_total: int n_chunks_total: int n_train_chunks: int n_val_chunks: int class CompoundChunkDataset(Dataset): """Returns (input, target) pairs of shape (block_size-1, N_AXES) each.""" def __init__(self, chunks: Sequence[torch.Tensor]) -> None: self._chunks = list(chunks) def __len__(self) -> int: return len(self._chunks) def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: chunk = self._chunks[idx] return chunk[:-1], chunk[1:] def _eos_pad_step() -> List[int]: """A PAD step used to fill block boundaries.""" s = list(SENTINELS) s[0] = STEP_PAD return s def _bos_step_separator() -> List[int]: """A BOS step inserted between concatenated piece sequences.""" s = list(SENTINELS) s[0] = STEP_BOS return s def load_encoded_compound_sequences( sample_dir: Path, ) -> tuple[List[List[List[int]]], int]: paths = sorted(sample_dir.rglob("*.mid")) + sorted(sample_dir.rglob("*.midi")) sequences: List[List[List[int]]] = [] n_failed = 0 for p in paths: try: pm = pretty_midi.PrettyMIDI(str(p)) sequences.append(encode_compound(pm)) except Exception: n_failed += 1 return sequences, n_failed def concat_sequences(seqs: Sequence[Sequence[Sequence[int]]]) -> List[List[int]]: """Concatenate compound sequences with a BOS separator between pieces.""" flat: List[List[int]] = [] sep = _bos_step_separator() for i, seq in enumerate(seqs): flat.extend(seq) if i < len(seqs) - 1: flat.append(sep) return flat def chunk_compound_stream(stream: Sequence[Sequence[int]], block_size: int) -> List[torch.Tensor]: n_chunks = len(stream) // block_size chunks: List[torch.Tensor] = [] for i in range(n_chunks): block = stream[i * block_size:(i + 1) * block_size] chunks.append(torch.tensor(block, dtype=torch.long)) return chunks def split_chunks( chunks: Sequence[torch.Tensor], split_ratio: float, seed: int, ) -> tuple[List[torch.Tensor], List[torch.Tensor]]: rng = random.Random(seed) indices = list(range(len(chunks))) rng.shuffle(indices) n_train = int(len(indices) * split_ratio) train = [chunks[i] for i in indices[:n_train]] val = [chunks[i] for i in indices[n_train:]] return train, val def build_compound_dataloaders( sample_dir: Path | None = None, block_size: int = DEFAULT_BLOCK_SIZE, batch_size: int = DEFAULT_BATCH_SIZE, split_ratio: float = DEFAULT_SPLIT_RATIO, seed: int = DEFAULT_SEED, ) -> tuple[DataLoader, DataLoader, CompoundDatasetStats]: if sample_dir is None: root = Path(__file__).resolve().parent.parent sample_dir = root / "data" / "gigamidi" / "sample" if not sample_dir.exists(): raise FileNotFoundError(f"Sample dir not found: {sample_dir}") midi_paths = sorted(sample_dir.rglob("*.mid")) + sorted(sample_dir.rglob("*.midi")) sequences, n_failed = load_encoded_compound_sequences(sample_dir) stream = concat_sequences(sequences) if stream: max_per_axis = [max(s[a] for s in stream) for a in range(N_AXES)] for a, mx in enumerate(max_per_axis): assert mx < AXIS_SIZES[a], ( f"axis {a} value {mx} exceeds size {AXIS_SIZES[a]}" ) chunks = chunk_compound_stream(stream, block_size) train_chunks, val_chunks = split_chunks(chunks, split_ratio, seed) train_ds = CompoundChunkDataset(train_chunks) val_ds = CompoundChunkDataset(val_chunks) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False) stats = CompoundDatasetStats( n_files_seen=len(midi_paths), n_files_encoded=len(sequences), n_files_failed=n_failed, n_steps_total=len(stream), n_chunks_total=len(chunks), n_train_chunks=len(train_chunks), n_val_chunks=len(val_chunks), ) return train_loader, val_loader, stats if __name__ == "__main__": train_loader, val_loader, stats = build_compound_dataloaders() print( f"[compound_dataset] files seen/encoded/failed: " f"{stats.n_files_seen}/{stats.n_files_encoded}/{stats.n_files_failed}" ) print( f"[compound_dataset] steps={stats.n_steps_total} " f"chunks={stats.n_chunks_total} train={stats.n_train_chunks} val={stats.n_val_chunks}" ) print(f"[compound_dataset] axis sizes = {AXIS_SIZES}")