| """ |
| data/dataloader.py |
| |
| Streaming dataloader for the pre-tokenized binary shards produced by |
| tokenizer/tokenize_dataset.py. |
| |
| Each shard is a flat binary file of np.uint16 token IDs. |
| 100M tokens * 2 bytes = ~200MB per shard. |
| |
| Strategy: |
| 1. Discover all shards matching split name (train/val). |
| 2. Shuffle shard order at start of each epoch. |
| 3. For each shard, load it (memmap or full) and yield non-overlapping |
| chunks of (context_length + 1) tokens. |
| 4. Inputs = chunk[:-1] (length context_length) |
| Targets = chunk[1:] (length context_length, shifted right by 1) |
| |
| When no data shards exist yet (tokenization not done), a SyntheticShard |
| can be used for architecture testing. |
| """ |
|
|
| import os |
| import glob |
| import random |
| import numpy as np |
| import torch |
| from torch.utils.data import IterableDataset, DataLoader |
|
|
|
|
| |
| |
| |
|
|
| def find_shards(data_dir: str, split: str) -> list[str]: |
| """ |
| Returns sorted list of shard paths for the given split. |
| |
| Args: |
| data_dir : directory containing .bin shard files |
| split : 'train' or 'val' |
| """ |
| pattern = os.path.join(data_dir, f"{split}_*.bin") |
| shards = sorted(glob.glob(pattern)) |
| return shards |
|
|
|
|
| |
| |
| |
|
|
| class ShardedTokenDataset(IterableDataset): |
| """ |
| IterableDataset that streams token chunks from binary shards. |
| |
| Each worker processes a disjoint subset of shards so we get |
| proper parallelism with DataLoader(num_workers=N). |
| |
| Usage: |
| dataset = ShardedTokenDataset(data_dir, split='train', context_length=1024) |
| loader = DataLoader(dataset, batch_size=4) |
| for input_ids, targets in loader: |
| ... |
| """ |
|
|
| def __init__( |
| self, |
| data_dir: str, |
| split: str, |
| context_length: int, |
| shuffle_shards: bool = True, |
| ): |
| """ |
| Args: |
| data_dir : path to directory with .bin shard files |
| split : 'train' or 'val' |
| context_length : sequence length (model context length) |
| shuffle_shards : shuffle shard order each epoch (train only) |
| """ |
| super().__init__() |
| self.context_length = context_length |
| self.shuffle_shards = shuffle_shards |
|
|
| self.shards = find_shards(data_dir, split) |
| if not self.shards: |
| raise FileNotFoundError( |
| f"No {split} shards found in {data_dir}.\n" |
| f"Run tokenizer/tokenize_dataset.py first to generate data." |
| ) |
| print(f"[DataLoader] Found {len(self.shards)} {split} shards in {data_dir}") |
|
|
| def __iter__(self): |
| worker_info = torch.utils.data.get_worker_info() |
|
|
| shards = self.shards.copy() |
| if self.shuffle_shards: |
| random.shuffle(shards) |
|
|
| |
| if worker_info is not None: |
| shards = shards[worker_info.id :: worker_info.num_workers] |
|
|
| chunk = self.context_length + 1 |
|
|
| for shard_path in shards: |
| |
| tokens = np.fromfile(shard_path, dtype=np.uint16).astype(np.int32) |
|
|
| |
| n_chunks = len(tokens) // chunk |
| for i in range(n_chunks): |
| start = i * chunk |
| seq = torch.from_numpy(tokens[start : start + chunk].copy()) |
| input_ids = seq[:-1].long() |
| targets = seq[1:].long() |
| yield input_ids, targets |
|
|
|
|
| |
| |
| |
|
|
| class SyntheticDataset(IterableDataset): |
| """ |
| Generates random token sequences for architecture testing. |
| Use when real shards are not yet available. |
| """ |
|
|
| def __init__(self, vocab_size: int, context_length: int, n_batches: int = 1000): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.context_length = context_length |
| self.n_batches = n_batches |
|
|
| def __iter__(self): |
| for _ in range(self.n_batches): |
| seq = torch.randint(0, self.vocab_size, (self.context_length + 1,)) |
| input_ids = seq[:-1] |
| targets = seq[1:] |
| yield input_ids, targets |
|
|
|
|
| |
| |
| |
|
|
| def build_dataloader( |
| data_dir: str, |
| split: str, |
| context_length: int, |
| batch_size: int, |
| num_workers: int = 2, |
| use_synthetic: bool = False, |
| vocab_size: int = 32_000, |
| ) -> DataLoader: |
| """ |
| Builds and returns a DataLoader for the given split. |
| |
| Falls back to SyntheticDataset if use_synthetic=True or no shards found. |
| |
| Args: |
| data_dir : directory with .bin shards |
| split : 'train' or 'val' |
| context_length : model context length (1024) |
| batch_size : number of sequences per batch |
| num_workers : DataLoader workers (0 = main process) |
| use_synthetic : force synthetic data (for testing) |
| vocab_size : needed for synthetic fallback |
| |
| Returns: |
| DataLoader yielding (input_ids, targets) each of shape (B, T) |
| """ |
| if use_synthetic: |
| dataset = SyntheticDataset(vocab_size, context_length) |
| print(f"[DataLoader] Using synthetic data (use_synthetic=True)") |
| else: |
| try: |
| dataset = ShardedTokenDataset( |
| data_dir = data_dir, |
| split = split, |
| context_length = context_length, |
| shuffle_shards = (split == "train"), |
| ) |
| except FileNotFoundError as e: |
| print(f"[DataLoader] WARNING: {e}") |
| print(f"[DataLoader] Falling back to synthetic data for testing.") |
| dataset = SyntheticDataset(vocab_size, context_length) |
|
|
| return DataLoader( |
| dataset, |
| batch_size = batch_size, |
| num_workers = num_workers, |
| pin_memory = True, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| import sys |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from model.config import SLLM_100M |
|
|
| cfg = SLLM_100M |
|
|
| print("Testing with synthetic data...") |
| loader = build_dataloader( |
| data_dir = "tokenizer/data", |
| split = "train", |
| context_length = cfg.context_length, |
| batch_size = 4, |
| num_workers = 0, |
| use_synthetic = True, |
| vocab_size = cfg.vocab_size, |
| ) |
|
|
| for i, (x, y) in enumerate(loader): |
| print(f"Batch {i}: input_ids={x.shape}, targets={y.shape}, dtype={x.dtype}") |
| if i == 3: |
| break |
|
|
| print("DataLoader OK") |
|
|