File size: 2,962 Bytes
53f0cc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""
Memory-efficient dataset utilities for tokenized JSONL training data.
"""

from __future__ import annotations

import json
from pathlib import Path
from typing import Iterator, List, Tuple

import torch
from torch.utils.data import Dataset


class TokenizedJsonlDataset(Dataset):
    """
    Random-access dataset over tokenized JSONL using line byte offsets.
    This avoids loading all samples into RAM.
    """

    def __init__(self, path: str, split: str = "train", val_ratio: float = 0.02, split_seed: int = 17) -> None:
        self.path = Path(path)
        if not self.path.exists():
            raise FileNotFoundError(f"Tokenized dataset not found: {self.path}")
        self.split = split
        self.val_ratio = val_ratio
        self.split_seed = split_seed
        self.offsets: List[int] = []
        self._build_offsets()

    def _hash_to_split(self, idx: int) -> bool:
        # Deterministic split using index so train/val is stable across runs.
        h = (idx * 1103515245 + self.split_seed) & 0x7FFFFFFF
        p = (h % 10_000) / 10_000.0
        return p < self.val_ratio

    def _build_offsets(self) -> None:
        with self.path.open("rb") as f:
            idx = 0
            while True:
                offset = f.tell()
                line = f.readline()
                if not line:
                    break
                if self.split == "val":
                    keep = self._hash_to_split(idx)
                else:
                    keep = not self._hash_to_split(idx)
                if keep:
                    self.offsets.append(offset)
                idx += 1

    def __len__(self) -> int:
        return len(self.offsets)

    def __getitem__(self, index: int) -> List[int]:
        offset = self.offsets[index]
        with self.path.open("rb") as f:
            f.seek(offset)
            line = f.readline().decode("utf-8").strip()
        row = json.loads(line)
        ids = row.get("input_ids")
        if not isinstance(ids, list) or not ids:
            raise ValueError(f"Invalid input_ids at index {index}")
        return [int(x) for x in ids]


class CausalCollator:
    """
    Pads/truncates sequences and produces labels for next-token training.
    """

    def __init__(self, pad_token_id: int = 0, max_seq_len: int = 512) -> None:
        self.pad_token_id = pad_token_id
        self.max_seq_len = max_seq_len

    def __call__(self, batch: List[List[int]]) -> Tuple[torch.Tensor, torch.Tensor]:
        clipped = [x[: self.max_seq_len] for x in batch]
        max_len = max(len(x) for x in clipped)
        input_ids = []
        labels = []
        for seq in clipped:
            pad_len = max_len - len(seq)
            padded = seq + [self.pad_token_id] * pad_len
            label = seq + [-100] * pad_len
            input_ids.append(padded)
            labels.append(label)
        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)