cturan's picture
Upload folder using huggingface_hub
01fcd60 verified
"""
Mmap Data Loader
================
Memory-mapped binary dataset for efficient training.
Loads tokenized .bin files created by tokenize_data.py.
Uses mmap for low RAM overhead — OS manages page cache.
"""
import os
import struct
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader, DistributedSampler
def _is_rank0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0
HEADER_MAGIC = b"FREQTOK1"
HEADER_SIZE = 32
def read_vocab_size(path: str) -> int:
"""Read vocab_size from binary data file header without loading the data."""
with open(path, "rb") as f:
magic = f.read(8)
assert magic == HEADER_MAGIC, f"Invalid file format: {path}"
_version = struct.unpack("<I", f.read(4))[0]
_num_tokens = struct.unpack("<Q", f.read(8))[0]
_seq_len = struct.unpack("<I", f.read(4))[0]
vocab_size = struct.unpack("<I", f.read(4))[0]
return vocab_size
class MmapDataset(Dataset):
"""Memory-mapped token dataset. Serves (input, target) pairs of length seq_len."""
def __init__(self, path: str, seq_len: int):
assert os.path.exists(path), f"Data file not found: {path}"
self.seq_len = seq_len
# Read header
with open(path, "rb") as f:
magic = f.read(8)
assert magic == HEADER_MAGIC, f"Invalid file format: {path}"
version = struct.unpack("<I", f.read(4))[0]
self.num_tokens = struct.unpack("<Q", f.read(8))[0]
_ = struct.unpack("<I", f.read(4))[0] # seq_len (unused)
self.vocab_size = struct.unpack("<I", f.read(4))[0]
# Determine dtype
dtype = np.uint16 if self.vocab_size < 65536 else np.uint32
self.data = np.memmap(path, dtype=dtype, mode="r", offset=HEADER_SIZE)
# Number of complete sequences we can serve
# We need seq_len + 1 tokens per sample (input + 1 shifted target)
self.n_samples = (len(self.data) - 1) // seq_len
def __len__(self):
return self.n_samples
def __getitem__(self, idx):
start = idx * self.seq_len
end = start + self.seq_len + 1
chunk = torch.from_numpy(self.data[start:end].astype(np.int64))
x = chunk[:-1]
y = chunk[1:]
return x, y
def get_dataloaders(
data_dir: str,
seq_len: int,
batch_size: int,
train_file: str = "train.bin",
val_file: str = "val.bin",
num_workers: int = 4,
distributed: bool = False,
):
"""Create train and val dataloaders from binary token files."""
train_path = os.path.join(data_dir, train_file)
val_path = os.path.join(data_dir, val_file)
train_ds = MmapDataset(train_path, seq_len)
if _is_rank0():
print(f"Train: {train_ds.num_tokens:,} tokens, {len(train_ds)} batches of T={seq_len}")
val_ds = None
val_loader = None
if os.path.exists(val_path):
val_ds = MmapDataset(val_path, seq_len)
if _is_rank0():
print(f"Val: {val_ds.num_tokens:,} tokens, {len(val_ds)} batches of T={seq_len}")
val_sampler = DistributedSampler(val_ds, shuffle=False) if distributed else None
val_loader = DataLoader(
val_ds,
batch_size=batch_size,
shuffle=False,
sampler=val_sampler,
num_workers=num_workers,
pin_memory=True,
drop_last=True,
persistent_workers=num_workers > 0,
prefetch_factor=4 if num_workers > 0 else None,
)
train_sampler = DistributedSampler(train_ds, shuffle=True) if distributed else None
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=(train_sampler is None),
sampler=train_sampler,
num_workers=num_workers,
pin_memory=True,
drop_last=True,
persistent_workers=num_workers > 0,
prefetch_factor=4 if num_workers > 0 else None,
)
return train_loader, val_loader