Coda / src /caption_dataloader.py
Prajanya Gupta
initial deploy
6b7b403
"""Caption-conditioned MIDI dataset and dataloaders.
Reads records from captions_llm.jsonl and returns fixed-length MIDI token
windows plus text captions for conditioning.
"""
from __future__ import annotations
import json
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional
import pretty_midi
import torch
from torch.utils.data import DataLoader, Dataset, get_worker_info
from compound import SENTINELS, STEP_PAD, encode_compound
from tokenizer import encode
DEFAULT_MAX_SEQ_LEN = 512
DEFAULT_BATCH_SIZE = 64
DEFAULT_SPLIT_RATIO = 0.95
DEFAULT_SEED = 17
DEFAULT_NUM_WORKERS = 4
@dataclass
class CaptionDatasetStats:
n_total_records: int
n_train_records: int
n_val_records: int
def _load_jsonl_records(jsonl_path: Path) -> List[Dict[str, Any]]:
records: List[Dict[str, Any]] = []
with open(jsonl_path) as f:
for line in f:
line = line.strip()
if not line:
continue
records.append(json.loads(line))
return records
def _caption_from_record(record: Dict[str, Any]) -> str:
llm_caption = str(record.get("caption", "")).strip()
if llm_caption:
return llm_caption
return str(record.get("caption_template", "")).strip()
class MidiCaptionDataset(Dataset):
"""Dataset yielding (midi_tokens, attention_mask, caption, path)."""
def __init__(
self,
records: List[Dict[str, Any]],
is_train: bool,
max_seq_len: int = DEFAULT_MAX_SEQ_LEN,
seed: int = DEFAULT_SEED,
) -> None:
self.records = records
self.is_train = is_train
self.max_seq_len = max_seq_len
self._seed = seed
self._rng = random.Random(seed)
self._token_cache: Dict[int, Optional[List[int]]] = {}
self._valid_indices: set[int] = set()
self._n_tokenize_failures = 0
self._n_short_records = 0
self._did_full_scan = False
def __len__(self) -> int:
return len(self.records)
def _tokenize_record(self, idx: int) -> Optional[List[int]]:
if idx in self._token_cache:
return self._token_cache[idx]
rec = self.records[idx]
midi_path = Path(str(rec.get("path", "")))
try:
pm = pretty_midi.PrettyMIDI(str(midi_path))
ids = encode(pm)
except Exception:
self._n_tokenize_failures += 1
self._token_cache[idx] = None
return None
if len(ids) < self.max_seq_len:
self._n_short_records += 1
self._token_cache[idx] = ids
self._valid_indices.add(idx)
return ids
def _crop(self, ids: List[int]) -> List[int]:
if self.is_train:
max_start = len(ids) - self.max_seq_len
start = 0 if max_start <= 0 else self._rng.randint(0, max_start)
return ids[start:start + self.max_seq_len]
return ids[: self.max_seq_len]
def _window_with_mask(self, ids: List[int]) -> tuple[List[int], List[int]]:
window = self._crop(ids)
valid_len = len(window)
if valid_len < self.max_seq_len:
window = window + [0] * (self.max_seq_len - valid_len)
attention_mask = [1] * valid_len + [0] * (self.max_seq_len - valid_len)
return window, attention_mask
def _log_scan_summary(self) -> None:
n_cached_valid = len(self._valid_indices)
n_cached_invalid = sum(
1 for v in self._token_cache.values() if v is None
)
print(
"[caption_dataloader] scan summary: "
f"records={len(self.records)} valid={n_cached_valid} "
f"tokenize_failures={self._n_tokenize_failures} "
f"shorter_than_max_seq_len={self._n_short_records} "
f"invalid={n_cached_invalid}"
)
def __getitem__(self, idx: int) -> Dict[str, Any]:
ids = self._tokenize_record(idx)
# If this file cannot be parsed/tokenized, sample another valid item.
if ids is None:
if not self._valid_indices:
# Warmup pass to discover all valid indices and diagnostics.
for probe_idx in range(len(self.records)):
self._tokenize_record(probe_idx)
self._did_full_scan = True
self._log_scan_summary()
if not self._valid_indices:
raise RuntimeError(
"No valid MIDI records found that meet max_seq_len."
)
replacement_idx = self._rng.choice(list(self._valid_indices))
ids = self._tokenize_record(replacement_idx)
rec = self.records[replacement_idx]
else:
rec = self.records[idx]
assert ids is not None
if (
not self._did_full_scan
and len(self._token_cache) == len(self.records)
):
self._did_full_scan = True
self._log_scan_summary()
window, attention_mask = self._window_with_mask(ids)
caption = _caption_from_record(rec)
input_ids = torch.tensor(window, dtype=torch.long)
attention_mask_t = torch.tensor(attention_mask, dtype=torch.long)
return {
"input_ids": input_ids,
"attention_mask": attention_mask_t,
"caption": caption,
"path": str(rec.get("path", "")),
}
class MidiCompoundCaptionDataset(Dataset):
"""Dataset yielding compound MIDI steps + caption metadata."""
def __init__(
self,
records: List[Dict[str, Any]],
is_train: bool,
max_seq_len: int = DEFAULT_MAX_SEQ_LEN,
seed: int = DEFAULT_SEED,
) -> None:
self.records = records
self.is_train = is_train
self.max_seq_len = max_seq_len
self._seed = seed
self._rng = random.Random(seed)
self._token_cache: Dict[int, Optional[List[List[int]]]] = {}
self._valid_indices: set[int] = set()
self._n_tokenize_failures = 0
self._n_short_records = 0
self._did_full_scan = False
self._pad_step = list(SENTINELS)
self._pad_step[0] = STEP_PAD
def __len__(self) -> int:
return len(self.records)
def _tokenize_record(self, idx: int) -> Optional[List[List[int]]]:
if idx in self._token_cache:
return self._token_cache[idx]
rec = self.records[idx]
midi_path = Path(str(rec.get("path", "")))
try:
pm = pretty_midi.PrettyMIDI(str(midi_path))
steps = encode_compound(pm)
except Exception:
self._n_tokenize_failures += 1
self._token_cache[idx] = None
return None
if len(steps) < self.max_seq_len:
self._n_short_records += 1
self._token_cache[idx] = steps
self._valid_indices.add(idx)
return steps
def _crop(
self, steps: List[List[int]]
) -> List[List[int]]:
if self.is_train:
max_start = len(steps) - self.max_seq_len
start = 0 if max_start <= 0 else self._rng.randint(0, max_start)
return steps[start:start + self.max_seq_len]
return steps[: self.max_seq_len]
def _window_with_mask(
self, steps: List[List[int]]
) -> tuple[List[List[int]], List[int]]:
window = self._crop(steps)
valid_len = len(window)
if valid_len < self.max_seq_len:
window = window + [self._pad_step] * (self.max_seq_len - valid_len)
attention_mask = [1] * valid_len + [0] * (self.max_seq_len - valid_len)
return window, attention_mask
def _log_scan_summary(self) -> None:
n_cached_valid = len(self._valid_indices)
n_cached_invalid = sum(
1 for v in self._token_cache.values() if v is None
)
print(
"[compound_caption_dataloader] scan summary: "
f"records={len(self.records)} valid={n_cached_valid} "
f"tokenize_failures={self._n_tokenize_failures} "
f"shorter_than_max_seq_len={self._n_short_records} "
f"invalid={n_cached_invalid}"
)
def __getitem__(self, idx: int) -> Dict[str, Any]:
steps = self._tokenize_record(idx)
if steps is None:
if not self._valid_indices:
for probe_idx in range(len(self.records)):
self._tokenize_record(probe_idx)
self._did_full_scan = True
self._log_scan_summary()
if not self._valid_indices:
raise RuntimeError(
"No valid MIDI records found that meet max_seq_len."
)
replacement_idx = self._rng.choice(list(self._valid_indices))
steps = self._tokenize_record(replacement_idx)
rec = self.records[replacement_idx]
else:
rec = self.records[idx]
assert steps is not None
if (
not self._did_full_scan
and len(self._token_cache) == len(self.records)
):
self._did_full_scan = True
self._log_scan_summary()
window, attention_mask = self._window_with_mask(steps)
caption = _caption_from_record(rec)
compound_input = torch.tensor(window, dtype=torch.long)
attention_mask_t = torch.tensor(attention_mask, dtype=torch.long)
return {
"compound_input": compound_input,
"attention_mask": attention_mask_t,
"caption": caption,
"path": str(rec.get("path", "")),
}
def _collate_caption_batch(items: List[Dict[str, Any]]) -> Dict[str, Any]:
return {
"input_ids": torch.stack([x["input_ids"] for x in items], dim=0),
"attention_mask": torch.stack(
[x["attention_mask"] for x in items], dim=0
),
"captions": [x["caption"] for x in items],
"paths": [x["path"] for x in items],
}
def _collate_compound_caption_batch(items: List[Dict[str, Any]]) -> Dict[str, Any]:
return {
"compound_input": torch.stack([x["compound_input"] for x in items], dim=0),
"attention_mask": torch.stack(
[x["attention_mask"] for x in items], dim=0
),
"captions": [x["caption"] for x in items],
"paths": [x["path"] for x in items],
}
def _seed_dataset_worker(worker_id: int) -> None:
info = get_worker_info()
if info is None:
return
ds = info.dataset
if isinstance(ds, (MidiCaptionDataset, MidiCompoundCaptionDataset)):
ds._rng.seed(ds._seed + worker_id)
def build_caption_dataloaders(
jsonl_path: Path | str,
max_seq_len: int = DEFAULT_MAX_SEQ_LEN,
batch_size: int = DEFAULT_BATCH_SIZE,
split_ratio: float = DEFAULT_SPLIT_RATIO,
seed: int = DEFAULT_SEED,
num_workers: int = DEFAULT_NUM_WORKERS,
pin_memory: Optional[bool] = None,
) -> tuple[DataLoader, DataLoader, CaptionDatasetStats]:
"""Build train/val DataLoaders from captions_llm.jsonl records.
- 95/5 split with fixed random seed
- train random crop to max_seq_len
- val deterministic first-window crop
- drop_last=True for train to keep fixed contrastive matrix shape
"""
path = Path(jsonl_path)
if not path.exists():
raise FileNotFoundError(f"JSONL not found: {path}")
if not 0.0 < split_ratio < 1.0:
raise ValueError("split_ratio must be between 0 and 1.")
records = _load_jsonl_records(path)
if not records:
raise ValueError(f"No records found in {path}")
rng = random.Random(seed)
indices = list(range(len(records)))
rng.shuffle(indices)
n_train = int(len(indices) * split_ratio)
train_indices = indices[:n_train]
val_indices = indices[n_train:]
train_records = [records[i] for i in train_indices]
val_records = [records[i] for i in val_indices]
if pin_memory is None:
pin_memory = torch.cuda.is_available()
train_ds = MidiCaptionDataset(
records=train_records,
is_train=True,
max_seq_len=max_seq_len,
seed=seed,
)
val_ds = MidiCaptionDataset(
records=val_records,
is_train=False,
max_seq_len=max_seq_len,
seed=seed + 1,
)
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
drop_last=True,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=_collate_caption_batch,
worker_init_fn=_seed_dataset_worker,
)
val_loader = DataLoader(
val_ds,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=_collate_caption_batch,
worker_init_fn=_seed_dataset_worker,
)
stats = CaptionDatasetStats(
n_total_records=len(records),
n_train_records=len(train_records),
n_val_records=len(val_records),
)
return train_loader, val_loader, stats
def build_compound_caption_dataloaders(
jsonl_path: Path | str,
max_seq_len: int = DEFAULT_MAX_SEQ_LEN,
batch_size: int = DEFAULT_BATCH_SIZE,
split_ratio: float = DEFAULT_SPLIT_RATIO,
seed: int = DEFAULT_SEED,
num_workers: int = DEFAULT_NUM_WORKERS,
pin_memory: Optional[bool] = None,
) -> tuple[DataLoader, DataLoader, CaptionDatasetStats]:
"""Build train/val loaders that return compound_input of shape (B, T, N_AXES)."""
path = Path(jsonl_path)
if not path.exists():
raise FileNotFoundError(f"JSONL not found: {path}")
if not 0.0 < split_ratio < 1.0:
raise ValueError("split_ratio must be between 0 and 1.")
records = _load_jsonl_records(path)
if not records:
raise ValueError(f"No records found in {path}")
rng = random.Random(seed)
indices = list(range(len(records)))
rng.shuffle(indices)
n_train = int(len(indices) * split_ratio)
train_indices = indices[:n_train]
val_indices = indices[n_train:]
train_records = [records[i] for i in train_indices]
val_records = [records[i] for i in val_indices]
if pin_memory is None:
pin_memory = torch.cuda.is_available()
train_ds = MidiCompoundCaptionDataset(
records=train_records,
is_train=True,
max_seq_len=max_seq_len,
seed=seed,
)
val_ds = MidiCompoundCaptionDataset(
records=val_records,
is_train=False,
max_seq_len=max_seq_len,
seed=seed + 1,
)
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
drop_last=True,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=_collate_compound_caption_batch,
worker_init_fn=_seed_dataset_worker,
)
val_loader = DataLoader(
val_ds,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=_collate_compound_caption_batch,
worker_init_fn=_seed_dataset_worker,
)
stats = CaptionDatasetStats(
n_total_records=len(records),
n_train_records=len(train_records),
n_val_records=len(val_records),
)
return train_loader, val_loader, stats