File size: 2,962 Bytes
53f0cc2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | """
Memory-efficient dataset utilities for tokenized JSONL training data.
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Iterator, List, Tuple
import torch
from torch.utils.data import Dataset
class TokenizedJsonlDataset(Dataset):
"""
Random-access dataset over tokenized JSONL using line byte offsets.
This avoids loading all samples into RAM.
"""
def __init__(self, path: str, split: str = "train", val_ratio: float = 0.02, split_seed: int = 17) -> None:
self.path = Path(path)
if not self.path.exists():
raise FileNotFoundError(f"Tokenized dataset not found: {self.path}")
self.split = split
self.val_ratio = val_ratio
self.split_seed = split_seed
self.offsets: List[int] = []
self._build_offsets()
def _hash_to_split(self, idx: int) -> bool:
# Deterministic split using index so train/val is stable across runs.
h = (idx * 1103515245 + self.split_seed) & 0x7FFFFFFF
p = (h % 10_000) / 10_000.0
return p < self.val_ratio
def _build_offsets(self) -> None:
with self.path.open("rb") as f:
idx = 0
while True:
offset = f.tell()
line = f.readline()
if not line:
break
if self.split == "val":
keep = self._hash_to_split(idx)
else:
keep = not self._hash_to_split(idx)
if keep:
self.offsets.append(offset)
idx += 1
def __len__(self) -> int:
return len(self.offsets)
def __getitem__(self, index: int) -> List[int]:
offset = self.offsets[index]
with self.path.open("rb") as f:
f.seek(offset)
line = f.readline().decode("utf-8").strip()
row = json.loads(line)
ids = row.get("input_ids")
if not isinstance(ids, list) or not ids:
raise ValueError(f"Invalid input_ids at index {index}")
return [int(x) for x in ids]
class CausalCollator:
"""
Pads/truncates sequences and produces labels for next-token training.
"""
def __init__(self, pad_token_id: int = 0, max_seq_len: int = 512) -> None:
self.pad_token_id = pad_token_id
self.max_seq_len = max_seq_len
def __call__(self, batch: List[List[int]]) -> Tuple[torch.Tensor, torch.Tensor]:
clipped = [x[: self.max_seq_len] for x in batch]
max_len = max(len(x) for x in clipped)
input_ids = []
labels = []
for seq in clipped:
pad_len = max_len - len(seq)
padded = seq + [self.pad_token_id] * pad_len
label = seq + [-100] * pad_len
input_ids.append(padded)
labels.append(label)
return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
|