""" Memory-efficient dataset utilities for tokenized JSONL training data. """ from __future__ import annotations import json from pathlib import Path from typing import Iterator, List, Tuple import torch from torch.utils.data import Dataset class TokenizedJsonlDataset(Dataset): """ Random-access dataset over tokenized JSONL using line byte offsets. This avoids loading all samples into RAM. """ def __init__(self, path: str, split: str = "train", val_ratio: float = 0.02, split_seed: int = 17) -> None: self.path = Path(path) if not self.path.exists(): raise FileNotFoundError(f"Tokenized dataset not found: {self.path}") self.split = split self.val_ratio = val_ratio self.split_seed = split_seed self.offsets: List[int] = [] self._build_offsets() def _hash_to_split(self, idx: int) -> bool: # Deterministic split using index so train/val is stable across runs. h = (idx * 1103515245 + self.split_seed) & 0x7FFFFFFF p = (h % 10_000) / 10_000.0 return p < self.val_ratio def _build_offsets(self) -> None: with self.path.open("rb") as f: idx = 0 while True: offset = f.tell() line = f.readline() if not line: break if self.split == "val": keep = self._hash_to_split(idx) else: keep = not self._hash_to_split(idx) if keep: self.offsets.append(offset) idx += 1 def __len__(self) -> int: return len(self.offsets) def __getitem__(self, index: int) -> List[int]: offset = self.offsets[index] with self.path.open("rb") as f: f.seek(offset) line = f.readline().decode("utf-8").strip() row = json.loads(line) ids = row.get("input_ids") if not isinstance(ids, list) or not ids: raise ValueError(f"Invalid input_ids at index {index}") return [int(x) for x in ids] class CausalCollator: """ Pads/truncates sequences and produces labels for next-token training. """ def __init__(self, pad_token_id: int = 0, max_seq_len: int = 512) -> None: self.pad_token_id = pad_token_id self.max_seq_len = max_seq_len def __call__(self, batch: List[List[int]]) -> Tuple[torch.Tensor, torch.Tensor]: clipped = [x[: self.max_seq_len] for x in batch] max_len = max(len(x) for x in clipped) input_ids = [] labels = [] for seq in clipped: pad_len = max_len - len(seq) padded = seq + [self.pad_token_id] * pad_len label = seq + [-100] * pad_len input_ids.append(padded) labels.append(label) return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)