Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| import torch | |
| from src.data.the_stack import _normalize_repo_name | |
| from src.data.the_stack_bpe import BPETokenDataset, _train_tokenizer | |
| def _collect_wikitext_text( | |
| repo_id: str, | |
| config_name: str, | |
| split: str, | |
| target_bytes: int, | |
| ) -> str: | |
| from datasets import load_dataset | |
| ds = load_dataset(repo_id, config_name, split=split) | |
| chunks: list[str] = [] | |
| total = 0 | |
| for sample in ds: | |
| text = sample.get("text") or "" | |
| if not isinstance(text, str): | |
| continue | |
| block = text.strip("\n") | |
| if not block: | |
| continue | |
| block = block + "\n\n" | |
| chunks.append(block) | |
| total += len(block.encode("utf-8")) | |
| if total >= target_bytes: | |
| break | |
| if total == 0: | |
| raise RuntimeError(f"No usable WikiText text found for {repo_id}:{config_name}:{split}") | |
| return "".join(chunks) | |
| def load_wikitext_bpe( | |
| seq_len: int = 256, | |
| device: str = "cpu", | |
| data_dir: str = "data_cache", | |
| repo_id: str = "wikitext", | |
| config_name: str = "wikitext-2-raw-v1", | |
| target_bytes: int = 2_000_000, | |
| vocab_size: int = 4096, | |
| ) -> tuple[BPETokenDataset, BPETokenDataset]: | |
| Path(data_dir).mkdir(parents=True, exist_ok=True) | |
| prefix = ( | |
| f"{_normalize_repo_name(repo_id)}_{config_name.replace('-', '_')}_{target_bytes}_bpe{vocab_size}" | |
| ) | |
| tokenizer_path = Path(data_dir) / f"{prefix}_tokenizer.json" | |
| train_ids_path = Path(data_dir) / f"{prefix}_train_ids.pt" | |
| val_ids_path = Path(data_dir) / f"{prefix}_val_ids.pt" | |
| meta_path = Path(data_dir) / f"{prefix}_meta.json" | |
| if tokenizer_path.exists() and train_ids_path.exists() and val_ids_path.exists() and meta_path.exists(): | |
| train_ids = torch.load(train_ids_path, map_location="cpu") | |
| val_ids = torch.load(val_ids_path, map_location="cpu") | |
| meta = json.loads(meta_path.read_text(encoding="utf-8")) | |
| actual_vocab_size = int(meta["vocab_size"]) | |
| else: | |
| train_text = _collect_wikitext_text( | |
| repo_id=repo_id, | |
| config_name=config_name, | |
| split="train", | |
| target_bytes=target_bytes, | |
| ) | |
| val_text = _collect_wikitext_text( | |
| repo_id=repo_id, | |
| config_name=config_name, | |
| split="validation", | |
| target_bytes=max(250_000, target_bytes // 8), | |
| ) | |
| tokenizer = _train_tokenizer(text=train_text, vocab_size=vocab_size) | |
| train_ids = torch.tensor(tokenizer.encode(train_text).ids, dtype=torch.long) | |
| val_ids = torch.tensor(tokenizer.encode(val_text).ids, dtype=torch.long) | |
| actual_vocab_size = tokenizer.get_vocab_size() | |
| tokenizer.save(str(tokenizer_path)) | |
| torch.save(train_ids, train_ids_path) | |
| torch.save(val_ids, val_ids_path) | |
| meta_path.write_text( | |
| json.dumps( | |
| { | |
| "repo_id": repo_id, | |
| "config_name": config_name, | |
| "target_bytes": target_bytes, | |
| "vocab_size": actual_vocab_size, | |
| "train_token_count": int(train_ids.numel()), | |
| "val_token_count": int(val_ids.numel()), | |
| }, | |
| indent=2, | |
| ), | |
| encoding="utf-8", | |
| ) | |
| train = BPETokenDataset( | |
| token_ids=train_ids, | |
| vocab_size=actual_vocab_size, | |
| split="train", | |
| seq_len=seq_len, | |
| device=device, | |
| split_data=False, | |
| ) | |
| val = BPETokenDataset( | |
| token_ids=val_ids, | |
| vocab_size=actual_vocab_size, | |
| split="val", | |
| seq_len=seq_len, | |
| device=device, | |
| split_data=False, | |
| ) | |
| return train, val | |