""" Custom dataset loaders for different training corpora. Usage in config: data: dataset: wikitext103 # or dataset: pile # or dataset: your_custom_dataset """ from typing import Optional import torch from torch.utils.data import DataLoader def build_wikitext_dataloader( tokenizer, split: str = "train", seq_len: int = 512, batch_size: int = 32, num_workers: int = 4, cache_dir: Optional[str] = None, ): """WikiText-103 dataset.""" from datasets import load_dataset ds = load_dataset("wikitext", "wikitext-103-v1", split=split, cache_dir=cache_dir) def tokenize_and_chunk(examples): all_ids = [] for text in examples["text"]: all_ids.extend(tokenizer(text, truncation=False, padding=False)["input_ids"]) chunks = [all_ids[i:i + seq_len] for i in range(0, len(all_ids) - seq_len, seq_len)] return {"input_ids": chunks} ds = ds.map(tokenize_and_chunk, batched=True, remove_columns=["text"]) ds.set_format(type="torch") def collate_fn(examples): ids = torch.stack([e["input_ids"] for e in examples]) return {"input_ids": ids, "attention_mask": torch.ones_like(ids)} return DataLoader( ds, batch_size=batch_size, shuffle=(split == "train"), num_workers=num_workers, collate_fn=collate_fn, pin_memory=True ) def build_c4_dataloader( tokenizer, split: str = "train", seq_len: int = 512, batch_size: int = 32, num_workers: int = 4, cache_dir: Optional[str] = None, streaming: bool = False, ): """C4 (Colossal Clean Crawled Corpus) dataset.""" from datasets import load_dataset ds = load_dataset("c4", "en", split=split, cache_dir=cache_dir, streaming=streaming) def tokenize_and_chunk(examples): all_ids = [] for text in examples["text"]: all_ids.extend(tokenizer(text, truncation=False, padding=False)["input_ids"]) chunks = [all_ids[i:i + seq_len] for i in range(0, len(all_ids) - seq_len, seq_len)] return {"input_ids": chunks} ds = ds.map(tokenize_and_chunk, batched=True, remove_columns=["text", "timestamp", "url"]) if not streaming: ds.set_format(type="torch") def collate_fn(examples): ids = torch.stack([e["input_ids"] for e in examples]) return {"input_ids": ids, "attention_mask": torch.ones_like(ids)} return DataLoader( ds, batch_size=batch_size, shuffle=(split == "train" and not streaming), num_workers=num_workers, collate_fn=collate_fn, pin_memory=True ) def build_pile_dataloader( tokenizer, split: str = "train", seq_len: int = 512, batch_size: int = 32, num_workers: int = 4, cache_dir: Optional[str] = None, streaming: bool = True, # Pile很大,推荐streaming ): """The Pile dataset (825GB).""" from datasets import load_dataset ds = load_dataset("EleutherAI/pile", split=split, cache_dir=cache_dir, streaming=streaming) def tokenize_and_chunk(examples): all_ids = [] for text in examples["text"]: all_ids.extend(tokenizer(text, truncation=False, padding=False)["input_ids"]) chunks = [all_ids[i:i + seq_len] for i in range(0, len(all_ids) - seq_len, seq_len)] return {"input_ids": chunks} ds = ds.map(tokenize_and_chunk, batched=True, remove_columns=["text", "meta"]) def collate_fn(examples): ids = torch.stack([e["input_ids"] for e in examples]) return {"input_ids": ids, "attention_mask": torch.ones_like(ids)} return DataLoader( ds, batch_size=batch_size, shuffle=False, # streaming不支持shuffle num_workers=num_workers, collate_fn=collate_fn )