sllm / data /dataloader.py
geeteshcodes's picture
Initial commit
7f974df verified
raw
history blame
7.4 kB
"""
data/dataloader.py
Streaming dataloader for the pre-tokenized binary shards produced by
tokenizer/tokenize_dataset.py.
Each shard is a flat binary file of np.uint16 token IDs.
100M tokens * 2 bytes = ~200MB per shard.
Strategy:
1. Discover all shards matching split name (train/val).
2. Shuffle shard order at start of each epoch.
3. For each shard, load it (memmap or full) and yield non-overlapping
chunks of (context_length + 1) tokens.
4. Inputs = chunk[:-1] (length context_length)
Targets = chunk[1:] (length context_length, shifted right by 1)
When no data shards exist yet (tokenization not done), a SyntheticShard
can be used for architecture testing.
"""
import os
import glob
import random
import numpy as np
import torch
from torch.utils.data import IterableDataset, DataLoader
# ------------------------------------------------------------------ #
# SHARD DISCOVERY
# ------------------------------------------------------------------ #
def find_shards(data_dir: str, split: str) -> list[str]:
"""
Returns sorted list of shard paths for the given split.
Args:
data_dir : directory containing .bin shard files
split : 'train' or 'val'
"""
pattern = os.path.join(data_dir, f"{split}_*.bin")
shards = sorted(glob.glob(pattern))
return shards
# ------------------------------------------------------------------ #
# ITERABLE DATASET
# ------------------------------------------------------------------ #
class ShardedTokenDataset(IterableDataset):
"""
IterableDataset that streams token chunks from binary shards.
Each worker processes a disjoint subset of shards so we get
proper parallelism with DataLoader(num_workers=N).
Usage:
dataset = ShardedTokenDataset(data_dir, split='train', context_length=1024)
loader = DataLoader(dataset, batch_size=4)
for input_ids, targets in loader:
...
"""
def __init__(
self,
data_dir: str,
split: str,
context_length: int,
shuffle_shards: bool = True,
):
"""
Args:
data_dir : path to directory with .bin shard files
split : 'train' or 'val'
context_length : sequence length (model context length)
shuffle_shards : shuffle shard order each epoch (train only)
"""
super().__init__()
self.context_length = context_length
self.shuffle_shards = shuffle_shards
self.shards = find_shards(data_dir, split)
if not self.shards:
raise FileNotFoundError(
f"No {split} shards found in {data_dir}.\n"
f"Run tokenizer/tokenize_dataset.py first to generate data."
)
print(f"[DataLoader] Found {len(self.shards)} {split} shards in {data_dir}")
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
shards = self.shards.copy()
if self.shuffle_shards:
random.shuffle(shards)
# Split shards across workers
if worker_info is not None:
shards = shards[worker_info.id :: worker_info.num_workers]
chunk = self.context_length + 1 # +1 so we can shift for targets
for shard_path in shards:
# Load shard as uint16 array
tokens = np.fromfile(shard_path, dtype=np.uint16).astype(np.int32)
# Yield non-overlapping chunks
n_chunks = len(tokens) // chunk
for i in range(n_chunks):
start = i * chunk
seq = torch.from_numpy(tokens[start : start + chunk].copy())
input_ids = seq[:-1].long() # (context_length,)
targets = seq[1:].long() # (context_length,)
yield input_ids, targets
# ------------------------------------------------------------------ #
# SYNTHETIC DATASET (for testing without real data)
# ------------------------------------------------------------------ #
class SyntheticDataset(IterableDataset):
"""
Generates random token sequences for architecture testing.
Use when real shards are not yet available.
"""
def __init__(self, vocab_size: int, context_length: int, n_batches: int = 1000):
super().__init__()
self.vocab_size = vocab_size
self.context_length = context_length
self.n_batches = n_batches
def __iter__(self):
for _ in range(self.n_batches):
seq = torch.randint(0, self.vocab_size, (self.context_length + 1,))
input_ids = seq[:-1]
targets = seq[1:]
yield input_ids, targets
# ------------------------------------------------------------------ #
# FACTORY FUNCTION
# ------------------------------------------------------------------ #
def build_dataloader(
data_dir: str,
split: str,
context_length: int,
batch_size: int,
num_workers: int = 2,
use_synthetic: bool = False,
vocab_size: int = 32_000,
) -> DataLoader:
"""
Builds and returns a DataLoader for the given split.
Falls back to SyntheticDataset if use_synthetic=True or no shards found.
Args:
data_dir : directory with .bin shards
split : 'train' or 'val'
context_length : model context length (1024)
batch_size : number of sequences per batch
num_workers : DataLoader workers (0 = main process)
use_synthetic : force synthetic data (for testing)
vocab_size : needed for synthetic fallback
Returns:
DataLoader yielding (input_ids, targets) each of shape (B, T)
"""
if use_synthetic:
dataset = SyntheticDataset(vocab_size, context_length)
print(f"[DataLoader] Using synthetic data (use_synthetic=True)")
else:
try:
dataset = ShardedTokenDataset(
data_dir = data_dir,
split = split,
context_length = context_length,
shuffle_shards = (split == "train"),
)
except FileNotFoundError as e:
print(f"[DataLoader] WARNING: {e}")
print(f"[DataLoader] Falling back to synthetic data for testing.")
dataset = SyntheticDataset(vocab_size, context_length)
return DataLoader(
dataset,
batch_size = batch_size,
num_workers = num_workers,
pin_memory = True, # faster CPU->GPU transfer
)
# ------------------------------------------------------------------ #
# QUICK CHECK
# ------------------------------------------------------------------ #
if __name__ == "__main__":
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model.config import SLLM_100M
cfg = SLLM_100M
print("Testing with synthetic data...")
loader = build_dataloader(
data_dir = "tokenizer/data",
split = "train",
context_length = cfg.context_length,
batch_size = 4,
num_workers = 0,
use_synthetic = True,
vocab_size = cfg.vocab_size,
)
for i, (x, y) in enumerate(loader):
print(f"Batch {i}: input_ids={x.shape}, targets={y.shape}, dtype={x.dtype}")
if i == 3:
break
print("DataLoader OK")