""" Data loading utilities for SAD. Supports: - Tiny debug dataset (random token ids for smoke tests) - OpenWebText via HuggingFace datasets - Generic text dataset from a file All datasets return batches of shape [B, seq_len] with attention_mask. """ import random from typing import Optional, Iterator import torch from torch.utils.data import Dataset, DataLoader, IterableDataset class TinyDebugDataset(Dataset): """ Random token-id dataset for smoke tests. Does NOT load any real data. """ def __init__(self, vocab_size: int, seq_len: int, num_samples: int = 512, mask_token_id: int = 50256, seed: int = 42): self.vocab_size = vocab_size self.seq_len = seq_len self.num_samples = num_samples self.mask_token_id = mask_token_id torch.manual_seed(seed) # Pre-generate to make iteration reproducible self.data = torch.randint(0, vocab_size - 1, (num_samples, seq_len)) # Avoid mask token in ground truth self.data = torch.where( self.data == mask_token_id, torch.zeros_like(self.data), self.data, ) def __len__(self): return self.num_samples def __getitem__(self, idx): return { "input_ids": self.data[idx], "attention_mask": torch.ones(self.seq_len, dtype=torch.long), } def _parse_split_slice(split: str, total_len: int): """Parse HF split slice like train[:-100000] or train[-100000:] and return (start, end).""" import re m = re.match(r"^(.+?)\[(.+)\]$", split) if not m: return 0, total_len slice_str = m.group(2).strip() def _to_idx(s, default): s = s.strip() if not s: return default val = int(s) return total_len + val if val < 0 else val if ":" in slice_str: parts = slice_str.split(":") start = _to_idx(parts[0], 0) end = _to_idx(parts[1], total_len) return max(0, start), min(total_len, end) else: idx = _to_idx(slice_str, 0) return max(0, idx), min(total_len, idx + 1) def build_owt_dataloader( tokenizer, split: str = "train", seq_len: int = 512, batch_size: int = 32, num_workers: int = 4, cache_dir: Optional[str] = None, max_samples: Optional[int] = None, seed: int = 42, mode: str = "subsample", shard_across_ranks: bool = True, ) -> DataLoader: """ Build an OpenWebText DataLoader. Args: mode: "subsample" – HDLM-aligned default. Each sample = one document wrapped as [BOS] ... [EOS]. Long docs: random 512-token window. Short docs: pad to seq_len with attention_mask=0 on pads. "pack" – legacy packing. Tokenize all docs, concatenate, split into non-overlapping seq_len chunks (cross-document, no separators). shard_across_ranks: True (train): each rank takes a disjoint slice under DDP. False (val): every rank iterates the full (deterministic) set — eval currently runs on rank 0 only, so sharding would bias metrics to a single shard. """ try: from datasets import load_dataset except ImportError: raise ImportError("pip install datasets") import glob as _glob has_slice = "[" in split and "]" in split if cache_dir is not None: parquet_files = sorted(_glob.glob(f"{cache_dir}/plain_text/train-*.parquet")) if not parquet_files: raise FileNotFoundError( f"No parquet files found in {cache_dir}/plain_text/. " "Run: huggingface-cli download Skylion007/openwebtext " "--repo-type dataset --local-dir " ) if has_slice: # Exact slice (e.g. train[:-100000]) requires non-streaming so we can # select a precise range. This matches HDLM's behaviour. ds = load_dataset( "parquet", data_files={"train": parquet_files}, split="train", streaming=False, ) n = len(ds) start, end = _parse_split_slice(split, n) ds = ds.select(range(start, end)) else: ds = load_dataset( "parquet", data_files={"train": parquet_files}, split="train", streaming=True, ) else: if has_slice: ds = load_dataset( "Skylion007/openwebtext", split=split, streaming=False, ) else: ds = load_dataset( "Skylion007/openwebtext", split="train", streaming=True, ) # Shuffle training splits only. A tail slice like train[-100000:] is val. # train[:-N] 是训练集(取前面),train[-N:] 是验证集(取末尾)。 is_train = split.startswith("train") and not split.startswith("train[-") if is_train: if has_slice: ds = ds.shuffle(seed=seed) else: ds = ds.shuffle(seed=seed, buffer_size=10_000) # Shard across ranks so each GPU sees a disjoint slice. Without this, # every rank iterates the same stream → multi-GPU training becomes # gradient-averaging over identical batches. Eval loaders opt out # (shard_across_ranks=False) so rank-0-only evaluation isn't biased # to a single shard. if shard_across_ranks: try: import torch.distributed as dist if dist.is_available() and dist.is_initialized(): ds = ds.shard(num_shards=dist.get_world_size(), index=dist.get_rank()) except ImportError: pass if max_samples is not None: if hasattr(ds, "take"): ds = ds.take(max_samples) else: n = len(ds) if hasattr(ds, "__len__") else max_samples ds = ds.select(range(min(max_samples, n))) if mode == "pack": def tokenize_and_chunk(examples): all_ids = [] for text in examples["text"]: all_ids.extend(tokenizer(text, truncation=False, padding=False, return_attention_mask=False)["input_ids"]) chunks = [all_ids[i:i + seq_len] for i in range(0, len(all_ids) - seq_len, seq_len)] return {"input_ids": chunks} ds = ds.map(tokenize_and_chunk, batched=True, remove_columns=["text"]) def collate_fn(examples): ids = torch.stack([torch.tensor(e["input_ids"], dtype=torch.long) for e in examples]) return {"input_ids": ids, "attention_mask": torch.ones_like(ids)} elif mode == "subsample": # GPT-2 常见:bos/eos/pad 全部共用 <|endoftext|> = 50256,但 bos_token_id 可能为 None _candidates = [ tokenizer.eos_token_id, tokenizer.bos_token_id, tokenizer.pad_token_id, ] fallback = next((c for c in _candidates if c is not None), None) if fallback is None: raise ValueError("tokenizer has no bos/eos/pad token id; cannot build subsample loader") bos_id = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else fallback eos_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else fallback pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else fallback def collate_fn(examples): input_ids_list, attn_list = [], [] for ex in examples: toks = tokenizer(ex["text"], truncation=False, padding=False, return_attention_mask=False)["input_ids"] if not toks or toks[0] != bos_id: toks = [bos_id] + toks if toks[-1] != eos_id: toks = toks + [eos_id] L = len(toks) if L > seq_len: start = random.randint(0, L - seq_len) toks = toks[start:start + seq_len] attn = [1] * seq_len else: attn = [1] * L + [0] * (seq_len - L) toks = toks + [pad_id] * (seq_len - L) input_ids_list.append(toks) attn_list.append(attn) return { "input_ids": torch.tensor(input_ids_list, dtype=torch.long), "attention_mask": torch.tensor(attn_list, dtype=torch.long), } else: raise ValueError(f"Unknown mode: {mode!r} (expected 'subsample' or 'pack')") return DataLoader( ds, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, pin_memory=True, persistent_workers=(num_workers > 0), ) def build_debug_dataloader( vocab_size: int, seq_len: int = 64, batch_size: int = 4, num_samples: int = 64, mask_token_id: int = 50256, ) -> DataLoader: ds = TinyDebugDataset(vocab_size, seq_len, num_samples, mask_token_id) return DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True)