from __future__ import annotations import json from pathlib import Path import torch from src.data.the_stack import _normalize_repo_name from src.data.the_stack_bpe import BPETokenDataset, _train_tokenizer def _collect_wikitext_text( repo_id: str, config_name: str, split: str, target_bytes: int, ) -> str: from datasets import load_dataset ds = load_dataset(repo_id, config_name, split=split) chunks: list[str] = [] total = 0 for sample in ds: text = sample.get("text") or "" if not isinstance(text, str): continue block = text.strip("\n") if not block: continue block = block + "\n\n" chunks.append(block) total += len(block.encode("utf-8")) if total >= target_bytes: break if total == 0: raise RuntimeError(f"No usable WikiText text found for {repo_id}:{config_name}:{split}") return "".join(chunks) def load_wikitext_bpe( seq_len: int = 256, device: str = "cpu", data_dir: str = "data_cache", repo_id: str = "wikitext", config_name: str = "wikitext-2-raw-v1", target_bytes: int = 2_000_000, vocab_size: int = 4096, ) -> tuple[BPETokenDataset, BPETokenDataset]: Path(data_dir).mkdir(parents=True, exist_ok=True) prefix = ( f"{_normalize_repo_name(repo_id)}_{config_name.replace('-', '_')}_{target_bytes}_bpe{vocab_size}" ) tokenizer_path = Path(data_dir) / f"{prefix}_tokenizer.json" train_ids_path = Path(data_dir) / f"{prefix}_train_ids.pt" val_ids_path = Path(data_dir) / f"{prefix}_val_ids.pt" meta_path = Path(data_dir) / f"{prefix}_meta.json" if tokenizer_path.exists() and train_ids_path.exists() and val_ids_path.exists() and meta_path.exists(): train_ids = torch.load(train_ids_path, map_location="cpu") val_ids = torch.load(val_ids_path, map_location="cpu") meta = json.loads(meta_path.read_text(encoding="utf-8")) actual_vocab_size = int(meta["vocab_size"]) else: train_text = _collect_wikitext_text( repo_id=repo_id, config_name=config_name, split="train", target_bytes=target_bytes, ) val_text = _collect_wikitext_text( repo_id=repo_id, config_name=config_name, split="validation", target_bytes=max(250_000, target_bytes // 8), ) tokenizer = _train_tokenizer(text=train_text, vocab_size=vocab_size) train_ids = torch.tensor(tokenizer.encode(train_text).ids, dtype=torch.long) val_ids = torch.tensor(tokenizer.encode(val_text).ids, dtype=torch.long) actual_vocab_size = tokenizer.get_vocab_size() tokenizer.save(str(tokenizer_path)) torch.save(train_ids, train_ids_path) torch.save(val_ids, val_ids_path) meta_path.write_text( json.dumps( { "repo_id": repo_id, "config_name": config_name, "target_bytes": target_bytes, "vocab_size": actual_vocab_size, "train_token_count": int(train_ids.numel()), "val_token_count": int(val_ids.numel()), }, indent=2, ), encoding="utf-8", ) train = BPETokenDataset( token_ids=train_ids, vocab_size=actual_vocab_size, split="train", seq_len=seq_len, device=device, split_data=False, ) val = BPETokenDataset( token_ids=val_ids, vocab_size=actual_vocab_size, split="val", seq_len=seq_len, device=device, split_data=False, ) return train, val