File size: 4,881 Bytes
8125804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import json
from pathlib import Path

import torch


class ByteCorpusDataset:
    def __init__(
        self,
        data: bytes,
        split: str = "train",
        seq_len: int = 256,
        device: str = "cpu",
    ) -> None:
        assert split in ("train", "val")
        self.seq_len = seq_len
        self.device = device

        tensor = torch.tensor(list(data), dtype=torch.long)
        n = int(0.9 * len(tensor))
        self.data = tensor[:n] if split == "train" else tensor[n:]

    @property
    def vocab_size(self) -> int:
        return 256

    def get_batch(self, batch_size: int) -> tuple[torch.Tensor, torch.Tensor]:
        max_start = len(self.data) - self.seq_len - 1
        if max_start <= 0:
            raise ValueError(
                f"Corpus too small for seq_len={self.seq_len}. "
                f"Need more than {self.seq_len + 1} bytes, got {len(self.data)}."
            )
        starts = torch.randint(0, max_start, (batch_size,))
        x = torch.stack([self.data[s: s + self.seq_len] for s in starts])
        y = torch.stack([self.data[s + 1: s + self.seq_len + 1] for s in starts])
        return x.to(self.device), y.to(self.device)

    def __len__(self) -> int:
        return len(self.data)


def _normalize_repo_name(repo_id: str) -> str:
    return repo_id.replace("/", "__").replace("-", "_")


def _stream_the_stack_text(
    repo_id: str,
    lang: str,
    target_bytes: int,
) -> bytes:
    try:
        from huggingface_hub import hf_hub_download

        json_path = hf_hub_download(repo_id, f"data/{lang}/data.json", repo_type="dataset")
        chunks: list[bytes] = []
        total = 0
        with open(json_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                sample = json.loads(line)
                text = sample.get("content") or ""
                if not isinstance(text, str) or not text.strip():
                    continue
                encoded = text.encode("utf-8", errors="ignore") + b"\n\n"
                chunks.append(encoded)
                total += len(encoded)
                if total >= target_bytes:
                    break
        if total > 0:
            return b"".join(chunks)
    except Exception:
        pass

    try:
        from datasets import load_dataset
    except ImportError as exc:
        raise ImportError(
            "datasets package is required for The Stack loading. "
            "Install with `pip install datasets`."
        ) from exc

    try:
        ds = load_dataset(repo_id, lang, split="train", streaming=True)
    except Exception:
        data_dir = f"data/{lang}"
        try:
            ds = load_dataset(repo_id, data_dir=data_dir, split="train", streaming=True)
        except Exception as exc:
            raise RuntimeError(
                f"Failed to load repo={repo_id} lang={lang}. "
                "If this is a gated dataset, accept the Hugging Face terms first or "
                "switch to a public Stack-family subset such as `bigcode/the-stack-smol-xs`."
            ) from exc

    chunks: list[bytes] = []
    total = 0
    for sample in ds:
        text = sample.get("content") or sample.get("text") or ""
        if not isinstance(text, str) or not text.strip():
            continue
        encoded = text.encode("utf-8", errors="ignore") + b"\n\n"
        chunks.append(encoded)
        total += len(encoded)
        if total >= target_bytes:
            break

    if total == 0:
        raise RuntimeError(f"No text content collected from repo={repo_id} lang={lang}.")
    return b"".join(chunks)


def load_the_stack_text(
    data_dir: str = "data_cache",
    repo_id: str = "bigcode/the-stack-smol-xs",
    lang: str = "python",
    target_bytes: int = 8_000_000,
) -> bytes:
    Path(data_dir).mkdir(parents=True, exist_ok=True)
    cache_name = f"{_normalize_repo_name(repo_id)}_{lang}_{target_bytes}.bin"
    cache_path = os.path.join(data_dir, cache_name)

    if os.path.exists(cache_path):
        return Path(cache_path).read_bytes()

    data = _stream_the_stack_text(repo_id=repo_id, lang=lang, target_bytes=target_bytes)
    Path(cache_path).write_bytes(data)
    return data


def load_the_stack(
    seq_len: int = 256,
    device: str = "cpu",
    data_dir: str = "data_cache",
    repo_id: str = "bigcode/the-stack-smol-xs",
    lang: str = "python",
    target_bytes: int = 8_000_000,
) -> tuple[ByteCorpusDataset, ByteCorpusDataset]:
    data = load_the_stack_text(
        data_dir=data_dir,
        repo_id=repo_id,
        lang=lang,
        target_bytes=target_bytes,
    )
    train = ByteCorpusDataset(data=data, split="train", seq_len=seq_len, device=device)
    val = ByteCorpusDataset(data=data, split="val", seq_len=seq_len, device=device)
    return train, val