| """ |
| Data loading utilities for SAD. |
| |
| Supports: |
| - Tiny debug dataset (random token ids for smoke tests) |
| - OpenWebText via HuggingFace datasets |
| - Generic text dataset from a file |
| |
| All datasets return batches of shape [B, seq_len] with attention_mask. |
| """ |
|
|
| import random |
| from typing import Optional, Iterator |
|
|
| import torch |
| from torch.utils.data import Dataset, DataLoader, IterableDataset |
|
|
|
|
| class TinyDebugDataset(Dataset): |
| """ |
| Random token-id dataset for smoke tests. |
| Does NOT load any real data. |
| """ |
|
|
| def __init__(self, vocab_size: int, seq_len: int, num_samples: int = 512, |
| mask_token_id: int = 50256, seed: int = 42): |
| self.vocab_size = vocab_size |
| self.seq_len = seq_len |
| self.num_samples = num_samples |
| self.mask_token_id = mask_token_id |
| torch.manual_seed(seed) |
| |
| self.data = torch.randint(0, vocab_size - 1, (num_samples, seq_len)) |
| |
| self.data = torch.where( |
| self.data == mask_token_id, |
| torch.zeros_like(self.data), |
| self.data, |
| ) |
|
|
| def __len__(self): |
| return self.num_samples |
|
|
| def __getitem__(self, idx): |
| return { |
| "input_ids": self.data[idx], |
| "attention_mask": torch.ones(self.seq_len, dtype=torch.long), |
| } |
|
|
|
|
| def _parse_split_slice(split: str, total_len: int): |
| """Parse HF split slice like train[:-100000] or train[-100000:] and return (start, end).""" |
| import re |
| m = re.match(r"^(.+?)\[(.+)\]$", split) |
| if not m: |
| return 0, total_len |
| slice_str = m.group(2).strip() |
|
|
| def _to_idx(s, default): |
| s = s.strip() |
| if not s: |
| return default |
| val = int(s) |
| return total_len + val if val < 0 else val |
|
|
| if ":" in slice_str: |
| parts = slice_str.split(":") |
| start = _to_idx(parts[0], 0) |
| end = _to_idx(parts[1], total_len) |
| return max(0, start), min(total_len, end) |
| else: |
| idx = _to_idx(slice_str, 0) |
| return max(0, idx), min(total_len, idx + 1) |
|
|
|
|
| def build_owt_dataloader( |
| tokenizer, |
| split: str = "train", |
| seq_len: int = 512, |
| batch_size: int = 32, |
| num_workers: int = 4, |
| cache_dir: Optional[str] = None, |
| max_samples: Optional[int] = None, |
| seed: int = 42, |
| mode: str = "subsample", |
| shard_across_ranks: bool = True, |
| ) -> DataLoader: |
| """ |
| Build an OpenWebText DataLoader. |
| |
| Args: |
| mode: |
| "subsample" – HDLM-aligned default. Each sample = one document |
| wrapped as [BOS] ... [EOS]. Long docs: random 512-token window. |
| Short docs: pad to seq_len with attention_mask=0 on pads. |
| "pack" – legacy packing. Tokenize all docs, concatenate, split |
| into non-overlapping seq_len chunks (cross-document, no separators). |
| shard_across_ranks: |
| True (train): each rank takes a disjoint slice under DDP. |
| False (val): every rank iterates the full (deterministic) set — |
| eval currently runs on rank 0 only, so sharding would bias |
| metrics to a single shard. |
| """ |
| try: |
| from datasets import load_dataset |
| except ImportError: |
| raise ImportError("pip install datasets") |
|
|
| import glob as _glob |
|
|
| has_slice = "[" in split and "]" in split |
|
|
| if cache_dir is not None: |
| parquet_files = sorted(_glob.glob(f"{cache_dir}/plain_text/train-*.parquet")) |
| if not parquet_files: |
| raise FileNotFoundError( |
| f"No parquet files found in {cache_dir}/plain_text/. " |
| "Run: huggingface-cli download Skylion007/openwebtext " |
| "--repo-type dataset --local-dir <cache_dir>" |
| ) |
| if has_slice: |
| |
| |
| ds = load_dataset( |
| "parquet", data_files={"train": parquet_files}, |
| split="train", streaming=False, |
| ) |
| n = len(ds) |
| start, end = _parse_split_slice(split, n) |
| ds = ds.select(range(start, end)) |
| else: |
| ds = load_dataset( |
| "parquet", data_files={"train": parquet_files}, |
| split="train", streaming=True, |
| ) |
| else: |
| if has_slice: |
| ds = load_dataset( |
| "Skylion007/openwebtext", split=split, streaming=False, |
| ) |
| else: |
| ds = load_dataset( |
| "Skylion007/openwebtext", split="train", streaming=True, |
| ) |
|
|
| |
| |
| is_train = split.startswith("train") and not split.startswith("train[-") |
| if is_train: |
| if has_slice: |
| ds = ds.shuffle(seed=seed) |
| else: |
| ds = ds.shuffle(seed=seed, buffer_size=10_000) |
|
|
| |
| |
| |
| |
| |
| if shard_across_ranks: |
| try: |
| import torch.distributed as dist |
| if dist.is_available() and dist.is_initialized(): |
| ds = ds.shard(num_shards=dist.get_world_size(), |
| index=dist.get_rank()) |
| except ImportError: |
| pass |
|
|
| if max_samples is not None: |
| if hasattr(ds, "take"): |
| ds = ds.take(max_samples) |
| else: |
| n = len(ds) if hasattr(ds, "__len__") else max_samples |
| ds = ds.select(range(min(max_samples, n))) |
|
|
| if mode == "pack": |
| def tokenize_and_chunk(examples): |
| all_ids = [] |
| for text in examples["text"]: |
| all_ids.extend(tokenizer(text, truncation=False, padding=False, |
| return_attention_mask=False)["input_ids"]) |
| chunks = [all_ids[i:i + seq_len] for i in range(0, len(all_ids) - seq_len, seq_len)] |
| return {"input_ids": chunks} |
|
|
| ds = ds.map(tokenize_and_chunk, batched=True, remove_columns=["text"]) |
|
|
| def collate_fn(examples): |
| ids = torch.stack([torch.tensor(e["input_ids"], dtype=torch.long) for e in examples]) |
| return {"input_ids": ids, "attention_mask": torch.ones_like(ids)} |
|
|
| elif mode == "subsample": |
| |
| _candidates = [ |
| tokenizer.eos_token_id, |
| tokenizer.bos_token_id, |
| tokenizer.pad_token_id, |
| ] |
| fallback = next((c for c in _candidates if c is not None), None) |
| if fallback is None: |
| raise ValueError("tokenizer has no bos/eos/pad token id; cannot build subsample loader") |
| bos_id = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else fallback |
| eos_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else fallback |
| pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else fallback |
|
|
| def collate_fn(examples): |
| input_ids_list, attn_list = [], [] |
| for ex in examples: |
| toks = tokenizer(ex["text"], truncation=False, padding=False, |
| return_attention_mask=False)["input_ids"] |
| if not toks or toks[0] != bos_id: |
| toks = [bos_id] + toks |
| if toks[-1] != eos_id: |
| toks = toks + [eos_id] |
|
|
| L = len(toks) |
| if L > seq_len: |
| start = random.randint(0, L - seq_len) |
| toks = toks[start:start + seq_len] |
| attn = [1] * seq_len |
| else: |
| attn = [1] * L + [0] * (seq_len - L) |
| toks = toks + [pad_id] * (seq_len - L) |
|
|
| input_ids_list.append(toks) |
| attn_list.append(attn) |
|
|
| return { |
| "input_ids": torch.tensor(input_ids_list, dtype=torch.long), |
| "attention_mask": torch.tensor(attn_list, dtype=torch.long), |
| } |
|
|
| else: |
| raise ValueError(f"Unknown mode: {mode!r} (expected 'subsample' or 'pack')") |
|
|
| return DataLoader( |
| ds, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| collate_fn=collate_fn, |
| pin_memory=True, |
| persistent_workers=(num_workers > 0), |
| ) |
|
|
|
|
| def build_debug_dataloader( |
| vocab_size: int, |
| seq_len: int = 64, |
| batch_size: int = 4, |
| num_samples: int = 64, |
| mask_token_id: int = 50256, |
| ) -> DataLoader: |
| ds = TinyDebugDataset(vocab_size, seq_len, num_samples, mask_token_id) |
| return DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True) |
|
|