""" data/dataloader.py Streaming dataloader for the pre-tokenized binary shards produced by tokenizer/tokenize_dataset.py. Each shard is a flat binary file of np.uint16 token IDs. 100M tokens * 2 bytes = ~200MB per shard. Strategy: 1. Discover all shards matching split name (train/val). 2. Shuffle shard order at start of each epoch. 3. For each shard, load it (memmap or full) and yield non-overlapping chunks of (context_length + 1) tokens. 4. Inputs = chunk[:-1] (length context_length) Targets = chunk[1:] (length context_length, shifted right by 1) When no data shards exist yet (tokenization not done), a SyntheticShard can be used for architecture testing. """ import os import glob import random import numpy as np import torch from torch.utils.data import IterableDataset, DataLoader # ------------------------------------------------------------------ # # SHARD DISCOVERY # ------------------------------------------------------------------ # def find_shards(data_dir: str, split: str) -> list[str]: """ Returns sorted list of shard paths for the given split. Args: data_dir : directory containing .bin shard files split : 'train' or 'val' """ pattern = os.path.join(data_dir, f"{split}_*.bin") shards = sorted(glob.glob(pattern)) return shards # ------------------------------------------------------------------ # # ITERABLE DATASET # ------------------------------------------------------------------ # class ShardedTokenDataset(IterableDataset): """ IterableDataset that streams token chunks from binary shards. Each worker processes a disjoint subset of shards so we get proper parallelism with DataLoader(num_workers=N). Usage: dataset = ShardedTokenDataset(data_dir, split='train', context_length=1024) loader = DataLoader(dataset, batch_size=4) for input_ids, targets in loader: ... """ def __init__( self, data_dir: str, split: str, context_length: int, shuffle_shards: bool = True, ): """ Args: data_dir : path to directory with .bin shard files split : 'train' or 'val' context_length : sequence length (model context length) shuffle_shards : shuffle shard order each epoch (train only) """ super().__init__() self.context_length = context_length self.shuffle_shards = shuffle_shards self.shards = find_shards(data_dir, split) if not self.shards: raise FileNotFoundError( f"No {split} shards found in {data_dir}.\n" f"Run tokenizer/tokenize_dataset.py first to generate data." ) print(f"[DataLoader] Found {len(self.shards)} {split} shards in {data_dir}") def __iter__(self): worker_info = torch.utils.data.get_worker_info() shards = self.shards.copy() if self.shuffle_shards: random.shuffle(shards) # Split shards across workers if worker_info is not None: shards = shards[worker_info.id :: worker_info.num_workers] chunk = self.context_length + 1 # +1 so we can shift for targets for shard_path in shards: # Load shard as uint16 array tokens = np.fromfile(shard_path, dtype=np.uint16).astype(np.int32) # Yield non-overlapping chunks n_chunks = len(tokens) // chunk for i in range(n_chunks): start = i * chunk seq = torch.from_numpy(tokens[start : start + chunk].copy()) input_ids = seq[:-1].long() # (context_length,) targets = seq[1:].long() # (context_length,) yield input_ids, targets # ------------------------------------------------------------------ # # SYNTHETIC DATASET (for testing without real data) # ------------------------------------------------------------------ # class SyntheticDataset(IterableDataset): """ Generates random token sequences for architecture testing. Use when real shards are not yet available. """ def __init__(self, vocab_size: int, context_length: int, n_batches: int = 1000): super().__init__() self.vocab_size = vocab_size self.context_length = context_length self.n_batches = n_batches def __iter__(self): for _ in range(self.n_batches): seq = torch.randint(0, self.vocab_size, (self.context_length + 1,)) input_ids = seq[:-1] targets = seq[1:] yield input_ids, targets # ------------------------------------------------------------------ # # FACTORY FUNCTION # ------------------------------------------------------------------ # def build_dataloader( data_dir: str, split: str, context_length: int, batch_size: int, num_workers: int = 2, use_synthetic: bool = False, vocab_size: int = 32_000, ) -> DataLoader: """ Builds and returns a DataLoader for the given split. Falls back to SyntheticDataset if use_synthetic=True or no shards found. Args: data_dir : directory with .bin shards split : 'train' or 'val' context_length : model context length (1024) batch_size : number of sequences per batch num_workers : DataLoader workers (0 = main process) use_synthetic : force synthetic data (for testing) vocab_size : needed for synthetic fallback Returns: DataLoader yielding (input_ids, targets) each of shape (B, T) """ if use_synthetic: dataset = SyntheticDataset(vocab_size, context_length) print(f"[DataLoader] Using synthetic data (use_synthetic=True)") else: try: dataset = ShardedTokenDataset( data_dir = data_dir, split = split, context_length = context_length, shuffle_shards = (split == "train"), ) except FileNotFoundError as e: print(f"[DataLoader] WARNING: {e}") print(f"[DataLoader] Falling back to synthetic data for testing.") dataset = SyntheticDataset(vocab_size, context_length) return DataLoader( dataset, batch_size = batch_size, num_workers = num_workers, pin_memory = True, # faster CPU->GPU transfer ) # ------------------------------------------------------------------ # # QUICK CHECK # ------------------------------------------------------------------ # if __name__ == "__main__": import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from model.config import SLLM_100M cfg = SLLM_100M print("Testing with synthetic data...") loader = build_dataloader( data_dir = "tokenizer/data", split = "train", context_length = cfg.context_length, batch_size = 4, num_workers = 0, use_synthetic = True, vocab_size = cfg.vocab_size, ) for i, (x, y) in enumerate(loader): print(f"Batch {i}: input_ids={x.shape}, targets={y.shape}, dtype={x.dtype}") if i == 3: break print("DataLoader OK")