File size: 9,167 Bytes
278b5e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
"""
Data loading utilities for SAD.

Supports:
  - Tiny debug dataset (random token ids for smoke tests)
  - OpenWebText via HuggingFace datasets
  - Generic text dataset from a file

All datasets return batches of shape [B, seq_len] with attention_mask.
"""

import random
from typing import Optional, Iterator

import torch
from torch.utils.data import Dataset, DataLoader, IterableDataset


class TinyDebugDataset(Dataset):
    """
    Random token-id dataset for smoke tests.
    Does NOT load any real data.
    """

    def __init__(self, vocab_size: int, seq_len: int, num_samples: int = 512,
                 mask_token_id: int = 50256, seed: int = 42):
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.num_samples = num_samples
        self.mask_token_id = mask_token_id
        torch.manual_seed(seed)
        # Pre-generate to make iteration reproducible
        self.data = torch.randint(0, vocab_size - 1, (num_samples, seq_len))
        # Avoid mask token in ground truth
        self.data = torch.where(
            self.data == mask_token_id,
            torch.zeros_like(self.data),
            self.data,
        )

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return {
            "input_ids": self.data[idx],
            "attention_mask": torch.ones(self.seq_len, dtype=torch.long),
        }


def _parse_split_slice(split: str, total_len: int):
    """Parse HF split slice like train[:-100000] or train[-100000:] and return (start, end)."""
    import re
    m = re.match(r"^(.+?)\[(.+)\]$", split)
    if not m:
        return 0, total_len
    slice_str = m.group(2).strip()

    def _to_idx(s, default):
        s = s.strip()
        if not s:
            return default
        val = int(s)
        return total_len + val if val < 0 else val

    if ":" in slice_str:
        parts = slice_str.split(":")
        start = _to_idx(parts[0], 0)
        end = _to_idx(parts[1], total_len)
        return max(0, start), min(total_len, end)
    else:
        idx = _to_idx(slice_str, 0)
        return max(0, idx), min(total_len, idx + 1)


