File size: 3,774 Bytes
742c943
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from __future__ import annotations

import json
from pathlib import Path

import torch

from src.data.the_stack import _normalize_repo_name
from src.data.the_stack_bpe import BPETokenDataset, _train_tokenizer


def _collect_wikitext_text(
    repo_id: str,
    config_name: str,
    split: str,
    target_bytes: int,
) -> str:
    from datasets import load_dataset

    ds = load_dataset(repo_id, config_name, split=split)
    chunks: list[str] = []
    total = 0
    for sample in ds:
        text = sample.get("text") or ""
        if not isinstance(text, str):
            continue
        block = text.strip("\n")
        if not block:
            continue
        block = block + "\n\n"
        chunks.append(block)
        total += len(block.encode("utf-8"))
        if total >= target_bytes:
            break
    if total == 0:
        raise RuntimeError(f"No usable WikiText text found for {repo_id}:{config_name}:{split}")
    return "".join(chunks)


def load_wikitext_bpe(
    seq_len: int = 256,
    device: str = "cpu",
    data_dir: str = "data_cache",
    repo_id: str = "wikitext",
    config_name: str = "wikitext-2-raw-v1",
    target_bytes: int = 2_000_000,
    vocab_size: int = 4096,
) -> tuple[BPETokenDataset, BPETokenDataset]:
    Path(data_dir).mkdir(parents=True, exist_ok=True)
    prefix = (
        f"{_normalize_repo_name(repo_id)}_{config_name.replace('-', '_')}_{target_bytes}_bpe{vocab_size}"
    )
    tokenizer_path = Path(data_dir) / f"{prefix}_tokenizer.json"
    train_ids_path = Path(data_dir) / f"{prefix}_train_ids.pt"
    val_ids_path = Path(data_dir) / f"{prefix}_val_ids.pt"
    meta_path = Path(data_dir) / f"{prefix}_meta.json"

    if tokenizer_path.exists() and train_ids_path.exists() and val_ids_path.exists() and meta_path.exists():
        train_ids = torch.load(train_ids_path, map_location="cpu")
        val_ids = torch.load(val_ids_path, map_location="cpu")
        meta = json.loads(meta_path.read_text(encoding="utf-8"))
        actual_vocab_size = int(meta["vocab_size"])
    else:
        train_text = _collect_wikitext_text(
            repo_id=repo_id,
            config_name=config_name,
            split="train",
            target_bytes=target_bytes,
        )
        val_text = _collect_wikitext_text(
            repo_id=repo_id,
            config_name=config_name,
            split="validation",
            target_bytes=max(250_000, target_bytes // 8),
        )
        tokenizer = _train_tokenizer(text=train_text, vocab_size=vocab_size)
        train_ids = torch.tensor(tokenizer.encode(train_text).ids, dtype=torch.long)
        val_ids = torch.tensor(tokenizer.encode(val_text).ids, dtype=torch.long)
        actual_vocab_size = tokenizer.get_vocab_size()

        tokenizer.save(str(tokenizer_path))
        torch.save(train_ids, train_ids_path)
        torch.save(val_ids, val_ids_path)
        meta_path.write_text(
            json.dumps(
                {
                    "repo_id": repo_id,
                    "config_name": config_name,
                    "target_bytes": target_bytes,
                    "vocab_size": actual_vocab_size,
                    "train_token_count": int(train_ids.numel()),
                    "val_token_count": int(val_ids.numel()),
                },
                indent=2,
            ),
            encoding="utf-8",
        )

    train = BPETokenDataset(
        token_ids=train_ids,
        vocab_size=actual_vocab_size,
        split="train",
        seq_len=seq_len,
        device=device,
        split_data=False,
    )
    val = BPETokenDataset(
        token_ids=val_ids,
        vocab_size=actual_vocab_size,
        split="val",
        seq_len=seq_len,
        device=device,
        split_data=False,
    )
    return train, val