File size: 3,791 Bytes
278b5e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Custom dataset loaders for different training corpora.

Usage in config:
  data:
    dataset: wikitext103
    # or
    dataset: pile
    # or
    dataset: your_custom_dataset
"""

from typing import Optional
import torch
from torch.utils.data import DataLoader


def build_wikitext_dataloader(
    tokenizer,
    split: str = "train",
    seq_len: int = 512,
    batch_size: int = 32,
    num_workers: int = 4,
    cache_dir: Optional[str] = None,
):
    """WikiText-103 dataset."""
    from datasets import load_dataset
    
    ds = load_dataset("wikitext", "wikitext-103-v1", split=split, cache_dir=cache_dir)
    
    def tokenize_and_chunk(examples):
        all_ids = []
        for text in examples["text"]:
            all_ids.extend(tokenizer(text, truncation=False, padding=False)["input_ids"])
        chunks = [all_ids[i:i + seq_len] for i in range(0, len(all_ids) - seq_len, seq_len)]
        return {"input_ids": chunks}

    ds = ds.map(tokenize_and_chunk, batched=True, remove_columns=["text"])
    ds.set_format(type="torch")
    
    def collate_fn(examples):
        ids = torch.stack([e["input_ids"] for e in examples])
        return {"input_ids": ids, "attention_mask": torch.ones_like(ids)}
    
    return DataLoader(
        ds, batch_size=batch_size, shuffle=(split == "train"),
        num_workers=num_workers, collate_fn=collate_fn, pin_memory=True
    )


def build_c4_dataloader(
    tokenizer,
    split: str = "train",
    seq_len: int = 512,
    batch_size: int = 32,
    num_workers: int = 4,
    cache_dir: Optional[str] = None,
    streaming: bool = False,
):
    """C4 (Colossal Clean Crawled Corpus) dataset."""
    from datasets import load_dataset
    
    ds = load_dataset("c4", "en", split=split, cache_dir=cache_dir, streaming=streaming)
    
    def tokenize_and_chunk(examples):
        all_ids = []
        for text in examples["text"]:
            all_ids.extend(tokenizer(text, truncation=False, padding=False)["input_ids"])
        chunks = [all_ids[i:i + seq_len] for i in range(0, len(all_ids) - seq_len, seq_len)]
        return {"input_ids": chunks}

    ds = ds.map(tokenize_and_chunk, batched=True, remove_columns=["text", "timestamp", "url"])
    
    if not streaming:
        ds.set_format(type="torch")
    
    def collate_fn(examples):
        ids = torch.stack([e["input_ids"] for e in examples])
        return {"input_ids": ids, "attention_mask": torch.ones_like(ids)}
    
    return DataLoader(
        ds, batch_size=batch_size, shuffle=(split == "train" and not streaming),
        num_workers=num_workers, collate_fn=collate_fn, pin_memory=True
    )


def build_pile_dataloader(
    tokenizer,
    split: str = "train",
    seq_len: int = 512,
    batch_size: int = 32,
    num_workers: int = 4,
    cache_dir: Optional[str] = None,
    streaming: bool = True,  # Pile很大,推荐streaming
):
    """The Pile dataset (825GB)."""
    from datasets import load_dataset
    
    ds = load_dataset("EleutherAI/pile", split=split, cache_dir=cache_dir, streaming=streaming)
    
    def tokenize_and_chunk(examples):
        all_ids = []
        for text in examples["text"]:
            all_ids.extend(tokenizer(text, truncation=False, padding=False)["input_ids"])
        chunks = [all_ids[i:i + seq_len] for i in range(0, len(all_ids) - seq_len, seq_len)]
        return {"input_ids": chunks}

    ds = ds.map(tokenize_and_chunk, batched=True, remove_columns=["text", "meta"])
    
    def collate_fn(examples):
        ids = torch.stack([e["input_ids"] for e in examples])
        return {"input_ids": ids, "attention_mask": torch.ones_like(ids)}
    
    return DataLoader(
        ds, batch_size=batch_size, shuffle=False,  # streaming不支持shuffle
        num_workers=num_workers, collate_fn=collate_fn
    )