Spaces:
Running on Zero
Running on Zero
| import os | |
| import json | |
| from pathlib import Path | |
| import torch | |
| class ByteCorpusDataset: | |
| def __init__( | |
| self, | |
| data: bytes, | |
| split: str = "train", | |
| seq_len: int = 256, | |
| device: str = "cpu", | |
| ) -> None: | |
| assert split in ("train", "val") | |
| self.seq_len = seq_len | |
| self.device = device | |
| tensor = torch.tensor(list(data), dtype=torch.long) | |
| n = int(0.9 * len(tensor)) | |
| self.data = tensor[:n] if split == "train" else tensor[n:] | |
| def vocab_size(self) -> int: | |
| return 256 | |
| def get_batch(self, batch_size: int) -> tuple[torch.Tensor, torch.Tensor]: | |
| max_start = len(self.data) - self.seq_len - 1 | |
| if max_start <= 0: | |
| raise ValueError( | |
| f"Corpus too small for seq_len={self.seq_len}. " | |
| f"Need more than {self.seq_len + 1} bytes, got {len(self.data)}." | |
| ) | |
| starts = torch.randint(0, max_start, (batch_size,)) | |
| x = torch.stack([self.data[s: s + self.seq_len] for s in starts]) | |
| y = torch.stack([self.data[s + 1: s + self.seq_len + 1] for s in starts]) | |
| return x.to(self.device), y.to(self.device) | |
| def __len__(self) -> int: | |
| return len(self.data) | |
| def _normalize_repo_name(repo_id: str) -> str: | |
| return repo_id.replace("/", "__").replace("-", "_") | |
| def _stream_the_stack_text( | |
| repo_id: str, | |
| lang: str, | |
| target_bytes: int, | |
| ) -> bytes: | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| json_path = hf_hub_download(repo_id, f"data/{lang}/data.json", repo_type="dataset") | |
| chunks: list[bytes] = [] | |
| total = 0 | |
| with open(json_path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| sample = json.loads(line) | |
| text = sample.get("content") or "" | |
| if not isinstance(text, str) or not text.strip(): | |
| continue | |
| encoded = text.encode("utf-8", errors="ignore") + b"\n\n" | |
| chunks.append(encoded) | |
| total += len(encoded) | |
| if total >= target_bytes: | |
| break | |
| if total > 0: | |
| return b"".join(chunks) | |
| except Exception: | |
| pass | |
| try: | |
| from datasets import load_dataset | |
| except ImportError as exc: | |
| raise ImportError( | |
| "datasets package is required for The Stack loading. " | |
| "Install with `pip install datasets`." | |
| ) from exc | |
| try: | |
| ds = load_dataset(repo_id, lang, split="train", streaming=True) | |
| except Exception: | |
| data_dir = f"data/{lang}" | |
| try: | |
| ds = load_dataset(repo_id, data_dir=data_dir, split="train", streaming=True) | |
| except Exception as exc: | |
| raise RuntimeError( | |
| f"Failed to load repo={repo_id} lang={lang}. " | |
| "If this is a gated dataset, accept the Hugging Face terms first or " | |
| "switch to a public Stack-family subset such as `bigcode/the-stack-smol-xs`." | |
| ) from exc | |
| chunks: list[bytes] = [] | |
| total = 0 | |
| for sample in ds: | |
| text = sample.get("content") or sample.get("text") or "" | |
| if not isinstance(text, str) or not text.strip(): | |
| continue | |
| encoded = text.encode("utf-8", errors="ignore") + b"\n\n" | |
| chunks.append(encoded) | |
| total += len(encoded) | |
| if total >= target_bytes: | |
| break | |
| if total == 0: | |
| raise RuntimeError(f"No text content collected from repo={repo_id} lang={lang}.") | |
| return b"".join(chunks) | |
| def load_the_stack_text( | |
| data_dir: str = "data_cache", | |
| repo_id: str = "bigcode/the-stack-smol-xs", | |
| lang: str = "python", | |
| target_bytes: int = 8_000_000, | |
| ) -> bytes: | |
| Path(data_dir).mkdir(parents=True, exist_ok=True) | |
| cache_name = f"{_normalize_repo_name(repo_id)}_{lang}_{target_bytes}.bin" | |
| cache_path = os.path.join(data_dir, cache_name) | |
| if os.path.exists(cache_path): | |
| return Path(cache_path).read_bytes() | |
| data = _stream_the_stack_text(repo_id=repo_id, lang=lang, target_bytes=target_bytes) | |
| Path(cache_path).write_bytes(data) | |
| return data | |
| def load_the_stack( | |
| seq_len: int = 256, | |
| device: str = "cpu", | |
| data_dir: str = "data_cache", | |
| repo_id: str = "bigcode/the-stack-smol-xs", | |
| lang: str = "python", | |
| target_bytes: int = 8_000_000, | |
| ) -> tuple[ByteCorpusDataset, ByteCorpusDataset]: | |
| data = load_the_stack_text( | |
| data_dir=data_dir, | |
| repo_id=repo_id, | |
| lang=lang, | |
| target_bytes=target_bytes, | |
| ) | |
| train = ByteCorpusDataset(data=data, split="train", seq_len=seq_len, device=device) | |
| val = ByteCorpusDataset(data=data, split="val", seq_len=seq_len, device=device) | |
| return train, val | |