"""Caption-conditioned MIDI dataset and dataloaders. Reads records from captions_llm.jsonl and returns fixed-length MIDI token windows plus text captions for conditioning. """ from __future__ import annotations import json import random from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional import pretty_midi import torch from torch.utils.data import DataLoader, Dataset, get_worker_info from compound import SENTINELS, STEP_PAD, encode_compound from tokenizer import encode DEFAULT_MAX_SEQ_LEN = 512 DEFAULT_BATCH_SIZE = 64 DEFAULT_SPLIT_RATIO = 0.95 DEFAULT_SEED = 17 DEFAULT_NUM_WORKERS = 4 @dataclass class CaptionDatasetStats: n_total_records: int n_train_records: int n_val_records: int def _load_jsonl_records(jsonl_path: Path) -> List[Dict[str, Any]]: records: List[Dict[str, Any]] = [] with open(jsonl_path) as f: for line in f: line = line.strip() if not line: continue records.append(json.loads(line)) return records def _caption_from_record(record: Dict[str, Any]) -> str: llm_caption = str(record.get("caption", "")).strip() if llm_caption: return llm_caption return str(record.get("caption_template", "")).strip() class MidiCaptionDataset(Dataset): """Dataset yielding (midi_tokens, attention_mask, caption, path).""" def __init__( self, records: List[Dict[str, Any]], is_train: bool, max_seq_len: int = DEFAULT_MAX_SEQ_LEN, seed: int = DEFAULT_SEED, ) -> None: self.records = records self.is_train = is_train self.max_seq_len = max_seq_len self._seed = seed self._rng = random.Random(seed) self._token_cache: Dict[int, Optional[List[int]]] = {} self._valid_indices: set[int] = set() self._n_tokenize_failures = 0 self._n_short_records = 0 self._did_full_scan = False def __len__(self) -> int: return len(self.records) def _tokenize_record(self, idx: int) -> Optional[List[int]]: if idx in self._token_cache: return self._token_cache[idx] rec = self.records[idx] midi_path = Path(str(rec.get("path", ""))) try: pm = pretty_midi.PrettyMIDI(str(midi_path)) ids = encode(pm) except Exception: self._n_tokenize_failures += 1 self._token_cache[idx] = None return None if len(ids) < self.max_seq_len: self._n_short_records += 1 self._token_cache[idx] = ids self._valid_indices.add(idx) return ids def _crop(self, ids: List[int]) -> List[int]: if self.is_train: max_start = len(ids) - self.max_seq_len start = 0 if max_start <= 0 else self._rng.randint(0, max_start) return ids[start:start + self.max_seq_len] return ids[: self.max_seq_len] def _window_with_mask(self, ids: List[int]) -> tuple[List[int], List[int]]: window = self._crop(ids) valid_len = len(window) if valid_len < self.max_seq_len: window = window + [0] * (self.max_seq_len - valid_len) attention_mask = [1] * valid_len + [0] * (self.max_seq_len - valid_len) return window, attention_mask def _log_scan_summary(self) -> None: n_cached_valid = len(self._valid_indices) n_cached_invalid = sum( 1 for v in self._token_cache.values() if v is None ) print( "[caption_dataloader] scan summary: " f"records={len(self.records)} valid={n_cached_valid} " f"tokenize_failures={self._n_tokenize_failures} " f"shorter_than_max_seq_len={self._n_short_records} " f"invalid={n_cached_invalid}" ) def __getitem__(self, idx: int) -> Dict[str, Any]: ids = self._tokenize_record(idx) # If this file cannot be parsed/tokenized, sample another valid item. if ids is None: if not self._valid_indices: # Warmup pass to discover all valid indices and diagnostics. for probe_idx in range(len(self.records)): self._tokenize_record(probe_idx) self._did_full_scan = True self._log_scan_summary() if not self._valid_indices: raise RuntimeError( "No valid MIDI records found that meet max_seq_len." ) replacement_idx = self._rng.choice(list(self._valid_indices)) ids = self._tokenize_record(replacement_idx) rec = self.records[replacement_idx] else: rec = self.records[idx] assert ids is not None if ( not self._did_full_scan and len(self._token_cache) == len(self.records) ): self._did_full_scan = True self._log_scan_summary() window, attention_mask = self._window_with_mask(ids) caption = _caption_from_record(rec) input_ids = torch.tensor(window, dtype=torch.long) attention_mask_t = torch.tensor(attention_mask, dtype=torch.long) return { "input_ids": input_ids, "attention_mask": attention_mask_t, "caption": caption, "path": str(rec.get("path", "")), } class MidiCompoundCaptionDataset(Dataset): """Dataset yielding compound MIDI steps + caption metadata.""" def __init__( self, records: List[Dict[str, Any]], is_train: bool, max_seq_len: int = DEFAULT_MAX_SEQ_LEN, seed: int = DEFAULT_SEED, ) -> None: self.records = records self.is_train = is_train self.max_seq_len = max_seq_len self._seed = seed self._rng = random.Random(seed) self._token_cache: Dict[int, Optional[List[List[int]]]] = {} self._valid_indices: set[int] = set() self._n_tokenize_failures = 0 self._n_short_records = 0 self._did_full_scan = False self._pad_step = list(SENTINELS) self._pad_step[0] = STEP_PAD def __len__(self) -> int: return len(self.records) def _tokenize_record(self, idx: int) -> Optional[List[List[int]]]: if idx in self._token_cache: return self._token_cache[idx] rec = self.records[idx] midi_path = Path(str(rec.get("path", ""))) try: pm = pretty_midi.PrettyMIDI(str(midi_path)) steps = encode_compound(pm) except Exception: self._n_tokenize_failures += 1 self._token_cache[idx] = None return None if len(steps) < self.max_seq_len: self._n_short_records += 1 self._token_cache[idx] = steps self._valid_indices.add(idx) return steps def _crop( self, steps: List[List[int]] ) -> List[List[int]]: if self.is_train: max_start = len(steps) - self.max_seq_len start = 0 if max_start <= 0 else self._rng.randint(0, max_start) return steps[start:start + self.max_seq_len] return steps[: self.max_seq_len] def _window_with_mask( self, steps: List[List[int]] ) -> tuple[List[List[int]], List[int]]: window = self._crop(steps) valid_len = len(window) if valid_len < self.max_seq_len: window = window + [self._pad_step] * (self.max_seq_len - valid_len) attention_mask = [1] * valid_len + [0] * (self.max_seq_len - valid_len) return window, attention_mask def _log_scan_summary(self) -> None: n_cached_valid = len(self._valid_indices) n_cached_invalid = sum( 1 for v in self._token_cache.values() if v is None ) print( "[compound_caption_dataloader] scan summary: " f"records={len(self.records)} valid={n_cached_valid} " f"tokenize_failures={self._n_tokenize_failures} " f"shorter_than_max_seq_len={self._n_short_records} " f"invalid={n_cached_invalid}" ) def __getitem__(self, idx: int) -> Dict[str, Any]: steps = self._tokenize_record(idx) if steps is None: if not self._valid_indices: for probe_idx in range(len(self.records)): self._tokenize_record(probe_idx) self._did_full_scan = True self._log_scan_summary() if not self._valid_indices: raise RuntimeError( "No valid MIDI records found that meet max_seq_len." ) replacement_idx = self._rng.choice(list(self._valid_indices)) steps = self._tokenize_record(replacement_idx) rec = self.records[replacement_idx] else: rec = self.records[idx] assert steps is not None if ( not self._did_full_scan and len(self._token_cache) == len(self.records) ): self._did_full_scan = True self._log_scan_summary() window, attention_mask = self._window_with_mask(steps) caption = _caption_from_record(rec) compound_input = torch.tensor(window, dtype=torch.long) attention_mask_t = torch.tensor(attention_mask, dtype=torch.long) return { "compound_input": compound_input, "attention_mask": attention_mask_t, "caption": caption, "path": str(rec.get("path", "")), } def _collate_caption_batch(items: List[Dict[str, Any]]) -> Dict[str, Any]: return { "input_ids": torch.stack([x["input_ids"] for x in items], dim=0), "attention_mask": torch.stack( [x["attention_mask"] for x in items], dim=0 ), "captions": [x["caption"] for x in items], "paths": [x["path"] for x in items], } def _collate_compound_caption_batch(items: List[Dict[str, Any]]) -> Dict[str, Any]: return { "compound_input": torch.stack([x["compound_input"] for x in items], dim=0), "attention_mask": torch.stack( [x["attention_mask"] for x in items], dim=0 ), "captions": [x["caption"] for x in items], "paths": [x["path"] for x in items], } def _seed_dataset_worker(worker_id: int) -> None: info = get_worker_info() if info is None: return ds = info.dataset if isinstance(ds, (MidiCaptionDataset, MidiCompoundCaptionDataset)): ds._rng.seed(ds._seed + worker_id) def build_caption_dataloaders( jsonl_path: Path | str, max_seq_len: int = DEFAULT_MAX_SEQ_LEN, batch_size: int = DEFAULT_BATCH_SIZE, split_ratio: float = DEFAULT_SPLIT_RATIO, seed: int = DEFAULT_SEED, num_workers: int = DEFAULT_NUM_WORKERS, pin_memory: Optional[bool] = None, ) -> tuple[DataLoader, DataLoader, CaptionDatasetStats]: """Build train/val DataLoaders from captions_llm.jsonl records. - 95/5 split with fixed random seed - train random crop to max_seq_len - val deterministic first-window crop - drop_last=True for train to keep fixed contrastive matrix shape """ path = Path(jsonl_path) if not path.exists(): raise FileNotFoundError(f"JSONL not found: {path}") if not 0.0 < split_ratio < 1.0: raise ValueError("split_ratio must be between 0 and 1.") records = _load_jsonl_records(path) if not records: raise ValueError(f"No records found in {path}") rng = random.Random(seed) indices = list(range(len(records))) rng.shuffle(indices) n_train = int(len(indices) * split_ratio) train_indices = indices[:n_train] val_indices = indices[n_train:] train_records = [records[i] for i in train_indices] val_records = [records[i] for i in val_indices] if pin_memory is None: pin_memory = torch.cuda.is_available() train_ds = MidiCaptionDataset( records=train_records, is_train=True, max_seq_len=max_seq_len, seed=seed, ) val_ds = MidiCaptionDataset( records=val_records, is_train=False, max_seq_len=max_seq_len, seed=seed + 1, ) train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=pin_memory, collate_fn=_collate_caption_batch, worker_init_fn=_seed_dataset_worker, ) val_loader = DataLoader( val_ds, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=pin_memory, collate_fn=_collate_caption_batch, worker_init_fn=_seed_dataset_worker, ) stats = CaptionDatasetStats( n_total_records=len(records), n_train_records=len(train_records), n_val_records=len(val_records), ) return train_loader, val_loader, stats def build_compound_caption_dataloaders( jsonl_path: Path | str, max_seq_len: int = DEFAULT_MAX_SEQ_LEN, batch_size: int = DEFAULT_BATCH_SIZE, split_ratio: float = DEFAULT_SPLIT_RATIO, seed: int = DEFAULT_SEED, num_workers: int = DEFAULT_NUM_WORKERS, pin_memory: Optional[bool] = None, ) -> tuple[DataLoader, DataLoader, CaptionDatasetStats]: """Build train/val loaders that return compound_input of shape (B, T, N_AXES).""" path = Path(jsonl_path) if not path.exists(): raise FileNotFoundError(f"JSONL not found: {path}") if not 0.0 < split_ratio < 1.0: raise ValueError("split_ratio must be between 0 and 1.") records = _load_jsonl_records(path) if not records: raise ValueError(f"No records found in {path}") rng = random.Random(seed) indices = list(range(len(records))) rng.shuffle(indices) n_train = int(len(indices) * split_ratio) train_indices = indices[:n_train] val_indices = indices[n_train:] train_records = [records[i] for i in train_indices] val_records = [records[i] for i in val_indices] if pin_memory is None: pin_memory = torch.cuda.is_available() train_ds = MidiCompoundCaptionDataset( records=train_records, is_train=True, max_seq_len=max_seq_len, seed=seed, ) val_ds = MidiCompoundCaptionDataset( records=val_records, is_train=False, max_seq_len=max_seq_len, seed=seed + 1, ) train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=pin_memory, collate_fn=_collate_compound_caption_batch, worker_init_fn=_seed_dataset_worker, ) val_loader = DataLoader( val_ds, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=pin_memory, collate_fn=_collate_compound_caption_batch, worker_init_fn=_seed_dataset_worker, ) stats = CaptionDatasetStats( n_total_records=len(records), n_train_records=len(train_records), n_val_records=len(val_records), ) return train_loader, val_loader, stats