| """ |
| Custom dataset loaders for different training corpora. |
| |
| Usage in config: |
| data: |
| dataset: wikitext103 |
| # or |
| dataset: pile |
| # or |
| dataset: your_custom_dataset |
| """ |
|
|
| from typing import Optional |
| import torch |
| from torch.utils.data import DataLoader |
|
|
|
|
| def build_wikitext_dataloader( |
| tokenizer, |
| split: str = "train", |
| seq_len: int = 512, |
| batch_size: int = 32, |
| num_workers: int = 4, |
| cache_dir: Optional[str] = None, |
| ): |
| """WikiText-103 dataset.""" |
| from datasets import load_dataset |
| |
| ds = load_dataset("wikitext", "wikitext-103-v1", split=split, cache_dir=cache_dir) |
| |
| def tokenize_and_chunk(examples): |
| all_ids = [] |
| for text in examples["text"]: |
| all_ids.extend(tokenizer(text, truncation=False, padding=False)["input_ids"]) |
| chunks = [all_ids[i:i + seq_len] for i in range(0, len(all_ids) - seq_len, seq_len)] |
| return {"input_ids": chunks} |
|
|
| ds = ds.map(tokenize_and_chunk, batched=True, remove_columns=["text"]) |
| ds.set_format(type="torch") |
| |
| def collate_fn(examples): |
| ids = torch.stack([e["input_ids"] for e in examples]) |
| return {"input_ids": ids, "attention_mask": torch.ones_like(ids)} |
| |
| return DataLoader( |
| ds, batch_size=batch_size, shuffle=(split == "train"), |
| num_workers=num_workers, collate_fn=collate_fn, pin_memory=True |
| ) |
|
|
|
|
| def build_c4_dataloader( |
| tokenizer, |
| split: str = "train", |
| seq_len: int = 512, |
| batch_size: int = 32, |
| num_workers: int = 4, |
| cache_dir: Optional[str] = None, |
| streaming: bool = False, |
| ): |
| """C4 (Colossal Clean Crawled Corpus) dataset.""" |
| from datasets import load_dataset |
| |
| ds = load_dataset("c4", "en", split=split, cache_dir=cache_dir, streaming=streaming) |
| |
| def tokenize_and_chunk(examples): |
| all_ids = [] |
| for text in examples["text"]: |
| all_ids.extend(tokenizer(text, truncation=False, padding=False)["input_ids"]) |
| chunks = [all_ids[i:i + seq_len] for i in range(0, len(all_ids) - seq_len, seq_len)] |
| return {"input_ids": chunks} |
|
|
| ds = ds.map(tokenize_and_chunk, batched=True, remove_columns=["text", "timestamp", "url"]) |
| |
| if not streaming: |
| ds.set_format(type="torch") |
| |
| def collate_fn(examples): |
| ids = torch.stack([e["input_ids"] for e in examples]) |
| return {"input_ids": ids, "attention_mask": torch.ones_like(ids)} |
| |
| return DataLoader( |
| ds, batch_size=batch_size, shuffle=(split == "train" and not streaming), |
| num_workers=num_workers, collate_fn=collate_fn, pin_memory=True |
| ) |
|
|
|
|
| def build_pile_dataloader( |
| tokenizer, |
| split: str = "train", |
| seq_len: int = 512, |
| batch_size: int = 32, |
| num_workers: int = 4, |
| cache_dir: Optional[str] = None, |
| streaming: bool = True, |
| ): |
| """The Pile dataset (825GB).""" |
| from datasets import load_dataset |
| |
| ds = load_dataset("EleutherAI/pile", split=split, cache_dir=cache_dir, streaming=streaming) |
| |
| def tokenize_and_chunk(examples): |
| all_ids = [] |
| for text in examples["text"]: |
| all_ids.extend(tokenizer(text, truncation=False, padding=False)["input_ids"]) |
| chunks = [all_ids[i:i + seq_len] for i in range(0, len(all_ids) - seq_len, seq_len)] |
| return {"input_ids": chunks} |
|
|
| ds = ds.map(tokenize_and_chunk, batched=True, remove_columns=["text", "meta"]) |
| |
| def collate_fn(examples): |
| ids = torch.stack([e["input_ids"] for e in examples]) |
| return {"input_ids": ids, "attention_mask": torch.ones_like(ids)} |
| |
| return DataLoader( |
| ds, batch_size=batch_size, shuffle=False, |
| num_workers=num_workers, collate_fn=collate_fn |
| ) |
|
|