| """ |
| Memory-efficient dataset utilities for tokenized JSONL training data. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
| from typing import Iterator, List, Tuple |
|
|
| import torch |
| from torch.utils.data import Dataset |
|
|
|
|
| class TokenizedJsonlDataset(Dataset): |
| """ |
| Random-access dataset over tokenized JSONL using line byte offsets. |
| This avoids loading all samples into RAM. |
| """ |
|
|
| def __init__(self, path: str, split: str = "train", val_ratio: float = 0.02, split_seed: int = 17) -> None: |
| self.path = Path(path) |
| if not self.path.exists(): |
| raise FileNotFoundError(f"Tokenized dataset not found: {self.path}") |
| self.split = split |
| self.val_ratio = val_ratio |
| self.split_seed = split_seed |
| self.offsets: List[int] = [] |
| self._build_offsets() |
|
|
| def _hash_to_split(self, idx: int) -> bool: |
| |
| h = (idx * 1103515245 + self.split_seed) & 0x7FFFFFFF |
| p = (h % 10_000) / 10_000.0 |
| return p < self.val_ratio |
|
|
| def _build_offsets(self) -> None: |
| with self.path.open("rb") as f: |
| idx = 0 |
| while True: |
| offset = f.tell() |
| line = f.readline() |
| if not line: |
| break |
| if self.split == "val": |
| keep = self._hash_to_split(idx) |
| else: |
| keep = not self._hash_to_split(idx) |
| if keep: |
| self.offsets.append(offset) |
| idx += 1 |
|
|
| def __len__(self) -> int: |
| return len(self.offsets) |
|
|
| def __getitem__(self, index: int) -> List[int]: |
| offset = self.offsets[index] |
| with self.path.open("rb") as f: |
| f.seek(offset) |
| line = f.readline().decode("utf-8").strip() |
| row = json.loads(line) |
| ids = row.get("input_ids") |
| if not isinstance(ids, list) or not ids: |
| raise ValueError(f"Invalid input_ids at index {index}") |
| return [int(x) for x in ids] |
|
|
|
|
| class CausalCollator: |
| """ |
| Pads/truncates sequences and produces labels for next-token training. |
| """ |
|
|
| def __init__(self, pad_token_id: int = 0, max_seq_len: int = 512) -> None: |
| self.pad_token_id = pad_token_id |
| self.max_seq_len = max_seq_len |
|
|
| def __call__(self, batch: List[List[int]]) -> Tuple[torch.Tensor, torch.Tensor]: |
| clipped = [x[: self.max_seq_len] for x in batch] |
| max_len = max(len(x) for x in clipped) |
| input_ids = [] |
| labels = [] |
| for seq in clipped: |
| pad_len = max_len - len(seq) |
| padded = seq + [self.pad_token_id] * pad_len |
| label = seq + [-100] * pad_len |
| input_ids.append(padded) |
| labels.append(label) |
| return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long) |
|
|