def build_owt_dataloader(
    tokenizer,
    split: str = "train",
    seq_len: int = 512,
    batch_size: int = 32,
    num_workers: int = 4,
    cache_dir: Optional[str] = None,
    max_samples: Optional[int] = None,
    seed: int = 42,
    mode: str = "subsample",
    shard_across_ranks: bool = True,
) -> DataLoader:
    """
    Build an OpenWebText DataLoader.

    Args:
        mode:
          "subsample" – HDLM-aligned default. Each sample = one document
              wrapped as [BOS] ... [EOS]. Long docs: random 512-token window.
              Short docs: pad to seq_len with attention_mask=0 on pads.
          "pack" – legacy packing. Tokenize all docs, concatenate, split
              into non-overlapping seq_len chunks (cross-document, no separators).
        shard_across_ranks:
          True (train): each rank takes a disjoint slice under DDP.
          False (val):  every rank iterates the full (deterministic) set —
              eval currently runs on rank 0 only, so sharding would bias
              metrics to a single shard.
    """
    try:
        from datasets import load_dataset
    except ImportError:
        raise ImportError("pip install datasets")

    import glob as _glob

    has_slice = "[" in split and "]" in split

    if cache_dir is not None:
        parquet_files = sorted(_glob.glob(f"{cache_dir}/plain_text/train-*.parquet"))
        if not parquet_files:
            raise FileNotFoundError(
                f"No parquet files found in {cache_dir}/plain_text/. "
                "Run: huggingface-cli download Skylion007/openwebtext "
                "--repo-type dataset --local-dir <cache_dir>"
            )
        if has_slice:
            # Exact slice (e.g. train[:-100000]) requires non-streaming so we can
            # select a precise range. This matches HDLM's behaviour.
            ds = load_dataset(
                "parquet", data_files={"train": parquet_files},
                split="train", streaming=False,
            )
            n = len(ds)
            start, end = _parse_split_slice(split, n)
            ds = ds.select(range(start, end))
        else:
            ds = load_dataset(
                "parquet", data_files={"train": parquet_files},
                split="train", streaming=True,
            )
    else:
        if has_slice:
            ds = load_dataset(
                "Skylion007/openwebtext", split=split, streaming=False,
            )
        else:
            ds = load_dataset(
                "Skylion007/openwebtext", split="train", streaming=True,
            )

    # Shuffle training splits only. A tail slice like train[-100000:] is val.
    # train[:-N] 是训练集(取前面),train[-N:] 是验证集(取末尾)。
    is_train = split.startswith("train") and not split.startswith("train[-")
    if is_train:
        if has_slice:
            ds = ds.shuffle(seed=seed)
        else:
            ds = ds.shuffle(seed=seed, buffer_size=10_000)

    # Shard across ranks so each GPU sees a disjoint slice. Without this,
    # every rank iterates the same stream → multi-GPU training becomes
    # gradient-averaging over identical batches. Eval loaders opt out
    # (shard_across_ranks=False) so rank-0-only evaluation isn't biased
    # to a single shard.
    if shard_across_ranks:
        try:
            import torch.distributed as dist
            if dist.is_available() and dist.is_initialized():
                ds = ds.shard(num_shards=dist.get_world_size(),
                              index=dist.get_rank())
        except ImportError:
            pass

    if max_samples is not None:
        if hasattr(ds, "take"):
            ds = ds.take(max_samples)
        else:
            n = len(ds) if hasattr(ds, "__len__") else max_samples
            ds = ds.select(range(min(max_samples, n)))

    if mode == "pack":
        def tokenize_and_chunk(examples):
            all_ids = []
            for text in examples["text"]:
                all_ids.extend(tokenizer(text, truncation=False, padding=False,
                                         return_attention_mask=False)["input_ids"])
            chunks = [all_ids[i:i + seq_len] for i in range(0, len(all_ids) - seq_len, seq_len)]
            return {"input_ids": chunks}

        ds = ds.map(tokenize_and_chunk, batched=True, remove_columns=["text"])

        def collate_fn(examples):
            ids = torch.stack([torch.tensor(e["input_ids"], dtype=torch.long) for e in examples])
            return {"input_ids": ids, "attention_mask": torch.ones_like(ids)}

    elif mode == "subsample":
        # GPT-2 常见:bos/eos/pad 全部共用 <|endoftext|> = 50256,但 bos_token_id 可能为 None
        _candidates = [
            tokenizer.eos_token_id,
            tokenizer.bos_token_id,
            tokenizer.pad_token_id,
        ]
        fallback = next((c for c in _candidates if c is not None), None)
        if fallback is None:
            raise ValueError("tokenizer has no bos/eos/pad token id; cannot build subsample loader")
        bos_id = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else fallback
        eos_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else fallback
        pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else fallback

        def collate_fn(examples):
            input_ids_list, attn_list = [], []
            for ex in examples:
                toks = tokenizer(ex["text"], truncation=False, padding=False,
                                 return_attention_mask=False)["input_ids"]
                if not toks or toks[0] != bos_id:
                    toks = [bos_id] + toks
                if toks[-1] != eos_id:
                    toks = toks + [eos_id]

                L = len(toks)
                if L > seq_len:
                    start = random.randint(0, L - seq_len)
                    toks = toks[start:start + seq_len]
                    attn = [1] * seq_len
                else:
                    attn = [1] * L + [0] * (seq_len - L)
                    toks = toks + [pad_id] * (seq_len - L)

                input_ids_list.append(toks)
                attn_list.append(attn)

            return {
                "input_ids": torch.tensor(input_ids_list, dtype=torch.long),
                "attention_mask": torch.tensor(attn_list, dtype=torch.long),
            }

    else:
        raise ValueError(f"Unknown mode: {mode!r} (expected 'subsample' or 'pack')")

    return DataLoader(
        ds,
        batch_size=batch_size,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True,
        persistent_workers=(num_workers > 0),
    )


def build_debug_dataloader(
    vocab_size: int,
    seq_len: int = 64,
    batch_size: int = 4,
    num_samples: int = 64,
    mask_token_id: int = 50256,
) -> DataLoader:
    ds = TinyDebugDataset(vocab_size, seq_len, num_samples, mask_token_id)
    return DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True)