| | """ |
| | Mmap Data Loader |
| | ================ |
| | Memory-mapped binary dataset for efficient training. |
| | Loads tokenized .bin files created by tokenize_data.py. |
| | Uses mmap for low RAM overhead — OS manages page cache. |
| | """ |
| |
|
| | import os |
| | import struct |
| | import numpy as np |
| | import torch |
| | import torch.distributed as dist |
| | from torch.utils.data import Dataset, DataLoader, DistributedSampler |
| |
|
| |
|
| | def _is_rank0() -> bool: |
| | return not dist.is_initialized() or dist.get_rank() == 0 |
| |
|
| | HEADER_MAGIC = b"FREQTOK1" |
| | HEADER_SIZE = 32 |
| |
|
| |
|
| | def read_vocab_size(path: str) -> int: |
| | """Read vocab_size from binary data file header without loading the data.""" |
| | with open(path, "rb") as f: |
| | magic = f.read(8) |
| | assert magic == HEADER_MAGIC, f"Invalid file format: {path}" |
| | _version = struct.unpack("<I", f.read(4))[0] |
| | _num_tokens = struct.unpack("<Q", f.read(8))[0] |
| | _seq_len = struct.unpack("<I", f.read(4))[0] |
| | vocab_size = struct.unpack("<I", f.read(4))[0] |
| | return vocab_size |
| |
|
| |
|
| | class MmapDataset(Dataset): |
| | """Memory-mapped token dataset. Serves (input, target) pairs of length seq_len.""" |
| |
|
| | def __init__(self, path: str, seq_len: int): |
| | assert os.path.exists(path), f"Data file not found: {path}" |
| | self.seq_len = seq_len |
| |
|
| | |
| | with open(path, "rb") as f: |
| | magic = f.read(8) |
| | assert magic == HEADER_MAGIC, f"Invalid file format: {path}" |
| | version = struct.unpack("<I", f.read(4))[0] |
| | self.num_tokens = struct.unpack("<Q", f.read(8))[0] |
| | _ = struct.unpack("<I", f.read(4))[0] |
| | self.vocab_size = struct.unpack("<I", f.read(4))[0] |
| |
|
| | |
| | dtype = np.uint16 if self.vocab_size < 65536 else np.uint32 |
| | self.data = np.memmap(path, dtype=dtype, mode="r", offset=HEADER_SIZE) |
| |
|
| | |
| | |
| | self.n_samples = (len(self.data) - 1) // seq_len |
| |
|
| | def __len__(self): |
| | return self.n_samples |
| |
|
| | def __getitem__(self, idx): |
| | start = idx * self.seq_len |
| | end = start + self.seq_len + 1 |
| | chunk = torch.from_numpy(self.data[start:end].astype(np.int64)) |
| | x = chunk[:-1] |
| | y = chunk[1:] |
| | return x, y |
| |
|
| |
|
| | def get_dataloaders( |
| | data_dir: str, |
| | seq_len: int, |
| | batch_size: int, |
| | train_file: str = "train.bin", |
| | val_file: str = "val.bin", |
| | num_workers: int = 4, |
| | distributed: bool = False, |
| | ): |
| | """Create train and val dataloaders from binary token files.""" |
| | train_path = os.path.join(data_dir, train_file) |
| | val_path = os.path.join(data_dir, val_file) |
| |
|
| | train_ds = MmapDataset(train_path, seq_len) |
| | if _is_rank0(): |
| | print(f"Train: {train_ds.num_tokens:,} tokens, {len(train_ds)} batches of T={seq_len}") |
| |
|
| | val_ds = None |
| | val_loader = None |
| | if os.path.exists(val_path): |
| | val_ds = MmapDataset(val_path, seq_len) |
| | if _is_rank0(): |
| | print(f"Val: {val_ds.num_tokens:,} tokens, {len(val_ds)} batches of T={seq_len}") |
| | val_sampler = DistributedSampler(val_ds, shuffle=False) if distributed else None |
| | val_loader = DataLoader( |
| | val_ds, |
| | batch_size=batch_size, |
| | shuffle=False, |
| | sampler=val_sampler, |
| | num_workers=num_workers, |
| | pin_memory=True, |
| | drop_last=True, |
| | persistent_workers=num_workers > 0, |
| | prefetch_factor=4 if num_workers > 0 else None, |
| | ) |
| |
|
| | train_sampler = DistributedSampler(train_ds, shuffle=True) if distributed else None |
| | train_loader = DataLoader( |
| | train_ds, |
| | batch_size=batch_size, |
| | shuffle=(train_sampler is None), |
| | sampler=train_sampler, |
| | num_workers=num_workers, |
| | pin_memory=True, |
| | drop_last=True, |
| | persistent_workers=num_workers > 0, |
| | prefetch_factor=4 if num_workers > 0 else None, |
| | ) |
| |
|
| | return train_loader, val_loader |
| |
|