Text Generation
Transformers
PyTorch
English
taonet_mini_t2
taonet
taotern
ssm
state-space-model
dplr
custom_code
experimental
Instructions to use TaoTern/TaoNet-mini-T2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TaoTern/TaoNet-mini-T2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="TaoTern/TaoNet-mini-T2", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("TaoTern/TaoNet-mini-T2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use TaoTern/TaoNet-mini-T2 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "TaoTern/TaoNet-mini-T2" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/TaoTern/TaoNet-mini-T2
- SGLang
How to use TaoTern/TaoNet-mini-T2 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use TaoTern/TaoNet-mini-T2 with Docker Model Runner:
docker model run hf.co/TaoTern/TaoNet-mini-T2
| import json | |
| from pathlib import Path | |
| def md(text): | |
| return { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": text.strip("\n").splitlines(keepends=True), | |
| } | |
| def code(text): | |
| return { | |
| "cell_type": "code", | |
| "execution_count": None, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": text.strip("\n").splitlines(keepends=True), | |
| } | |
| QUICK_NOTEBOOK_PATH = Path("output/jupyter-notebook/gamma-s4-sinewave-benchmark.ipynb") | |
| RESEARCH_NOTEBOOK_PATH = Path("output/jupyter-notebook/gamma-s4-research-benchmark.ipynb") | |
| def setup_cell(): | |
| return code( | |
| r""" | |
| import os | |
| import sys | |
| import subprocess | |
| import importlib | |
| from pathlib import Path | |
| IN_COLAB = "google.colab" in sys.modules | |
| REPO_DIR = Path.cwd() | |
| if IN_COLAB: | |
| print("Running in Google Colab") | |
| from google.colab import userdata | |
| from getpass import getpass | |
| REPO_NAME = "gamma_ssm_s4_v2" | |
| REPO_DIR = Path("/content") / REPO_NAME | |
| GITHUB_REPO = "StarMists/gamma_SSM_S4_enhanced" | |
| token = os.environ.get("GITHUB_TOKEN") or os.environ.get("GH_TOKEN") | |
| if not token: | |
| try: | |
| token = userdata.get("GITHUB_TOKEN") | |
| except Exception: | |
| token = None | |
| if not token: | |
| token = getpass("GitHub personal access token for private repo access: ").strip() | |
| clone_url = f"https://{token}@github.com/{GITHUB_REPO}.git" | |
| if REPO_DIR.exists(): | |
| subprocess.run(["git", "-C", str(REPO_DIR), "fetch", "origin"], check=True) | |
| subprocess.run(["git", "-C", str(REPO_DIR), "checkout", "main"], check=True) | |
| subprocess.run(["git", "-C", str(REPO_DIR), "reset", "--hard", "origin/main"], check=True) | |
| else: | |
| subprocess.run(["git", "clone", clone_url, str(REPO_DIR)], check=True) | |
| os.chdir(REPO_DIR) | |
| sys.path.insert(0, str(REPO_DIR)) | |
| else: | |
| print("Running locally from", REPO_DIR) | |
| sys.path.insert(0, str(REPO_DIR)) | |
| importlib.invalidate_caches() | |
| for name in list(sys.modules): | |
| if ( | |
| name == "gamma_space_model" | |
| or name.startswith("gamma_space_model.") | |
| or name == "csrc" | |
| or name.startswith("csrc.") | |
| or name == "tilelang" | |
| or name.startswith("tilelang.") | |
| ): | |
| del sys.modules[name] | |
| try: | |
| commit = subprocess.check_output(["git", "-C", str(REPO_DIR), "rev-parse", "--short", "HEAD"], text=True).strip() | |
| print("Repo commit:", commit) | |
| except Exception: | |
| pass | |
| """ | |
| ) | |
| def imports_cell(): | |
| return code( | |
| r""" | |
| import math | |
| import random | |
| import time | |
| import urllib.request | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, TensorDataset | |
| from gamma_space_model import GammaSingleBlock, GammaS4Block, GammaS4MinimalBlock, S4TernaryDPLRBlock | |
| SEED = 7 | |
| random.seed(SEED) | |
| np.random.seed(SEED) | |
| torch.manual_seed(SEED) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(SEED) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| USE_AMP = DEVICE.type == "cuda" | |
| if hasattr(torch, "amp") and hasattr(torch.amp, "GradScaler"): | |
| scaler = torch.amp.GradScaler("cuda", enabled=USE_AMP) | |
| else: | |
| scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP) | |
| def synchronize(): | |
| if DEVICE.type == "cuda": | |
| torch.cuda.synchronize() | |
| print("Device:", DEVICE) | |
| print("Deployment cache available on GammaS4Block:", hasattr(GammaS4Block, "allocate_deployment_cache")) | |
| if DEVICE.type != "cuda": | |
| print("WARNING: running on CPU. Treat speed numbers as smoke-test only, not main benchmark evidence.") | |
| """ | |
| ) | |
| def shared_helpers_cell(): | |
| return code( | |
| r""" | |
| def make_forecasting_split(config, split, seed): | |
| rng = np.random.default_rng(seed) | |
| count = config["train_samples"] if split == "train" else config["val_samples"] | |
| seq_len = config["seq_len"] | |
| features = config["features"] | |
| complexity = config["complexity"] | |
| t = np.linspace(0.0, 1.0, seq_len + 1, dtype=np.float32) | |
| data = np.zeros((count, seq_len + 1, features), dtype=np.float32) | |
| for i in range(count): | |
| phase = rng.uniform(0.0, 2.0 * np.pi) | |
| chirp = np.sin(2.0 * np.pi * (1.0 + 1.75 * t**2) * rng.uniform(0.9, 1.15) + phase) | |
| slow = np.sin(2.0 * np.pi * (0.4 + 0.25 * complexity) * t + 0.7 * phase) | |
| medium = np.sin(2.0 * np.pi * (2.5 + complexity) * t + 1.1 * phase) | |
| fast = np.cos(2.0 * np.pi * (5.0 + 2.0 * complexity) * t + 1.5 * phase) | |
| bursts = (np.sin(2.0 * np.pi * (3.0 + complexity) * t + 0.4 * phase) > 0.8).astype(np.float32) | |
| bursts = bursts * np.sin(2.0 * np.pi * (10.0 + complexity) * t + 0.3 * phase) | |
| delayed = np.roll(chirp, 4 * complexity) | |
| modulated = medium * (1.0 + 0.35 * slow) | |
| components = [chirp, slow, medium, fast, bursts, delayed, modulated] | |
| for channel in range(features): | |
| weights = rng.normal(0.0, 1.0, size=len(components)).astype(np.float32) | |
| weights[: 2 + complexity] *= 1.2 | |
| signal = sum(w * c for w, c in zip(weights, components)) | |
| if complexity >= 2: | |
| signal += 0.20 * np.tanh(np.roll(signal, channel + 1)) | |
| if complexity >= 3: | |
| signal += 0.10 * np.sin(signal * (0.5 + 0.08 * channel)) | |
| signal += rng.normal(0.0, 0.03 + 0.01 * complexity, size=seq_len + 1).astype(np.float32) | |
| data[i, :, channel] = signal | |
| data -= data.mean(axis=1, keepdims=True) | |
| data /= data.std(axis=1, keepdims=True) + 1e-5 | |
| return TensorDataset(torch.from_numpy(data[:, :-1, :]), torch.from_numpy(data[:, 1:, :])) | |
| class SequenceForecaster(nn.Module): | |
| def __init__(self, input_dim, model_dim, layers, block_factory): | |
| super().__init__() | |
| self.in_proj = nn.Linear(input_dim, model_dim) | |
| self.layers = nn.ModuleList([block_factory(model_dim) for _ in range(layers)]) | |
| self.out_proj = nn.Linear(model_dim, input_dim) | |
| def forward(self, x): | |
| x = self.in_proj(x) | |
| for layer in self.layers: | |
| x, _ = layer(x, state=None, return_state=False) | |
| return self.out_proj(x), None | |
| def build_forecasting_model(kind, config, overrides=None): | |
| overrides = overrides or {} | |
| d_model = config["d_model"] | |
| hidden_dim = config["hidden_dim"] | |
| num_layers = config["num_layers"] | |
| input_dim = config["features"] | |
| if kind == "gamma_baseline": | |
| block_factory = lambda width: GammaSingleBlock( | |
| d_model=width, | |
| hidden_dim=hidden_dim, | |
| dropout=0.0, | |
| ) | |
| elif kind == "gamma_s4_enhanced": | |
| block_factory = lambda width: GammaS4Block( | |
| d_model=width, | |
| hidden_dim=hidden_dim, | |
| kernel_mode=overrides.get("kernel_mode", "auto"), | |
| kernel_threshold=overrides.get("kernel_threshold", 384), | |
| discretization=overrides.get("discretization", "bilinear"), | |
| gate=overrides.get("gate", True), | |
| input_gate=overrides.get("input_gate", True), | |
| activation=overrides.get("activation", "gelu"), | |
| use_D=overrides.get("use_D", True), | |
| layer_scale_init=overrides.get("layer_scale_init", 0.1), | |
| ) | |
| elif kind == "gamma_s4_minimal": | |
| block_factory = lambda width: GammaS4MinimalBlock( | |
| d_model=width, | |
| hidden_dim=hidden_dim, | |
| kernel_mode=overrides.get("kernel_mode", "auto"), | |
| kernel_threshold=overrides.get("kernel_threshold", 384), | |
| discretization=overrides.get("discretization", "bilinear"), | |
| use_D=overrides.get("use_D", True), | |
| ) | |
| elif kind == "s4_ternary_dplr_ssm": | |
| block_factory = lambda width: S4TernaryDPLRBlock( | |
| d_model=width, | |
| hidden_dim=hidden_dim, | |
| rank=overrides.get("rank", 1), | |
| kernel_mode=overrides.get("kernel_mode", "auto"), | |
| kernel_threshold=overrides.get("kernel_threshold", 256), | |
| gate=overrides.get("gate", True), | |
| input_gate=overrides.get("input_gate", True), | |
| activation=overrides.get("activation", "gelu"), | |
| use_D=overrides.get("use_D", True), | |
| layer_scale_init=overrides.get("layer_scale_init", 0.1), | |
| ) | |
| else: | |
| raise ValueError(kind) | |
| return SequenceForecaster(input_dim, d_model, num_layers, block_factory).to(DEVICE) | |
| def profile_train_step(model, batch_x, batch_y): | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4) | |
| model.train() | |
| batch_x = batch_x.to(DEVICE) | |
| batch_y = batch_y.to(DEVICE) | |
| optimizer.zero_grad(set_to_none=True) | |
| synchronize() | |
| t0 = time.perf_counter() | |
| with torch.autocast(device_type=DEVICE.type, enabled=USE_AMP): | |
| pred, _ = model(batch_x) | |
| loss = F.mse_loss(pred, batch_y) | |
| synchronize() | |
| t1 = time.perf_counter() | |
| scaler.scale(loss).backward() | |
| synchronize() | |
| t2 = time.perf_counter() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| synchronize() | |
| t3 = time.perf_counter() | |
| return { | |
| "forward_ms": 1000.0 * (t1 - t0), | |
| "backward_ms": 1000.0 * (t2 - t1), | |
| "optimizer_ms": 1000.0 * (t3 - t2), | |
| "loss": float(loss.detach().cpu()), | |
| } | |
| def run_epoch(model, loader, optimizer=None): | |
| training = optimizer is not None | |
| model.train(training) | |
| losses = [] | |
| synchronize() | |
| start = time.perf_counter() | |
| for batch_x, batch_y in loader: | |
| batch_x = batch_x.to(DEVICE) | |
| batch_y = batch_y.to(DEVICE) | |
| if training: | |
| optimizer.zero_grad(set_to_none=True) | |
| with torch.autocast(device_type=DEVICE.type, enabled=USE_AMP): | |
| pred, _ = model(batch_x) | |
| loss = F.mse_loss(pred, batch_y) | |
| if training: | |
| scaler.scale(loss).backward() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| losses.append(loss.detach().item()) | |
| synchronize() | |
| return float(np.mean(losses)), time.perf_counter() - start | |
| def benchmark_inference(model, sample_x): | |
| model.eval() | |
| sample_x = sample_x.to(DEVICE) | |
| with torch.no_grad(): | |
| for _ in range(2): | |
| _ = model(sample_x) | |
| synchronize() | |
| t0 = time.perf_counter() | |
| pred, _ = model(sample_x) | |
| synchronize() | |
| full_latency = time.perf_counter() - t0 | |
| hidden = model.in_proj(sample_x) | |
| states, caches, outputs = [], [], [] | |
| synchronize() | |
| t_cache = time.perf_counter() | |
| for layer in model.layers: | |
| ssm = getattr(layer, "ssm", None) | |
| if ssm is None: | |
| states.append(None) | |
| caches.append(None) | |
| else: | |
| states.append(ssm.init_state(sample_x.size(0), DEVICE, hidden.dtype)) | |
| if hasattr(layer, "allocate_inference_cache"): | |
| caches.append(layer.allocate_inference_cache(sample_x.size(0), sample_x.size(1), DEVICE, hidden.dtype)) | |
| else: | |
| caches.append(None) | |
| synchronize() | |
| cache_setup = time.perf_counter() - t_cache | |
| synchronize() | |
| t1 = time.perf_counter() | |
| for step in range(sample_x.size(1)): | |
| token = hidden[:, step, :] | |
| new_outputs = token | |
| for idx, layer in enumerate(model.layers): | |
| if caches[idx] is None: | |
| new_outputs, states[idx] = layer.step(new_outputs, states[idx]) | |
| else: | |
| new_outputs, states[idx] = layer.step(new_outputs, states[idx], cache=caches[idx]) | |
| outputs.append(new_outputs) | |
| recurrent = model.out_proj(torch.stack(outputs, dim=1)) | |
| synchronize() | |
| recurrent_latency = time.perf_counter() - t1 | |
| def run_cached_recurrent(cache_allocator_name): | |
| hidden_local = model.in_proj(sample_x) | |
| states_local, caches_local, outputs_local = [], [], [] | |
| synchronize() | |
| cache_start = time.perf_counter() | |
| for layer in model.layers: | |
| ssm = getattr(layer, "ssm", None) | |
| if ssm is None: | |
| states_local.append(None) | |
| caches_local.append(None) | |
| else: | |
| states_local.append(ssm.init_state(sample_x.size(0), DEVICE, hidden_local.dtype)) | |
| if hasattr(layer, cache_allocator_name): | |
| allocator = getattr(layer, cache_allocator_name) | |
| caches_local.append(allocator(sample_x.size(0), sample_x.size(1), DEVICE, hidden_local.dtype)) | |
| elif hasattr(layer, "allocate_deployment_cache"): | |
| caches_local.append(layer.allocate_deployment_cache(sample_x.size(0), sample_x.size(1), DEVICE, hidden_local.dtype)) | |
| elif hasattr(layer, "allocate_inference_cache"): | |
| caches_local.append(layer.allocate_inference_cache(sample_x.size(0), sample_x.size(1), DEVICE, hidden_local.dtype)) | |
| else: | |
| caches_local.append(None) | |
| synchronize() | |
| cache_elapsed = time.perf_counter() - cache_start | |
| synchronize() | |
| start = time.perf_counter() | |
| for step in range(sample_x.size(1)): | |
| token = hidden_local[:, step, :] | |
| new_outputs = token | |
| for idx, layer in enumerate(model.layers): | |
| if caches_local[idx] is None: | |
| new_outputs, states_local[idx] = layer.step(new_outputs, states_local[idx]) | |
| else: | |
| new_outputs, states_local[idx] = layer.step(new_outputs, states_local[idx], cache=caches_local[idx]) | |
| outputs_local.append(new_outputs) | |
| recurrent_out = model.out_proj(torch.stack(outputs_local, dim=1)) | |
| synchronize() | |
| elapsed = time.perf_counter() - start | |
| return cache_elapsed, elapsed, recurrent_out | |
| lightweight_latency = float("nan") | |
| lightweight_tokens_per_s = float("nan") | |
| deploy_supported = any(hasattr(layer, "allocate_deployment_cache") for layer in model.layers) | |
| if deploy_supported: | |
| cache_setup_lightweight, lightweight_latency, recurrent_light = run_cached_recurrent("allocate_deployment_cache") | |
| lightweight_tokens_per_s = (sample_x.shape[0] * sample_x.shape[1]) / max(lightweight_latency, 1e-9) | |
| else: | |
| cache_setup_lightweight = float("nan") | |
| recurrent_light = None | |
| balanced_latency = float("nan") | |
| balanced_tokens_per_s = float("nan") | |
| balanced_supported = any(hasattr(layer, "allocate_balanced_deployment_cache") for layer in model.layers) | |
| if balanced_supported: | |
| cache_setup_balanced, balanced_latency, recurrent_balanced = run_cached_recurrent("allocate_balanced_deployment_cache") | |
| balanced_tokens_per_s = (sample_x.shape[0] * sample_x.shape[1]) / max(balanced_latency, 1e-9) | |
| else: | |
| cache_setup_balanced = float("nan") | |
| recurrent_balanced = None | |
| tokens = sample_x.shape[0] * sample_x.shape[1] | |
| return { | |
| "full_latency_ms": 1000.0 * full_latency, | |
| "full_tokens_per_s": tokens / max(full_latency, 1e-9), | |
| "cache_setup_ms": 1000.0 * cache_setup, | |
| "recurrent_latency_ms": 1000.0 * recurrent_latency, | |
| "recurrent_tokens_per_s": tokens / max(recurrent_latency, 1e-9), | |
| "recurrent_match_mse": float(F.mse_loss(recurrent, pred).detach().cpu()), | |
| "deploy_supported": deploy_supported, | |
| "deploy_cache_setup_ms": 1000.0 * cache_setup_lightweight, | |
| "deploy_recurrent_latency_ms": 1000.0 * lightweight_latency if lightweight_latency == lightweight_latency else float("nan"), | |
| "deploy_recurrent_tokens_per_s": lightweight_tokens_per_s, | |
| "deploy_match_mse": float(F.mse_loss(recurrent_light, pred).detach().cpu()) if recurrent_light is not None else float("nan"), | |
| "balanced_deploy_supported": balanced_supported, | |
| "balanced_deploy_cache_setup_ms": 1000.0 * cache_setup_balanced, | |
| "balanced_deploy_recurrent_latency_ms": 1000.0 * balanced_latency if balanced_latency == balanced_latency else float("nan"), | |
| "balanced_deploy_recurrent_tokens_per_s": balanced_tokens_per_s, | |
| "balanced_deploy_match_mse": float(F.mse_loss(recurrent_balanced, pred).detach().cpu()) if recurrent_balanced is not None else float("nan"), | |
| "prediction": pred.detach().cpu(), | |
| "recurrent_prediction": recurrent.detach().cpu(), | |
| } | |
| def show_benchmark_tables(df, title="Benchmark"): | |
| if df.empty: | |
| display(df) | |
| return | |
| def available(columns): | |
| return [col for col in columns if col in df.columns] | |
| ordered = df.sort_values(["task", "kind"]).reset_index(drop=True) | |
| normal_cols = available([ | |
| "task", | |
| "kind", | |
| "params", | |
| "train_loss", | |
| "val_loss", | |
| "mean_epoch_time_s", | |
| "expected_full_mode", | |
| "forward_ms", | |
| "backward_ms", | |
| "optimizer_ms", | |
| "full_latency_ms", | |
| "full_tokens_per_s", | |
| "cache_setup_ms", | |
| "recurrent_latency_ms", | |
| "recurrent_tokens_per_s", | |
| "recurrent_match_mse", | |
| ]) | |
| deploy_cols = available([ | |
| "task", | |
| "kind", | |
| "deploy_supported", | |
| "deploy_cache_setup_ms", | |
| "deploy_recurrent_latency_ms", | |
| "deploy_recurrent_tokens_per_s", | |
| "deploy_match_mse", | |
| ]) | |
| balanced_cols = available([ | |
| "task", | |
| "kind", | |
| "balanced_deploy_supported", | |
| "balanced_deploy_cache_setup_ms", | |
| "balanced_deploy_recurrent_latency_ms", | |
| "balanced_deploy_recurrent_tokens_per_s", | |
| "balanced_deploy_match_mse", | |
| ]) | |
| print(f"{title}: normal/full-sequence and full recurrent") | |
| display(ordered[normal_cols]) | |
| print(f"{title}: deployment-lite recurrent") | |
| display(ordered[deploy_cols]) | |
| print(f"{title}: balanced deployment recurrent") | |
| display(ordered[balanced_cols]) | |
| """ | |
| ) | |
| def quick_notebook(): | |
| cells = [ | |
| md( | |
| """ | |
| # Gamma Baseline vs Gamma S4 Enhanced Quick Benchmark | |
| This is the fast notebook for day-to-day iteration. | |
| It is intentionally narrow: | |
| - compare `gamma_baseline`, `gamma_s4_enhanced`, and `s4_ternary_dplr_ssm` | |
| - use practical sequence lengths | |
| - keep `kernel_mode` conservative | |
| - report training speed, inference speed, and one-step profiling | |
| """ | |
| ), | |
| setup_cell(), | |
| imports_cell(), | |
| code( | |
| r""" | |
| QUICK_TASKS = { | |
| "simple": dict(seq_len=192, features=4, train_samples=256, val_samples=64, epochs=4, batch_size=32, d_model=48, hidden_dim=64, num_layers=2, complexity=1), | |
| "moderate": dict(seq_len=320, features=6, train_samples=320, val_samples=80, epochs=5, batch_size=24, d_model=64, hidden_dim=96, num_layers=2, complexity=2), | |
| } | |
| MODEL_OVERRIDES = { | |
| "gamma_baseline": {}, | |
| "gamma_s4_enhanced": { | |
| "kernel_mode": "auto", | |
| "kernel_threshold": 384, | |
| "discretization": "bilinear", | |
| "gate": True, | |
| "input_gate": True, | |
| "activation": "gelu", | |
| "use_D": True, | |
| "layer_scale_init": 0.1, | |
| }, | |
| "s4_ternary_dplr_ssm": { | |
| "kernel_mode": "auto", | |
| "kernel_threshold": 256, | |
| "rank": 1, | |
| "gate": True, | |
| "input_gate": True, | |
| "activation": "gelu", | |
| "use_D": True, | |
| "layer_scale_init": 0.1, | |
| }, | |
| } | |
| ACTIVE_TASKS = ["simple", "moderate"] | |
| MODELS = ["gamma_baseline", "gamma_s4_enhanced", "s4_ternary_dplr_ssm"] | |
| """ | |
| ), | |
| shared_helpers_cell(), | |
| code( | |
| r""" | |
| def train_and_benchmark(task_name, kind): | |
| cfg = QUICK_TASKS[task_name] | |
| train_ds = make_forecasting_split(cfg, "train", seed=SEED + 11) | |
| val_ds = make_forecasting_split(cfg, "val", seed=SEED + 29) | |
| train_loader = DataLoader(train_ds, batch_size=cfg["batch_size"], shuffle=True) | |
| val_loader = DataLoader(val_ds, batch_size=cfg["batch_size"], shuffle=False) | |
| model = build_forecasting_model(kind, cfg, overrides=MODEL_OVERRIDES.get(kind)) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=1e-4) | |
| history = {"train_loss": [], "val_loss": [], "epoch_time_s": []} | |
| first_batch_x, first_batch_y = next(iter(train_loader)) | |
| profile = profile_train_step(build_forecasting_model(kind, cfg, overrides=MODEL_OVERRIDES.get(kind)), first_batch_x, first_batch_y) | |
| for epoch in range(cfg["epochs"]): | |
| train_loss, epoch_time = run_epoch(model, train_loader, optimizer=optimizer) | |
| val_loss, _ = run_epoch(model, val_loader) | |
| history["train_loss"].append(train_loss) | |
| history["val_loss"].append(val_loss) | |
| history["epoch_time_s"].append(epoch_time) | |
| print(f"{task_name} | {kind} | epoch={epoch+1:02d} train={train_loss:.6f} val={val_loss:.6f}") | |
| sample_x, sample_y = next(iter(val_loader)) | |
| sample_x = sample_x[:2] | |
| sample_y = sample_y[:2] | |
| inf = benchmark_inference(model, sample_x) | |
| metrics = { | |
| "task": task_name, | |
| "kind": kind, | |
| "params": sum(p.numel() for p in model.parameters()), | |
| "train_loss": history["train_loss"][-1], | |
| "val_loss": history["val_loss"][-1], | |
| "mean_epoch_time_s": float(np.mean(history["epoch_time_s"])), | |
| **profile, | |
| "sample_target": sample_y.cpu(), | |
| **inf, | |
| } | |
| return metrics, history, model | |
| """ | |
| ), | |
| md( | |
| """ | |
| ## Run the Quick Experiment | |
| This is the cell you will usually run in Colab first. | |
| """ | |
| ), | |
| code( | |
| r""" | |
| all_metrics = [] | |
| histories = {} | |
| trained_models = {} | |
| for task_name in ACTIVE_TASKS: | |
| for kind in MODELS: | |
| metrics, history, model = train_and_benchmark(task_name, kind) | |
| all_metrics.append({k: v for k, v in metrics.items() if k not in {"sample_target", "prediction", "recurrent_prediction"}}) | |
| histories[(task_name, kind)] = history | |
| trained_models[(task_name, kind)] = metrics | |
| summary_df = pd.DataFrame(all_metrics).sort_values(["task", "val_loss"]).reset_index(drop=True) | |
| show_benchmark_tables(summary_df, title="Quick benchmark") | |
| """ | |
| ), | |
| code( | |
| r""" | |
| fig, axes = plt.subplots(2, 2, figsize=(14, 8)) | |
| metrics_to_plot = ["val_loss", "mean_epoch_time_s", "full_tokens_per_s", "recurrent_tokens_per_s"] | |
| for ax, metric in zip(axes.flatten(), metrics_to_plot): | |
| pivot = summary_df.pivot(index="task", columns="kind", values=metric).loc[ACTIVE_TASKS] | |
| pivot.plot(ax=ax, marker="o") | |
| ax.set_title(metric.replace("_", " ").title()) | |
| ax.grid(alpha=0.2) | |
| plt.tight_layout() | |
| plt.show() | |
| """ | |
| ), | |
| md( | |
| """ | |
| The benchmark tables above are intentionally split into normal, deployment-lite, and balanced deployment views so the important columns remain visible in Colab. | |
| """ | |
| ), | |
| code( | |
| r""" | |
| PLOT_TASK = ACTIVE_TASKS[-1] | |
| baseline = trained_models[(PLOT_TASK, "gamma_baseline")] | |
| target = baseline["sample_target"][0].numpy() | |
| baseline_pred = baseline["prediction"][0].numpy() | |
| comparison_kinds = [kind for kind in MODELS if kind != "gamma_baseline"] | |
| for compare_kind in comparison_kinds: | |
| candidate = trained_models[(PLOT_TASK, compare_kind)] | |
| candidate_pred = candidate["prediction"][0].numpy() | |
| channels = range(min(3, target.shape[-1])) | |
| time_axis = np.arange(target.shape[0]) | |
| fig, axes = plt.subplots(len(list(channels)), 1, figsize=(12, 3.5 * len(list(channels))), sharex=True) | |
| if target.shape[-1] == 1: | |
| axes = [axes] | |
| for row, channel in enumerate(channels): | |
| ax = axes[row] | |
| ax.plot(time_axis, target[:, channel], label="ground truth", linewidth=2) | |
| ax.plot(time_axis, baseline_pred[:, channel], label="baseline", alpha=0.9) | |
| ax.plot(time_axis, candidate_pred[:, channel], label=compare_kind, alpha=0.9) | |
| ax.set_title(f"{PLOT_TASK} task, channel {channel}") | |
| ax.grid(alpha=0.2) | |
| if row == 0: | |
| ax.legend() | |
| plt.tight_layout() | |
| plt.show() | |
| """ | |
| ), | |
| md( | |
| """ | |
| ## Reading This Notebook | |
| Use this notebook for fast feedback: | |
| - if either S4-style model loses badly here, do not trust it on bigger tasks | |
| - if one of them is competitive here, then move to the research notebook | |
| """ | |
| ), | |
| ] | |
| return cells | |
| def research_notebook(): | |
| cells = [ | |
| md( | |
| """ | |
| # Gamma S4 Practical Benchmark | |
| This notebook is the second benchmark track after the quick notebook. | |
| It is meant to be closer to practical sequence modeling while staying reasonable on Colab: | |
| - one harder long-context synthetic benchmark | |
| - enhanced-model ablations on that harder task | |
| - an optional lightweight token benchmark | |
| """ | |
| ), | |
| setup_cell(), | |
| imports_cell(), | |
| code( | |
| r""" | |
| PRACTICAL_CONFIGS = { | |
| "current_reference": dict(seq_len=320, features=6, train_samples=320, val_samples=80, epochs=5, batch_size=24, d_model=64, hidden_dim=96, num_layers=2, complexity=2), | |
| "long_context": dict(seq_len=768, features=8, train_samples=256, val_samples=64, epochs=4, batch_size=12, d_model=80, hidden_dim=128, num_layers=3, complexity=3), | |
| } | |
| RUN_PRACTICAL_SWEEP = True | |
| RUN_ABLATIONS = False | |
| RUN_TOKEN_TASK = True | |
| PRACTICAL_MODELS = ["gamma_baseline", "gamma_s4_enhanced", "s4_ternary_dplr_ssm"] | |
| MODEL_OVERRIDES = { | |
| "gamma_baseline": {}, | |
| "gamma_s4_enhanced": { | |
| "kernel_mode": "auto", | |
| "kernel_threshold": 384, | |
| "discretization": "bilinear", | |
| "gate": True, | |
| "input_gate": True, | |
| "activation": "gelu", | |
| "use_D": True, | |
| "layer_scale_init": 0.1, | |
| }, | |
| "gamma_s4_minimal": { | |
| "kernel_mode": "auto", | |
| "kernel_threshold": 512, | |
| "discretization": "bilinear", | |
| "use_D": True, | |
| }, | |
| "s4_ternary_dplr_ssm": { | |
| "kernel_mode": "auto", | |
| "kernel_threshold": 256, | |
| "rank": 1, | |
| "gate": True, | |
| "input_gate": True, | |
| "activation": "gelu", | |
| "use_D": True, | |
| "layer_scale_init": 0.1, | |
| }, | |
| } | |
| ABLATIONS = [ | |
| ("default", {}), | |
| ("no_input_gate", {"input_gate": False}), | |
| ("no_gate", {"gate": False}), | |
| ("no_skip_D", {"use_D": False}), | |
| ("euler", {"discretization": "euler"}), | |
| ] | |
| """ | |
| ), | |
| shared_helpers_cell(), | |
| code( | |
| r""" | |
| def train_practical_model(task_name, kind, overrides=None): | |
| cfg = PRACTICAL_CONFIGS[task_name] | |
| train_ds = make_forecasting_split(cfg, "train", seed=SEED + 101) | |
| val_ds = make_forecasting_split(cfg, "val", seed=SEED + 151) | |
| train_loader = DataLoader(train_ds, batch_size=cfg["batch_size"], shuffle=True) | |
| val_loader = DataLoader(val_ds, batch_size=cfg["batch_size"], shuffle=False) | |
| model = build_forecasting_model(kind, cfg, overrides=overrides or MODEL_OVERRIDES.get(kind)) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=1e-4) | |
| history = {"train_loss": [], "val_loss": [], "epoch_time_s": []} | |
| for epoch in range(cfg["epochs"]): | |
| train_loss, epoch_time = run_epoch(model, train_loader, optimizer=optimizer) | |
| val_loss, _ = run_epoch(model, val_loader) | |
| history["train_loss"].append(train_loss) | |
| history["val_loss"].append(val_loss) | |
| history["epoch_time_s"].append(epoch_time) | |
| print(f"{task_name} | {kind} | epoch={epoch+1:02d} train={train_loss:.6f} val={val_loss:.6f}") | |
| sample_x, sample_y = next(iter(val_loader)) | |
| sample_x = sample_x[:2] | |
| sample_y = sample_y[:2] | |
| inf = benchmark_inference(model, sample_x) | |
| result = { | |
| "task": task_name, | |
| "kind": kind, | |
| "params": sum(p.numel() for p in model.parameters()), | |
| "train_loss": history["train_loss"][-1], | |
| "val_loss": history["val_loss"][-1], | |
| "mean_epoch_time_s": float(np.mean(history["epoch_time_s"])), | |
| "expected_full_mode": "conv" if (kind != "gamma_baseline" and cfg["seq_len"] >= MODEL_OVERRIDES.get(kind, {}).get("kernel_threshold", 10**9)) else "recurrent_like", | |
| "sample_target": sample_y.cpu(), | |
| **inf, | |
| } | |
| return result, history, model | |
| """ | |
| ), | |
| md( | |
| """ | |
| ## Practical Sweep | |
| """ | |
| ), | |
| code( | |
| r""" | |
| practical_rows = [] | |
| practical_artifacts = {} | |
| practical_models = {} | |
| if RUN_PRACTICAL_SWEEP: | |
| for task_name in PRACTICAL_CONFIGS: | |
| for kind in PRACTICAL_MODELS: | |
| result, _, model = train_practical_model(task_name, kind) | |
| practical_artifacts[(task_name, kind)] = result | |
| practical_models[(task_name, kind)] = model | |
| practical_rows.append({k: v for k, v in result.items() if k not in {"sample_target", "prediction", "recurrent_prediction"}}) | |
| practical_df = pd.DataFrame(practical_rows) | |
| show_benchmark_tables(practical_df, title="Practical sweep") | |
| """ | |
| ), | |
| code( | |
| r""" | |
| if not practical_df.empty: | |
| fig, axes = plt.subplots(2, 2, figsize=(14, 8)) | |
| for ax, metric in zip(axes.flatten(), ["val_loss", "mean_epoch_time_s", "full_tokens_per_s", "recurrent_tokens_per_s"]): | |
| pivot = practical_df.pivot(index="task", columns="kind", values=metric) | |
| pivot.plot(ax=ax, marker="o") | |
| ax.set_title(metric.replace("_", " ").title()) | |
| ax.grid(alpha=0.2) | |
| plt.tight_layout() | |
| plt.show() | |
| """ | |
| ), | |
| md( | |
| """ | |
| ## Practical Inference Scaling | |
| This section benchmarks trained practical models across sequence lengths and batch sizes. It separates full-sequence prefill-like throughput from recurrent/deployment decode-like throughput. | |
| """ | |
| ), | |
| code( | |
| r""" | |
| inference_scaling_rows = [] | |
| if RUN_PRACTICAL_SWEEP and practical_models: | |
| INFERENCE_SEQ_LENS = [64, 192, 384, 768] | |
| INFERENCE_BATCH_SIZES = [1, 2] | |
| for (task_name, kind), model in practical_models.items(): | |
| feature_dim = model.out_proj.out_features | |
| for seq_len in INFERENCE_SEQ_LENS: | |
| for batch_size in INFERENCE_BATCH_SIZES: | |
| sample_x = torch.randn(batch_size, seq_len, feature_dim) | |
| metrics = benchmark_inference(model, sample_x) | |
| inference_scaling_rows.append({ | |
| "task": task_name, | |
| "kind": kind, | |
| "batch_size": batch_size, | |
| "seq_len": seq_len, | |
| "full_latency_ms": metrics["full_latency_ms"], | |
| "full_tokens_per_s": metrics["full_tokens_per_s"], | |
| "recurrent_latency_ms": metrics["recurrent_latency_ms"], | |
| "recurrent_tokens_per_s": metrics["recurrent_tokens_per_s"], | |
| "deploy_recurrent_latency_ms": metrics["deploy_recurrent_latency_ms"], | |
| "deploy_recurrent_tokens_per_s": metrics["deploy_recurrent_tokens_per_s"], | |
| "balanced_deploy_recurrent_latency_ms": metrics["balanced_deploy_recurrent_latency_ms"], | |
| "balanced_deploy_recurrent_tokens_per_s": metrics["balanced_deploy_recurrent_tokens_per_s"], | |
| "recurrent_match_mse": metrics["recurrent_match_mse"], | |
| "deploy_match_mse": metrics["deploy_match_mse"], | |
| "balanced_deploy_match_mse": metrics["balanced_deploy_match_mse"], | |
| }) | |
| inference_scaling_df = pd.DataFrame(inference_scaling_rows) | |
| display(inference_scaling_df.sort_values(["task", "kind", "seq_len", "batch_size"])) | |
| """ | |
| ), | |
| code( | |
| r""" | |
| if "inference_scaling_df" in globals() and not inference_scaling_df.empty: | |
| for task_name in inference_scaling_df["task"].unique(): | |
| subset = inference_scaling_df[(inference_scaling_df["task"] == task_name) & (inference_scaling_df["batch_size"] == 1)] | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 4)) | |
| for kind in subset["kind"].unique(): | |
| model_subset = subset[subset["kind"] == kind].sort_values("seq_len") | |
| axes[0].plot(model_subset["seq_len"], model_subset["full_tokens_per_s"], marker="o", label=kind) | |
| axes[1].plot(model_subset["seq_len"], model_subset["recurrent_tokens_per_s"], marker="o", label=kind) | |
| axes[0].set_title(f"{task_name}: full-sequence throughput, batch=1") | |
| axes[1].set_title(f"{task_name}: recurrent throughput, batch=1") | |
| for ax in axes: | |
| ax.set_xlabel("seq_len") | |
| ax.set_ylabel("tokens/s") | |
| ax.grid(alpha=0.25) | |
| ax.legend() | |
| plt.tight_layout() | |
| plt.show() | |
| """ | |
| ), | |
| md( | |
| """ | |
| ## Task Visual Preview | |
| These plots show what the synthetic tasks actually look like before we discuss scores. | |
| """ | |
| ), | |
| code( | |
| r""" | |
| if RUN_PRACTICAL_SWEEP: | |
| fig, axes = plt.subplots(len(PRACTICAL_CONFIGS), 1, figsize=(14, 3.5 * len(PRACTICAL_CONFIGS)), sharex=False) | |
| if len(PRACTICAL_CONFIGS) == 1: | |
| axes = [axes] | |
| for ax, task_name in zip(axes, PRACTICAL_CONFIGS): | |
| cfg = PRACTICAL_CONFIGS[task_name] | |
| preview_ds = make_forecasting_split(cfg, "val", seed=SEED + 151) | |
| preview_x, preview_y = preview_ds[0] | |
| channels = range(min(3, preview_y.shape[-1])) | |
| time_axis = np.arange(preview_y.shape[0]) | |
| for channel in channels: | |
| ax.plot(time_axis, preview_y[:, channel].numpy(), label=f"channel {channel}", linewidth=1.5) | |
| ax.set_title(f"{task_name} target preview") | |
| ax.grid(alpha=0.2) | |
| if task_name == list(PRACTICAL_CONFIGS.keys())[0]: | |
| ax.legend(ncol=min(3, preview_y.shape[-1])) | |
| plt.tight_layout() | |
| plt.show() | |
| """ | |
| ), | |
| md( | |
| """ | |
| ## Prediction Comparison Plots | |
| These are the most presentation-friendly plots in the notebook: | |
| ground truth vs baseline vs each S4-style model on the same held-out sample. | |
| """ | |
| ), | |
| code( | |
| r""" | |
| if practical_artifacts: | |
| for task_name in PRACTICAL_CONFIGS: | |
| baseline = practical_artifacts[(task_name, "gamma_baseline")] | |
| target = baseline["sample_target"][0].numpy() | |
| baseline_pred = baseline["prediction"][0].numpy() | |
| for compare_kind in [kind for kind in PRACTICAL_MODELS if kind != "gamma_baseline"]: | |
| candidate = practical_artifacts[(task_name, compare_kind)] | |
| candidate_pred = candidate["prediction"][0].numpy() | |
| channels = range(min(3, target.shape[-1])) | |
| time_axis = np.arange(target.shape[0]) | |
| fig, axes = plt.subplots(len(list(channels)), 1, figsize=(14, 3.5 * len(list(channels))), sharex=True) | |
| if target.shape[-1] == 1: | |
| axes = [axes] | |
| for row, channel in enumerate(channels): | |
| ax = axes[row] | |
| ax.plot(time_axis, target[:, channel], label="ground truth", linewidth=2) | |
| ax.plot(time_axis, baseline_pred[:, channel], label="baseline", alpha=0.9) | |
| ax.plot(time_axis, candidate_pred[:, channel], label=compare_kind, alpha=0.9) | |
| ax.set_title(f"{task_name} prediction comparison, channel {channel}") | |
| ax.grid(alpha=0.2) | |
| if row == 0: | |
| ax.legend() | |
| plt.tight_layout() | |
| plt.show() | |
| """ | |
| ), | |
| md( | |
| """ | |
| ## Error Comparison Plots | |
| These show where each model is missing the target signal. Lower absolute error should visually hug zero. | |
| """ | |
| ), | |
| code( | |
| r""" | |
| if practical_artifacts: | |
| for task_name in PRACTICAL_CONFIGS: | |
| baseline = practical_artifacts[(task_name, "gamma_baseline")] | |
| target = baseline["sample_target"][0].numpy() | |
| baseline_err = baseline["prediction"][0].numpy() - target | |
| for compare_kind in [kind for kind in PRACTICAL_MODELS if kind != "gamma_baseline"]: | |
| candidate = practical_artifacts[(task_name, compare_kind)] | |
| candidate_err = candidate["prediction"][0].numpy() - target | |
| channels = range(min(2, target.shape[-1])) | |
| time_axis = np.arange(target.shape[0]) | |
| fig, axes = plt.subplots(len(list(channels)), 1, figsize=(14, 3.0 * len(list(channels))), sharex=True) | |
| if len(list(channels)) == 1: | |
| axes = [axes] | |
| for row, channel in enumerate(channels): | |
| ax = axes[row] | |
| ax.plot(time_axis, baseline_err[:, channel], label="baseline error", alpha=0.9) | |
| ax.plot(time_axis, candidate_err[:, channel], label=f"{compare_kind} error", alpha=0.9) | |
| ax.axhline(0.0, color="black", linewidth=1, alpha=0.5) | |
| ax.set_title(f"{task_name} error comparison, channel {channel}") | |
| ax.grid(alpha=0.2) | |
| if row == 0: | |
| ax.legend() | |
| plt.tight_layout() | |
| plt.show() | |
| """ | |
| ), | |
| md( | |
| """ | |
| ## Enhanced Ablations On Long Context | |
| """ | |
| ), | |
| code( | |
| r""" | |
| ablation_rows = [] | |
| if RUN_ABLATIONS: | |
| for name, override in ABLATIONS: | |
| merged = {**MODEL_OVERRIDES["gamma_s4_enhanced"], **override} | |
| result, _, _ = train_practical_model("long_context", "gamma_s4_enhanced", overrides=merged) | |
| row = {k: v for k, v in result.items() if k not in {"sample_target", "prediction", "recurrent_prediction"}} | |
| row["ablation"] = name | |
| ablation_rows.append(row) | |
| ablation_df = pd.DataFrame(ablation_rows) | |
| ablation_df | |
| """ | |
| ), | |
| code( | |
| r""" | |
| if not ablation_df.empty: | |
| fig, axes = plt.subplots(1, 3, figsize=(16, 4)) | |
| ablation_df.sort_values("val_loss").plot.bar(x="ablation", y="val_loss", ax=axes[0], legend=False) | |
| ablation_df.sort_values("mean_epoch_time_s").plot.bar(x="ablation", y="mean_epoch_time_s", ax=axes[1], legend=False) | |
| ablation_df.sort_values("recurrent_tokens_per_s").plot.bar(x="ablation", y="recurrent_tokens_per_s", ax=axes[2], legend=False) | |
| axes[0].set_title("Validation Loss") | |
| axes[1].set_title("Mean Epoch Time") | |
| axes[2].set_title("Recurrent Tokens / s") | |
| for ax in axes: | |
| ax.tick_params(axis="x", rotation=25) | |
| ax.grid(alpha=0.2) | |
| plt.tight_layout() | |
| plt.show() | |
| """ | |
| ), | |
| md( | |
| """ | |
| ## Optional Token-Lite Task | |
| """ | |
| ), | |
| code( | |
| r""" | |
| @torch.no_grad() | |
| def benchmark_token_inference(model, token_batch, target_batch, vocab_size, repeats=6): | |
| model.eval() | |
| x = token_batch.to(DEVICE) | |
| y = target_batch.to(DEVICE) | |
| batch, seq_len = x.shape | |
| token_count = batch * seq_len | |
| rows = [] | |
| def reset_memory(): | |
| if DEVICE.type == "cuda": | |
| torch.cuda.reset_peak_memory_stats() | |
| def max_memory_mb(): | |
| if DEVICE.type != "cuda": | |
| return float("nan") | |
| return torch.cuda.max_memory_allocated() / (1024 ** 2) | |
| reset_memory() | |
| with torch.autocast(device_type=DEVICE.type, enabled=USE_AMP): | |
| logits = model(x) | |
| loss = F.cross_entropy(logits.reshape(-1, vocab_size), y.reshape(-1)) | |
| synchronize() | |
| start = time.perf_counter() | |
| for _ in range(repeats): | |
| with torch.autocast(device_type=DEVICE.type, enabled=USE_AMP): | |
| logits = model(x) | |
| synchronize() | |
| elapsed = time.perf_counter() - start | |
| full_logits = logits.detach() | |
| rows.append({ | |
| "mode": "prefill_full_sequence", | |
| "latency_ms": 1000.0 * elapsed / repeats, | |
| "tokens_per_s": token_count * repeats / max(elapsed, 1e-9), | |
| "ce": float(loss.detach().cpu()), | |
| "match_mse": 0.0, | |
| "max_memory_mb": max_memory_mb(), | |
| }) | |
| hidden = model.embed(x) | |
| states = [] | |
| caches = [] | |
| cache_start = time.perf_counter() | |
| for layer in model.layers: | |
| state = layer.ssm.init_state(hidden.size(0), DEVICE, hidden.dtype) | |
| states.append(state) | |
| if hasattr(layer, "allocate_inference_cache"): | |
| caches.append(layer.allocate_inference_cache(hidden.size(0), seq_len, DEVICE, hidden.dtype)) | |
| else: | |
| caches.append(None) | |
| synchronize() | |
| cache_setup_ms = 1000.0 * (time.perf_counter() - cache_start) | |
| def recurrent_pass(cache_list): | |
| local_states = [state.clone() for state in states] | |
| outputs = [] | |
| for t in range(seq_len): | |
| step_x = hidden[:, t, :] | |
| for layer_idx, layer in enumerate(model.layers): | |
| cache = None if cache_list is None else cache_list[layer_idx] | |
| try: | |
| step_x, local_states[layer_idx] = layer.step(step_x, local_states[layer_idx], cache=cache) | |
| except TypeError: | |
| step_x, local_states[layer_idx] = layer.step(step_x, local_states[layer_idx]) | |
| outputs.append(model.head(step_x)) | |
| return torch.stack(outputs, dim=1) | |
| reset_memory() | |
| recurrent_logits = recurrent_pass(caches) | |
| synchronize() | |
| start = time.perf_counter() | |
| for _ in range(repeats): | |
| recurrent_logits = recurrent_pass(caches) | |
| synchronize() | |
| elapsed = time.perf_counter() - start | |
| recurrent_loss = F.cross_entropy(recurrent_logits.reshape(-1, vocab_size), y.reshape(-1)) | |
| rows.append({ | |
| "mode": "decode_recurrent_exact", | |
| "cache_setup_ms": cache_setup_ms, | |
| "latency_ms": 1000.0 * elapsed / repeats, | |
| "tokens_per_s": token_count * repeats / max(elapsed, 1e-9), | |
| "ce": float(recurrent_loss.detach().cpu()), | |
| "match_mse": float(F.mse_loss(recurrent_logits, full_logits).detach().cpu()), | |
| "max_memory_mb": max_memory_mb(), | |
| }) | |
| deploy_supported = all(hasattr(layer, "allocate_deployment_cache") for layer in model.layers) | |
| if deploy_supported: | |
| cache_start = time.perf_counter() | |
| deploy_caches = [ | |
| layer.allocate_deployment_cache(hidden.size(0), seq_len, DEVICE, hidden.dtype) | |
| for layer in model.layers | |
| ] | |
| synchronize() | |
| deploy_cache_setup_ms = 1000.0 * (time.perf_counter() - cache_start) | |
| reset_memory() | |
| deploy_logits = recurrent_pass(deploy_caches) | |
| synchronize() | |
| start = time.perf_counter() | |
| for _ in range(repeats): | |
| deploy_logits = recurrent_pass(deploy_caches) | |
| synchronize() | |
| elapsed = time.perf_counter() - start | |
| deploy_loss = F.cross_entropy(deploy_logits.reshape(-1, vocab_size), y.reshape(-1)) | |
| rows.append({ | |
| "mode": "decode_deploy_lite", | |
| "cache_setup_ms": deploy_cache_setup_ms, | |
| "latency_ms": 1000.0 * elapsed / repeats, | |
| "tokens_per_s": token_count * repeats / max(elapsed, 1e-9), | |
| "ce": float(deploy_loss.detach().cpu()), | |
| "match_mse": float(F.mse_loss(deploy_logits, full_logits).detach().cpu()), | |
| "max_memory_mb": max_memory_mb(), | |
| }) | |
| balanced_supported = all(hasattr(layer, "allocate_balanced_deployment_cache") for layer in model.layers) | |
| if balanced_supported: | |
| cache_start = time.perf_counter() | |
| balanced_caches = [ | |
| layer.allocate_balanced_deployment_cache(hidden.size(0), seq_len, DEVICE, hidden.dtype) | |
| for layer in model.layers | |
| ] | |
| synchronize() | |
| balanced_cache_setup_ms = 1000.0 * (time.perf_counter() - cache_start) | |
| reset_memory() | |
| balanced_logits = recurrent_pass(balanced_caches) | |
| synchronize() | |
| start = time.perf_counter() | |
| for _ in range(repeats): | |
| balanced_logits = recurrent_pass(balanced_caches) | |
| synchronize() | |
| elapsed = time.perf_counter() - start | |
| balanced_loss = F.cross_entropy(balanced_logits.reshape(-1, vocab_size), y.reshape(-1)) | |
| rows.append({ | |
| "mode": "decode_balanced", | |
| "cache_setup_ms": balanced_cache_setup_ms, | |
| "latency_ms": 1000.0 * elapsed / repeats, | |
| "tokens_per_s": token_count * repeats / max(elapsed, 1e-9), | |
| "ce": float(balanced_loss.detach().cpu()), | |
| "match_mse": float(F.mse_loss(balanced_logits, full_logits).detach().cpu()), | |
| "max_memory_mb": max_memory_mb(), | |
| }) | |
| return rows | |
| """ | |
| ), | |
| code( | |
| r""" | |
| if RUN_TOKEN_TASK: | |
| TOKEN_DATA_URL = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" | |
| TOKEN_DATA_PATH = Path("tmp/jupyter-notebook/tinyshakespeare.txt") | |
| TOKEN_DATA_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| if not TOKEN_DATA_PATH.exists(): | |
| urllib.request.urlretrieve(TOKEN_DATA_URL, TOKEN_DATA_PATH) | |
| text = TOKEN_DATA_PATH.read_text(encoding="utf-8") | |
| vocab = sorted(set(text)) | |
| stoi = {ch: i for i, ch in enumerate(vocab)} | |
| tokens = torch.tensor([stoi[ch] for ch in text], dtype=torch.long) | |
| TOKEN_CFG = { | |
| "seq_len": 192, | |
| "train_samples": 1200, | |
| "val_samples": 240, | |
| "epochs": 2, | |
| "batch_size": 12 if DEVICE.type == "cuda" else 6, | |
| } | |
| def make_token_split(seq_len, train_samples, val_samples): | |
| max_start = len(tokens) - seq_len - 1 | |
| starts = torch.linspace(0, max_start - 1, steps=train_samples + val_samples).long() | |
| x = torch.stack([tokens[s : s + seq_len] for s in starts]) | |
| y = torch.stack([tokens[s + 1 : s + seq_len + 1] for s in starts]) | |
| return TensorDataset(x[:train_samples], y[:train_samples]), TensorDataset(x[train_samples:], y[train_samples:]) | |
| class TokenForecaster(nn.Module): | |
| def __init__(self, vocab_size, kind): | |
| super().__init__() | |
| self.embed = nn.Embedding(vocab_size, 64) | |
| if kind == "gamma_baseline": | |
| factory = lambda: GammaSingleBlock(d_model=64, hidden_dim=96, dropout=0.0) | |
| elif kind == "gamma_s4_enhanced": | |
| factory = lambda: GammaS4Block(d_model=64, hidden_dim=96, kernel_mode="auto", kernel_threshold=160) | |
| elif kind == "s4_ternary_dplr_ssm": | |
| factory = lambda: S4TernaryDPLRBlock(d_model=64, hidden_dim=96, kernel_mode="auto", kernel_threshold=160) | |
| else: | |
| factory = lambda: GammaS4MinimalBlock(d_model=64, hidden_dim=96, kernel_mode="auto", kernel_threshold=160) | |
| self.layers = nn.ModuleList([factory(), factory()]) | |
| self.head = nn.Linear(64, vocab_size) | |
| def forward(self, x): | |
| x = self.embed(x) | |
| for layer in self.layers: | |
| x, _ = layer(x, state=None, return_state=False) | |
| return self.head(x) | |
| token_train, token_val = make_token_split( | |
| seq_len=TOKEN_CFG["seq_len"], | |
| train_samples=TOKEN_CFG["train_samples"], | |
| val_samples=TOKEN_CFG["val_samples"], | |
| ) | |
| token_rows = [] | |
| token_inference_rows = [] | |
| for kind in PRACTICAL_MODELS: | |
| model = TokenForecaster(len(vocab), kind).to(DEVICE) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=1e-4) | |
| train_loader = DataLoader(token_train, batch_size=TOKEN_CFG["batch_size"], shuffle=True) | |
| val_loader = DataLoader(token_val, batch_size=TOKEN_CFG["batch_size"], shuffle=False) | |
| history = [] | |
| for epoch in range(TOKEN_CFG["epochs"]): | |
| model.train() | |
| train_losses = [] | |
| for batch_x, batch_y in train_loader: | |
| batch_x = batch_x.to(DEVICE) | |
| batch_y = batch_y.to(DEVICE) | |
| optimizer.zero_grad(set_to_none=True) | |
| with torch.autocast(device_type=DEVICE.type, enabled=USE_AMP): | |
| logits = model(batch_x) | |
| loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch_y.reshape(-1)) | |
| scaler.scale(loss).backward() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| train_losses.append(loss.detach().item()) | |
| model.eval() | |
| val_losses = [] | |
| with torch.no_grad(): | |
| for batch_x, batch_y in val_loader: | |
| batch_x = batch_x.to(DEVICE) | |
| batch_y = batch_y.to(DEVICE) | |
| logits = model(batch_x) | |
| loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch_y.reshape(-1)) | |
| val_losses.append(loss.detach().item()) | |
| history.append((float(np.mean(train_losses)), float(np.mean(val_losses)))) | |
| print(kind, epoch + 1, history[-1]) | |
| token_rows.append({ | |
| "kind": kind, | |
| "train_ce": history[-1][0], | |
| "val_ce": history[-1][1], | |
| "val_ppl": math.exp(history[-1][1]), | |
| "seq_len": TOKEN_CFG["seq_len"], | |
| "train_samples": TOKEN_CFG["train_samples"], | |
| }) | |
| sample_x, sample_y = next(iter(val_loader)) | |
| for row in benchmark_token_inference(model, sample_x[:2], sample_y[:2], len(vocab)): | |
| row["kind"] = kind | |
| row["seq_len"] = TOKEN_CFG["seq_len"] | |
| row["batch_size"] = min(2, sample_x.size(0)) | |
| token_inference_rows.append(row) | |
| token_df = pd.DataFrame(token_rows).sort_values("val_ce") | |
| display(token_df) | |
| token_inference_df = pd.DataFrame(token_inference_rows) | |
| display(token_inference_df.sort_values(["kind", "mode"])) | |
| """ | |
| ), | |
| md( | |
| """ | |
| Use this notebook after the quick benchmark. The `long_context` task is the more practical synthetic benchmark, and the optional token-lite section gives a small language-like check without making Colab costs too high. | |
| `RUN_ABLATIONS` is off by default because the long-context ablation sweep is still materially more expensive than the main comparison. | |
| """ | |
| ), | |
| ] | |
| return cells | |
| def write_notebook(path, cells): | |
| notebook = { | |
| "cells": cells, | |
| "metadata": { | |
| "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, | |
| "language_info": {"name": "python", "version": "3.11"}, | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5, | |
| } | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| path.write_text(json.dumps(notebook, indent=2), encoding="utf-8") | |
| print(f"Wrote {path}") | |
| write_notebook(QUICK_NOTEBOOK_PATH, quick_notebook()) | |
| write_notebook(RESEARCH_NOTEBOOK_PATH, research_notebook()) | |