Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import json | |
| import math | |
| from dataclasses import replace | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| from src.data.anchor_synthetic import load_anchor_synthetic | |
| from src.data.openwebmath_bpe import load_openwebmath_bpe | |
| from src.data.shakespeare import load_shakespeare | |
| from src.data.the_stack_bpe import load_the_stack_bpe | |
| from src.data.tinystories_bpe import load_tinystories_bpe | |
| from src.data.wikitext_bpe import load_wikitext_bpe | |
| from src.model.testformer import TestFormerLM | |
| from src.model.testformer_config import TestFormerConfig, build_testformer_config | |
| def _default_learning_rate(cfg: TestFormerConfig) -> float: | |
| if cfg.d_model <= 384: | |
| return 3.0e-4 | |
| if cfg.d_model <= 640: | |
| return 2.0e-4 | |
| return 1.5e-4 | |
| def _make_cosine_warmup_scheduler( | |
| optimizer: torch.optim.Optimizer, | |
| total_steps: int, | |
| warmup_fraction: float, | |
| ) -> torch.optim.lr_scheduler.LambdaLR: | |
| warmup_steps = max(1, int(total_steps * warmup_fraction)) | |
| def lr_lambda(current_step: int) -> float: | |
| if current_step < warmup_steps: | |
| return float(current_step + 1) / float(warmup_steps) | |
| progress = (current_step - warmup_steps) / max(1, total_steps - warmup_steps) | |
| return 0.5 * (1.0 + math.cos(math.pi * progress)) | |
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) | |
| def load_testformer_dataset( | |
| dataset: str, | |
| seq_len: int, | |
| device: str, | |
| data_dir: str = "data_cache", | |
| the_stack_repo: str = "bigcode/the-stack-smol-xs", | |
| the_stack_lang: str = "python", | |
| the_stack_bytes: int = 8_000_000, | |
| the_stack_vocab_size: int = 4096, | |
| tinystories_repo: str = "roneneldan/TinyStories", | |
| tinystories_bytes: int = 16_000_000, | |
| tinystories_vocab_size: int = 4096, | |
| openwebmath_repo: str = "open-web-math/open-web-math", | |
| openwebmath_bytes: int = 200_000, | |
| openwebmath_vocab_size: int = 256, | |
| wikitext_repo: str = "wikitext", | |
| wikitext_config_name: str = "wikitext-2-raw-v1", | |
| wikitext_bytes: int = 2_000_000, | |
| wikitext_vocab_size: int = 4096, | |
| ) -> tuple[Any, Any]: | |
| if dataset == "anchor-synthetic": | |
| return load_anchor_synthetic(seq_len=24, device=device) | |
| if dataset == "shakespeare": | |
| return load_shakespeare(seq_len=seq_len, device=device, data_dir=data_dir) | |
| if dataset == "the-stack-bpe": | |
| return load_the_stack_bpe( | |
| seq_len=seq_len, | |
| device=device, | |
| data_dir=data_dir, | |
| repo_id=the_stack_repo, | |
| lang=the_stack_lang, | |
| target_bytes=the_stack_bytes, | |
| vocab_size=the_stack_vocab_size, | |
| ) | |
| if dataset == "tinystories-bpe": | |
| return load_tinystories_bpe( | |
| seq_len=seq_len, | |
| device=device, | |
| data_dir=data_dir, | |
| repo_id=tinystories_repo, | |
| target_bytes=tinystories_bytes, | |
| vocab_size=tinystories_vocab_size, | |
| ) | |
| if dataset == "openwebmath-bpe": | |
| return load_openwebmath_bpe( | |
| seq_len=seq_len, | |
| device=device, | |
| data_dir=data_dir, | |
| repo_id=openwebmath_repo, | |
| target_bytes=openwebmath_bytes, | |
| vocab_size=openwebmath_vocab_size, | |
| ) | |
| if dataset == "wikitext-bpe": | |
| return load_wikitext_bpe( | |
| seq_len=seq_len, | |
| device=device, | |
| data_dir=data_dir, | |
| repo_id=wikitext_repo, | |
| config_name=wikitext_config_name, | |
| target_bytes=wikitext_bytes, | |
| vocab_size=wikitext_vocab_size, | |
| ) | |
| raise ValueError(f"Unknown TestFormer dataset: {dataset}") | |
| def evaluate_testformer( | |
| model: TestFormerLM, | |
| dataset: Any, | |
| batch_size: int, | |
| device: str, | |
| max_batches: int = 5, | |
| ) -> dict[str, float]: | |
| model.eval() | |
| total_loss = 0.0 | |
| total_tokens = 0 | |
| with torch.no_grad(): | |
| for _ in range(max_batches): | |
| x, y = dataset.get_batch(batch_size) | |
| x = x.to(device) | |
| y = y.to(device) | |
| out = model(x, y) | |
| total_loss += float(out["loss"].item()) * y.numel() | |
| total_tokens += y.numel() | |
| mean_loss = total_loss / max(1, total_tokens) | |
| return { | |
| "loss": mean_loss, | |
| "bpb": mean_loss / math.log(2.0), | |
| } | |
| def train_testformer( | |
| cfg: TestFormerConfig, | |
| dataset: str = "anchor-synthetic", | |
| device: str = "cpu", | |
| data_dir: str = "data_cache", | |
| steps: int = 100, | |
| batch_size: int = 16, | |
| eval_every: int = 20, | |
| eval_batches: int = 5, | |
| learning_rate: float | None = None, | |
| weight_decay: float = 0.1, | |
| beta1: float = 0.9, | |
| beta2: float = 0.95, | |
| grad_clip: float = 1.0, | |
| warmup_fraction: float = 0.02, | |
| the_stack_repo: str = "bigcode/the-stack-smol-xs", | |
| the_stack_lang: str = "python", | |
| the_stack_bytes: int = 8_000_000, | |
| the_stack_vocab_size: int = 4096, | |
| tinystories_repo: str = "roneneldan/TinyStories", | |
| tinystories_bytes: int = 16_000_000, | |
| tinystories_vocab_size: int = 4096, | |
| openwebmath_repo: str = "open-web-math/open-web-math", | |
| openwebmath_bytes: int = 200_000, | |
| openwebmath_vocab_size: int = 256, | |
| wikitext_repo: str = "wikitext", | |
| wikitext_config_name: str = "wikitext-2-raw-v1", | |
| wikitext_bytes: int = 2_000_000, | |
| wikitext_vocab_size: int = 4096, | |
| ) -> tuple[TestFormerLM, list[dict[str, float]], Any, Any]: | |
| train_data, val_data = load_testformer_dataset( | |
| dataset=dataset, | |
| seq_len=cfg.max_seq_len, | |
| device=device, | |
| data_dir=data_dir, | |
| the_stack_repo=the_stack_repo, | |
| the_stack_lang=the_stack_lang, | |
| the_stack_bytes=the_stack_bytes, | |
| the_stack_vocab_size=the_stack_vocab_size, | |
| tinystories_repo=tinystories_repo, | |
| tinystories_bytes=tinystories_bytes, | |
| tinystories_vocab_size=tinystories_vocab_size, | |
| openwebmath_repo=openwebmath_repo, | |
| openwebmath_bytes=openwebmath_bytes, | |
| openwebmath_vocab_size=openwebmath_vocab_size, | |
| wikitext_repo=wikitext_repo, | |
| wikitext_config_name=wikitext_config_name, | |
| wikitext_bytes=wikitext_bytes, | |
| wikitext_vocab_size=wikitext_vocab_size, | |
| ) | |
| effective_cfg = replace( | |
| cfg, | |
| vocab_size=train_data.vocab_size, | |
| max_seq_len=getattr(train_data, "seq_len", cfg.max_seq_len), | |
| ) | |
| model = TestFormerLM(effective_cfg).to(device) | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=learning_rate or _default_learning_rate(effective_cfg), | |
| betas=(beta1, beta2), | |
| weight_decay=weight_decay, | |
| ) | |
| scheduler = _make_cosine_warmup_scheduler( | |
| optimizer=optimizer, | |
| total_steps=max(steps, 1), | |
| warmup_fraction=warmup_fraction, | |
| ) | |
| history: list[dict[str, float]] = [] | |
| for step in range(steps): | |
| model.train() | |
| x, y = train_data.get_batch(batch_size) | |
| x = x.to(device) | |
| y = y.to(device) | |
| out = model(x, y) | |
| optimizer.zero_grad() | |
| out["loss"].backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) | |
| optimizer.step() | |
| scheduler.step() | |
| if (step + 1) % eval_every == 0 or step == steps - 1: | |
| metrics = evaluate_testformer( | |
| model=model, | |
| dataset=val_data, | |
| batch_size=batch_size, | |
| device=device, | |
| max_batches=eval_batches, | |
| ) | |
| history.append( | |
| { | |
| "step": float(step + 1), | |
| "train_loss": float(out["loss"].item()), | |
| "train_bpb": float(out["loss"].item() / math.log(2.0)), | |
| "val_loss": metrics["loss"], | |
| "val_bpb": metrics["bpb"], | |
| "lr": float(optimizer.param_groups[0]["lr"]), | |
| } | |
| ) | |
| model.training_history = history | |
| return model, history, train_data, val_data | |
| def summarize_testformer_result( | |
| preset_name: str, | |
| motif_name: str, | |
| dataset_name: str, | |
| model: TestFormerLM, | |
| history: list[dict[str, float]], | |
| ) -> dict[str, Any]: | |
| last = history[-1] if history else {} | |
| return { | |
| "preset": preset_name, | |
| "motif": motif_name, | |
| "dataset": dataset_name, | |
| "parameters": model.parameter_count(), | |
| "body_parameters": model.body_parameter_count(), | |
| "d_model": model.cfg.d_model, | |
| "n_layers": model.cfg.n_layers, | |
| "n_heads": model.cfg.n_heads, | |
| "d_ff": model.cfg.d_ff, | |
| "qk_dim": model.cfg.qk_dim, | |
| "v_dim": model.cfg.v_dim, | |
| "final_train_loss": last.get("train_loss"), | |
| "final_val_loss": last.get("val_loss"), | |
| "final_val_bpb": last.get("val_bpb"), | |
| "history": history, | |
| } | |
| def save_testformer_json(payload: Any, path: str | Path) -> None: | |
| path = Path(path) | |
| if path.parent: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| path.write_text(json.dumps(payload, indent=2), encoding="utf-8") | |
| def build_runtime_testformer_config( | |
| preset_name: str, | |
| motif_name: str, | |
| vocab_size: int = 32000, | |
| seq_len: int | None = None, | |
| resid_dropout: float = 0.0, | |
| attn_dropout: float = 0.0, | |
| emb_dropout: float = 0.0, | |
| ) -> TestFormerConfig: | |
| return build_testformer_config( | |
| preset_name=preset_name, | |
| motif_name=motif_name, | |
| vocab_size=vocab_size, | |
| max_seq_len=seq_len, | |
| resid_dropout=resid_dropout, | |
| attn_dropout=attn_dropout, | |
| emb_dropout=emb_dropout, | |
| ) | |