abpt / src /data /the_stack.py
Search
feat: add src/ module for script imports
8125804
import os
import json
from pathlib import Path
import torch
class ByteCorpusDataset:
def __init__(
self,
data: bytes,
split: str = "train",
seq_len: int = 256,
device: str = "cpu",
) -> None:
assert split in ("train", "val")
self.seq_len = seq_len
self.device = device
tensor = torch.tensor(list(data), dtype=torch.long)
n = int(0.9 * len(tensor))
self.data = tensor[:n] if split == "train" else tensor[n:]
@property
def vocab_size(self) -> int:
return 256
def get_batch(self, batch_size: int) -> tuple[torch.Tensor, torch.Tensor]:
max_start = len(self.data) - self.seq_len - 1
if max_start <= 0:
raise ValueError(
f"Corpus too small for seq_len={self.seq_len}. "
f"Need more than {self.seq_len + 1} bytes, got {len(self.data)}."
)
starts = torch.randint(0, max_start, (batch_size,))
x = torch.stack([self.data[s: s + self.seq_len] for s in starts])
y = torch.stack([self.data[s + 1: s + self.seq_len + 1] for s in starts])
return x.to(self.device), y.to(self.device)
def __len__(self) -> int:
return len(self.data)
def _normalize_repo_name(repo_id: str) -> str:
return repo_id.replace("/", "__").replace("-", "_")
def _stream_the_stack_text(
repo_id: str,
lang: str,
target_bytes: int,
) -> bytes:
try:
from huggingface_hub import hf_hub_download
json_path = hf_hub_download(repo_id, f"data/{lang}/data.json", repo_type="dataset")
chunks: list[bytes] = []
total = 0
with open(json_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
sample = json.loads(line)
text = sample.get("content") or ""
if not isinstance(text, str) or not text.strip():
continue
encoded = text.encode("utf-8", errors="ignore") + b"\n\n"
chunks.append(encoded)
total += len(encoded)
if total >= target_bytes:
break
if total > 0:
return b"".join(chunks)
except Exception:
pass
try:
from datasets import load_dataset
except ImportError as exc:
raise ImportError(
"datasets package is required for The Stack loading. "
"Install with `pip install datasets`."
) from exc
try:
ds = load_dataset(repo_id, lang, split="train", streaming=True)
except Exception:
data_dir = f"data/{lang}"
try:
ds = load_dataset(repo_id, data_dir=data_dir, split="train", streaming=True)
except Exception as exc:
raise RuntimeError(
f"Failed to load repo={repo_id} lang={lang}. "
"If this is a gated dataset, accept the Hugging Face terms first or "
"switch to a public Stack-family subset such as `bigcode/the-stack-smol-xs`."
) from exc
chunks: list[bytes] = []
total = 0
for sample in ds:
text = sample.get("content") or sample.get("text") or ""
if not isinstance(text, str) or not text.strip():
continue
encoded = text.encode("utf-8", errors="ignore") + b"\n\n"
chunks.append(encoded)
total += len(encoded)
if total >= target_bytes:
break
if total == 0:
raise RuntimeError(f"No text content collected from repo={repo_id} lang={lang}.")
return b"".join(chunks)
def load_the_stack_text(
data_dir: str = "data_cache",
repo_id: str = "bigcode/the-stack-smol-xs",
lang: str = "python",
target_bytes: int = 8_000_000,
) -> bytes:
Path(data_dir).mkdir(parents=True, exist_ok=True)
cache_name = f"{_normalize_repo_name(repo_id)}_{lang}_{target_bytes}.bin"
cache_path = os.path.join(data_dir, cache_name)
if os.path.exists(cache_path):
return Path(cache_path).read_bytes()
data = _stream_the_stack_text(repo_id=repo_id, lang=lang, target_bytes=target_bytes)
Path(cache_path).write_bytes(data)
return data
def load_the_stack(
seq_len: int = 256,
device: str = "cpu",
data_dir: str = "data_cache",
repo_id: str = "bigcode/the-stack-smol-xs",
lang: str = "python",
target_bytes: int = 8_000_000,
) -> tuple[ByteCorpusDataset, ByteCorpusDataset]:
data = load_the_stack_text(
data_dir=data_dir,
repo_id=repo_id,
lang=lang,
target_bytes=target_bytes,
)
train = ByteCorpusDataset(data=data, split="train", seq_len=seq_len, device=device)
val = ByteCorpusDataset(data=data, split="val", seq_len=seq_len, device=device)
return train, val