sad / src /data /__init__.py
haochengsama's picture
add missing files batch 13 (400)
278b5e7 verified
Raw
History Blame Contribute Delete
9.17 kB
"""
Data loading utilities for SAD.
Supports:
- Tiny debug dataset (random token ids for smoke tests)
- OpenWebText via HuggingFace datasets
- Generic text dataset from a file
All datasets return batches of shape [B, seq_len] with attention_mask.
"""
import random
from typing import Optional, Iterator
import torch
from torch.utils.data import Dataset, DataLoader, IterableDataset
class TinyDebugDataset(Dataset):
"""
Random token-id dataset for smoke tests.
Does NOT load any real data.
"""
def __init__(self, vocab_size: int, seq_len: int, num_samples: int = 512,
mask_token_id: int = 50256, seed: int = 42):
self.vocab_size = vocab_size
self.seq_len = seq_len
self.num_samples = num_samples
self.mask_token_id = mask_token_id
torch.manual_seed(seed)
# Pre-generate to make iteration reproducible
self.data = torch.randint(0, vocab_size - 1, (num_samples, seq_len))
# Avoid mask token in ground truth
self.data = torch.where(
self.data == mask_token_id,
torch.zeros_like(self.data),
self.data,
)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.data[idx],
"attention_mask": torch.ones(self.seq_len, dtype=torch.long),
}
def _parse_split_slice(split: str, total_len: int):
"""Parse HF split slice like train[:-100000] or train[-100000:] and return (start, end)."""
import re
m = re.match(r"^(.+?)\[(.+)\]$", split)
if not m:
return 0, total_len
slice_str = m.group(2).strip()
def _to_idx(s, default):
s = s.strip()
if not s:
return default
val = int(s)
return total_len + val if val < 0 else val
if ":" in slice_str:
parts = slice_str.split(":")
start = _to_idx(parts[0], 0)
end = _to_idx(parts[1], total_len)
return max(0, start), min(total_len, end)
else:
idx = _to_idx(slice_str, 0)
return max(0, idx), min(total_len, idx + 1)
def build_owt_dataloader(
tokenizer,
split: str = "train",
seq_len: int = 512,
batch_size: int = 32,
num_workers: int = 4,
cache_dir: Optional[str] = None,
max_samples: Optional[int] = None,
seed: int = 42,
mode: str = "subsample",
shard_across_ranks: bool = True,
) -> DataLoader:
"""
Build an OpenWebText DataLoader.
Args:
mode:
"subsample" – HDLM-aligned default. Each sample = one document
wrapped as [BOS] ... [EOS]. Long docs: random 512-token window.
Short docs: pad to seq_len with attention_mask=0 on pads.
"pack" – legacy packing. Tokenize all docs, concatenate, split
into non-overlapping seq_len chunks (cross-document, no separators).
shard_across_ranks:
True (train): each rank takes a disjoint slice under DDP.
False (val): every rank iterates the full (deterministic) set —
eval currently runs on rank 0 only, so sharding would bias
metrics to a single shard.
"""
try:
from datasets import load_dataset
except ImportError:
raise ImportError("pip install datasets")
import glob as _glob
has_slice = "[" in split and "]" in split
if cache_dir is not None:
parquet_files = sorted(_glob.glob(f"{cache_dir}/plain_text/train-*.parquet"))
if not parquet_files:
raise FileNotFoundError(
f"No parquet files found in {cache_dir}/plain_text/. "
"Run: huggingface-cli download Skylion007/openwebtext "
"--repo-type dataset --local-dir <cache_dir>"
)
if has_slice:
# Exact slice (e.g. train[:-100000]) requires non-streaming so we can
# select a precise range. This matches HDLM's behaviour.
ds = load_dataset(
"parquet", data_files={"train": parquet_files},
split="train", streaming=False,
)
n = len(ds)
start, end = _parse_split_slice(split, n)
ds = ds.select(range(start, end))
else:
ds = load_dataset(
"parquet", data_files={"train": parquet_files},
split="train", streaming=True,
)
else:
if has_slice:
ds = load_dataset(
"Skylion007/openwebtext", split=split, streaming=False,
)
else:
ds = load_dataset(
"Skylion007/openwebtext", split="train", streaming=True,
)
# Shuffle training splits only. A tail slice like train[-100000:] is val.
# train[:-N] 是训练集(取前面),train[-N:] 是验证集(取末尾)。
is_train = split.startswith("train") and not split.startswith("train[-")
if is_train:
if has_slice:
ds = ds.shuffle(seed=seed)
else:
ds = ds.shuffle(seed=seed, buffer_size=10_000)
# Shard across ranks so each GPU sees a disjoint slice. Without this,
# every rank iterates the same stream → multi-GPU training becomes
# gradient-averaging over identical batches. Eval loaders opt out
# (shard_across_ranks=False) so rank-0-only evaluation isn't biased
# to a single shard.
if shard_across_ranks:
try:
import torch.distributed as dist
if dist.is_available() and dist.is_initialized():
ds = ds.shard(num_shards=dist.get_world_size(),
index=dist.get_rank())
except ImportError:
pass
if max_samples is not None:
if hasattr(ds, "take"):
ds = ds.take(max_samples)
else:
n = len(ds) if hasattr(ds, "__len__") else max_samples
ds = ds.select(range(min(max_samples, n)))
if mode == "pack":
def tokenize_and_chunk(examples):
all_ids = []
for text in examples["text"]:
all_ids.extend(tokenizer(text, truncation=False, padding=False,
return_attention_mask=False)["input_ids"])
chunks = [all_ids[i:i + seq_len] for i in range(0, len(all_ids) - seq_len, seq_len)]
return {"input_ids": chunks}
ds = ds.map(tokenize_and_chunk, batched=True, remove_columns=["text"])
def collate_fn(examples):
ids = torch.stack([torch.tensor(e["input_ids"], dtype=torch.long) for e in examples])
return {"input_ids": ids, "attention_mask": torch.ones_like(ids)}
elif mode == "subsample":
# GPT-2 常见:bos/eos/pad 全部共用 <|endoftext|> = 50256,但 bos_token_id 可能为 None
_candidates = [
tokenizer.eos_token_id,
tokenizer.bos_token_id,
tokenizer.pad_token_id,
]
fallback = next((c for c in _candidates if c is not None), None)
if fallback is None:
raise ValueError("tokenizer has no bos/eos/pad token id; cannot build subsample loader")
bos_id = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else fallback
eos_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else fallback
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else fallback
def collate_fn(examples):
input_ids_list, attn_list = [], []
for ex in examples:
toks = tokenizer(ex["text"], truncation=False, padding=False,
return_attention_mask=False)["input_ids"]
if not toks or toks[0] != bos_id:
toks = [bos_id] + toks
if toks[-1] != eos_id:
toks = toks + [eos_id]
L = len(toks)
if L > seq_len:
start = random.randint(0, L - seq_len)
toks = toks[start:start + seq_len]
attn = [1] * seq_len
else:
attn = [1] * L + [0] * (seq_len - L)
toks = toks + [pad_id] * (seq_len - L)
input_ids_list.append(toks)
attn_list.append(attn)
return {
"input_ids": torch.tensor(input_ids_list, dtype=torch.long),
"attention_mask": torch.tensor(attn_list, dtype=torch.long),
}
else:
raise ValueError(f"Unknown mode: {mode!r} (expected 'subsample' or 'pack')")
return DataLoader(
ds,
batch_size=batch_size,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=True,
persistent_workers=(num_workers > 0),
)
def build_debug_dataloader(
vocab_size: int,
seq_len: int = 64,
batch_size: int = 4,
num_samples: int = 64,
mask_token_id: int = 50256,
) -> DataLoader:
ds = TinyDebugDataset(vocab_size, seq_len, num_samples, mask_token_id)
return DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True)