File size: 7,402 Bytes
7f974df | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | """
data/dataloader.py
Streaming dataloader for the pre-tokenized binary shards produced by
tokenizer/tokenize_dataset.py.
Each shard is a flat binary file of np.uint16 token IDs.
100M tokens * 2 bytes = ~200MB per shard.
Strategy:
1. Discover all shards matching split name (train/val).
2. Shuffle shard order at start of each epoch.
3. For each shard, load it (memmap or full) and yield non-overlapping
chunks of (context_length + 1) tokens.
4. Inputs = chunk[:-1] (length context_length)
Targets = chunk[1:] (length context_length, shifted right by 1)
When no data shards exist yet (tokenization not done), a SyntheticShard
can be used for architecture testing.
"""
import os
import glob
import random
import numpy as np
import torch
from torch.utils.data import IterableDataset, DataLoader
# ------------------------------------------------------------------ #
# SHARD DISCOVERY
# ------------------------------------------------------------------ #
def find_shards(data_dir: str, split: str) -> list[str]:
"""
Returns sorted list of shard paths for the given split.
Args:
data_dir : directory containing .bin shard files
split : 'train' or 'val'
"""
pattern = os.path.join(data_dir, f"{split}_*.bin")
shards = sorted(glob.glob(pattern))
return shards
# ------------------------------------------------------------------ #
# ITERABLE DATASET
# ------------------------------------------------------------------ #
class ShardedTokenDataset(IterableDataset):
"""
IterableDataset that streams token chunks from binary shards.
Each worker processes a disjoint subset of shards so we get
proper parallelism with DataLoader(num_workers=N).
Usage:
dataset = ShardedTokenDataset(data_dir, split='train', context_length=1024)
loader = DataLoader(dataset, batch_size=4)
for input_ids, targets in loader:
...
"""
def __init__(
self,
data_dir: str,
split: str,
context_length: int,
shuffle_shards: bool = True,
):
"""
Args:
data_dir : path to directory with .bin shard files
split : 'train' or 'val'
context_length : sequence length (model context length)
shuffle_shards : shuffle shard order each epoch (train only)
"""
super().__init__()
self.context_length = context_length
self.shuffle_shards = shuffle_shards
self.shards = find_shards(data_dir, split)
if not self.shards:
raise FileNotFoundError(
f"No {split} shards found in {data_dir}.\n"
f"Run tokenizer/tokenize_dataset.py first to generate data."
)
print(f"[DataLoader] Found {len(self.shards)} {split} shards in {data_dir}")
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
shards = self.shards.copy()
if self.shuffle_shards:
random.shuffle(shards)
# Split shards across workers
if worker_info is not None:
shards = shards[worker_info.id :: worker_info.num_workers]
chunk = self.context_length + 1 # +1 so we can shift for targets
for shard_path in shards:
# Load shard as uint16 array
tokens = np.fromfile(shard_path, dtype=np.uint16).astype(np.int32)
# Yield non-overlapping chunks
n_chunks = len(tokens) // chunk
for i in range(n_chunks):
start = i * chunk
seq = torch.from_numpy(tokens[start : start + chunk].copy())
input_ids = seq[:-1].long() # (context_length,)
targets = seq[1:].long() # (context_length,)
yield input_ids, targets
# ------------------------------------------------------------------ #
# SYNTHETIC DATASET (for testing without real data)
# ------------------------------------------------------------------ #
class SyntheticDataset(IterableDataset):
"""
Generates random token sequences for architecture testing.
Use when real shards are not yet available.
"""
def __init__(self, vocab_size: int, context_length: int, n_batches: int = 1000):
super().__init__()
self.vocab_size = vocab_size
self.context_length = context_length
self.n_batches = n_batches
def __iter__(self):
for _ in range(self.n_batches):
seq = torch.randint(0, self.vocab_size, (self.context_length + 1,))
input_ids = seq[:-1]
targets = seq[1:]
yield input_ids, targets
# ------------------------------------------------------------------ #
# FACTORY FUNCTION
# ------------------------------------------------------------------ #
def build_dataloader(
data_dir: str,
split: str,
context_length: int,
batch_size: int,
num_workers: int = 2,
use_synthetic: bool = False,
vocab_size: int = 32_000,
) -> DataLoader:
"""
Builds and returns a DataLoader for the given split.
Falls back to SyntheticDataset if use_synthetic=True or no shards found.
Args:
data_dir : directory with .bin shards
split : 'train' or 'val'
context_length : model context length (1024)
batch_size : number of sequences per batch
num_workers : DataLoader workers (0 = main process)
use_synthetic : force synthetic data (for testing)
vocab_size : needed for synthetic fallback
Returns:
DataLoader yielding (input_ids, targets) each of shape (B, T)
"""
if use_synthetic:
dataset = SyntheticDataset(vocab_size, context_length)
print(f"[DataLoader] Using synthetic data (use_synthetic=True)")
else:
try:
dataset = ShardedTokenDataset(
data_dir = data_dir,
split = split,
context_length = context_length,
shuffle_shards = (split == "train"),
)
except FileNotFoundError as e:
print(f"[DataLoader] WARNING: {e}")
print(f"[DataLoader] Falling back to synthetic data for testing.")
dataset = SyntheticDataset(vocab_size, context_length)
return DataLoader(
dataset,
batch_size = batch_size,
num_workers = num_workers,
pin_memory = True, # faster CPU->GPU transfer
)
# ------------------------------------------------------------------ #
# QUICK CHECK
# ------------------------------------------------------------------ #
if __name__ == "__main__":
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model.config import SLLM_100M
cfg = SLLM_100M
print("Testing with synthetic data...")
loader = build_dataloader(
data_dir = "tokenizer/data",
split = "train",
context_length = cfg.context_length,
batch_size = 4,
num_workers = 0,
use_synthetic = True,
vocab_size = cfg.vocab_size,
)
for i, (x, y) in enumerate(loader):
print(f"Batch {i}: input_ids={x.shape}, targets={y.shape}, dtype={x.dtype}")
if i == 3:
break
print("DataLoader OK")
|