import json from pathlib import Path def md(text): return { "cell_type": "markdown", "metadata": {}, "source": text.strip("\n").splitlines(keepends=True), } def code(text): return { "cell_type": "code", "execution_count": None, "metadata": {}, "outputs": [], "source": text.strip("\n").splitlines(keepends=True), } QUICK_NOTEBOOK_PATH = Path("output/jupyter-notebook/gamma-s4-sinewave-benchmark.ipynb") RESEARCH_NOTEBOOK_PATH = Path("output/jupyter-notebook/gamma-s4-research-benchmark.ipynb") def setup_cell(): return code( r""" import os import sys import subprocess import importlib from pathlib import Path IN_COLAB = "google.colab" in sys.modules REPO_DIR = Path.cwd() if IN_COLAB: print("Running in Google Colab") from google.colab import userdata from getpass import getpass REPO_NAME = "gamma_ssm_s4_v2" REPO_DIR = Path("/content") / REPO_NAME GITHUB_REPO = "StarMists/gamma_SSM_S4_enhanced" token = os.environ.get("GITHUB_TOKEN") or os.environ.get("GH_TOKEN") if not token: try: token = userdata.get("GITHUB_TOKEN") except Exception: token = None if not token: token = getpass("GitHub personal access token for private repo access: ").strip() clone_url = f"https://{token}@github.com/{GITHUB_REPO}.git" if REPO_DIR.exists(): subprocess.run(["git", "-C", str(REPO_DIR), "fetch", "origin"], check=True) subprocess.run(["git", "-C", str(REPO_DIR), "checkout", "main"], check=True) subprocess.run(["git", "-C", str(REPO_DIR), "reset", "--hard", "origin/main"], check=True) else: subprocess.run(["git", "clone", clone_url, str(REPO_DIR)], check=True) os.chdir(REPO_DIR) sys.path.insert(0, str(REPO_DIR)) else: print("Running locally from", REPO_DIR) sys.path.insert(0, str(REPO_DIR)) importlib.invalidate_caches() for name in list(sys.modules): if ( name == "gamma_space_model" or name.startswith("gamma_space_model.") or name == "csrc" or name.startswith("csrc.") or name == "tilelang" or name.startswith("tilelang.") ): del sys.modules[name] try: commit = subprocess.check_output(["git", "-C", str(REPO_DIR), "rev-parse", "--short", "HEAD"], text=True).strip() print("Repo commit:", commit) except Exception: pass """ ) def imports_cell(): return code( r""" import math import random import time import urllib.request import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, TensorDataset from gamma_space_model import GammaSingleBlock, GammaS4Block, GammaS4MinimalBlock, S4TernaryDPLRBlock SEED = 7 random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") USE_AMP = DEVICE.type == "cuda" if hasattr(torch, "amp") and hasattr(torch.amp, "GradScaler"): scaler = torch.amp.GradScaler("cuda", enabled=USE_AMP) else: scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP) def synchronize(): if DEVICE.type == "cuda": torch.cuda.synchronize() print("Device:", DEVICE) print("Deployment cache available on GammaS4Block:", hasattr(GammaS4Block, "allocate_deployment_cache")) if DEVICE.type != "cuda": print("WARNING: running on CPU. Treat speed numbers as smoke-test only, not main benchmark evidence.") """ ) def shared_helpers_cell(): return code( r""" def make_forecasting_split(config, split, seed): rng = np.random.default_rng(seed) count = config["train_samples"] if split == "train" else config["val_samples"] seq_len = config["seq_len"] features = config["features"] complexity = config["complexity"] t = np.linspace(0.0, 1.0, seq_len + 1, dtype=np.float32) data = np.zeros((count, seq_len + 1, features), dtype=np.float32) for i in range(count): phase = rng.uniform(0.0, 2.0 * np.pi) chirp = np.sin(2.0 * np.pi * (1.0 + 1.75 * t**2) * rng.uniform(0.9, 1.15) + phase) slow = np.sin(2.0 * np.pi * (0.4 + 0.25 * complexity) * t + 0.7 * phase) medium = np.sin(2.0 * np.pi * (2.5 + complexity) * t + 1.1 * phase) fast = np.cos(2.0 * np.pi * (5.0 + 2.0 * complexity) * t + 1.5 * phase) bursts = (np.sin(2.0 * np.pi * (3.0 + complexity) * t + 0.4 * phase) > 0.8).astype(np.float32) bursts = bursts * np.sin(2.0 * np.pi * (10.0 + complexity) * t + 0.3 * phase) delayed = np.roll(chirp, 4 * complexity) modulated = medium * (1.0 + 0.35 * slow) components = [chirp, slow, medium, fast, bursts, delayed, modulated] for channel in range(features): weights = rng.normal(0.0, 1.0, size=len(components)).astype(np.float32) weights[: 2 + complexity] *= 1.2 signal = sum(w * c for w, c in zip(weights, components)) if complexity >= 2: signal += 0.20 * np.tanh(np.roll(signal, channel + 1)) if complexity >= 3: signal += 0.10 * np.sin(signal * (0.5 + 0.08 * channel)) signal += rng.normal(0.0, 0.03 + 0.01 * complexity, size=seq_len + 1).astype(np.float32) data[i, :, channel] = signal data -= data.mean(axis=1, keepdims=True) data /= data.std(axis=1, keepdims=True) + 1e-5 return TensorDataset(torch.from_numpy(data[:, :-1, :]), torch.from_numpy(data[:, 1:, :])) class SequenceForecaster(nn.Module): def __init__(self, input_dim, model_dim, layers, block_factory): super().__init__() self.in_proj = nn.Linear(input_dim, model_dim) self.layers = nn.ModuleList([block_factory(model_dim) for _ in range(layers)]) self.out_proj = nn.Linear(model_dim, input_dim) def forward(self, x): x = self.in_proj(x) for layer in self.layers: x, _ = layer(x, state=None, return_state=False) return self.out_proj(x), None def build_forecasting_model(kind, config, overrides=None): overrides = overrides or {} d_model = config["d_model"] hidden_dim = config["hidden_dim"] num_layers = config["num_layers"] input_dim = config["features"] if kind == "gamma_baseline": block_factory = lambda width: GammaSingleBlock( d_model=width, hidden_dim=hidden_dim, dropout=0.0, ) elif kind == "gamma_s4_enhanced": block_factory = lambda width: GammaS4Block( d_model=width, hidden_dim=hidden_dim, kernel_mode=overrides.get("kernel_mode", "auto"), kernel_threshold=overrides.get("kernel_threshold", 384), discretization=overrides.get("discretization", "bilinear"), gate=overrides.get("gate", True), input_gate=overrides.get("input_gate", True), activation=overrides.get("activation", "gelu"), use_D=overrides.get("use_D", True), layer_scale_init=overrides.get("layer_scale_init", 0.1), ) elif kind == "gamma_s4_minimal": block_factory = lambda width: GammaS4MinimalBlock( d_model=width, hidden_dim=hidden_dim, kernel_mode=overrides.get("kernel_mode", "auto"), kernel_threshold=overrides.get("kernel_threshold", 384), discretization=overrides.get("discretization", "bilinear"), use_D=overrides.get("use_D", True), ) elif kind == "s4_ternary_dplr_ssm": block_factory = lambda width: S4TernaryDPLRBlock( d_model=width, hidden_dim=hidden_dim, rank=overrides.get("rank", 1), kernel_mode=overrides.get("kernel_mode", "auto"), kernel_threshold=overrides.get("kernel_threshold", 256), gate=overrides.get("gate", True), input_gate=overrides.get("input_gate", True), activation=overrides.get("activation", "gelu"), use_D=overrides.get("use_D", True), layer_scale_init=overrides.get("layer_scale_init", 0.1), ) else: raise ValueError(kind) return SequenceForecaster(input_dim, d_model, num_layers, block_factory).to(DEVICE) def profile_train_step(model, batch_x, batch_y): optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4) model.train() batch_x = batch_x.to(DEVICE) batch_y = batch_y.to(DEVICE) optimizer.zero_grad(set_to_none=True) synchronize() t0 = time.perf_counter() with torch.autocast(device_type=DEVICE.type, enabled=USE_AMP): pred, _ = model(batch_x) loss = F.mse_loss(pred, batch_y) synchronize() t1 = time.perf_counter() scaler.scale(loss).backward() synchronize() t2 = time.perf_counter() scaler.step(optimizer) scaler.update() synchronize() t3 = time.perf_counter() return { "forward_ms": 1000.0 * (t1 - t0), "backward_ms": 1000.0 * (t2 - t1), "optimizer_ms": 1000.0 * (t3 - t2), "loss": float(loss.detach().cpu()), } def run_epoch(model, loader, optimizer=None): training = optimizer is not None model.train(training) losses = [] synchronize() start = time.perf_counter() for batch_x, batch_y in loader: batch_x = batch_x.to(DEVICE) batch_y = batch_y.to(DEVICE) if training: optimizer.zero_grad(set_to_none=True) with torch.autocast(device_type=DEVICE.type, enabled=USE_AMP): pred, _ = model(batch_x) loss = F.mse_loss(pred, batch_y) if training: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() losses.append(loss.detach().item()) synchronize() return float(np.mean(losses)), time.perf_counter() - start def benchmark_inference(model, sample_x): model.eval() sample_x = sample_x.to(DEVICE) with torch.no_grad(): for _ in range(2): _ = model(sample_x) synchronize() t0 = time.perf_counter() pred, _ = model(sample_x) synchronize() full_latency = time.perf_counter() - t0 hidden = model.in_proj(sample_x) states, caches, outputs = [], [], [] synchronize() t_cache = time.perf_counter() for layer in model.layers: ssm = getattr(layer, "ssm", None) if ssm is None: states.append(None) caches.append(None) else: states.append(ssm.init_state(sample_x.size(0), DEVICE, hidden.dtype)) if hasattr(layer, "allocate_inference_cache"): caches.append(layer.allocate_inference_cache(sample_x.size(0), sample_x.size(1), DEVICE, hidden.dtype)) else: caches.append(None) synchronize() cache_setup = time.perf_counter() - t_cache synchronize() t1 = time.perf_counter() for step in range(sample_x.size(1)): token = hidden[:, step, :] new_outputs = token for idx, layer in enumerate(model.layers): if caches[idx] is None: new_outputs, states[idx] = layer.step(new_outputs, states[idx]) else: new_outputs, states[idx] = layer.step(new_outputs, states[idx], cache=caches[idx]) outputs.append(new_outputs) recurrent = model.out_proj(torch.stack(outputs, dim=1)) synchronize() recurrent_latency = time.perf_counter() - t1 def run_cached_recurrent(cache_allocator_name): hidden_local = model.in_proj(sample_x) states_local, caches_local, outputs_local = [], [], [] synchronize() cache_start = time.perf_counter() for layer in model.layers: ssm = getattr(layer, "ssm", None) if ssm is None: states_local.append(None) caches_local.append(None) else: states_local.append(ssm.init_state(sample_x.size(0), DEVICE, hidden_local.dtype)) if hasattr(layer, cache_allocator_name): allocator = getattr(layer, cache_allocator_name) caches_local.append(allocator(sample_x.size(0), sample_x.size(1), DEVICE, hidden_local.dtype)) elif hasattr(layer, "allocate_deployment_cache"): caches_local.append(layer.allocate_deployment_cache(sample_x.size(0), sample_x.size(1), DEVICE, hidden_local.dtype)) elif hasattr(layer, "allocate_inference_cache"): caches_local.append(layer.allocate_inference_cache(sample_x.size(0), sample_x.size(1), DEVICE, hidden_local.dtype)) else: caches_local.append(None) synchronize() cache_elapsed = time.perf_counter() - cache_start synchronize() start = time.perf_counter() for step in range(sample_x.size(1)): token = hidden_local[:, step, :] new_outputs = token for idx, layer in enumerate(model.layers): if caches_local[idx] is None: new_outputs, states_local[idx] = layer.step(new_outputs, states_local[idx]) else: new_outputs, states_local[idx] = layer.step(new_outputs, states_local[idx], cache=caches_local[idx]) outputs_local.append(new_outputs) recurrent_out = model.out_proj(torch.stack(outputs_local, dim=1)) synchronize() elapsed = time.perf_counter() - start return cache_elapsed, elapsed, recurrent_out lightweight_latency = float("nan") lightweight_tokens_per_s = float("nan") deploy_supported = any(hasattr(layer, "allocate_deployment_cache") for layer in model.layers) if deploy_supported: cache_setup_lightweight, lightweight_latency, recurrent_light = run_cached_recurrent("allocate_deployment_cache") lightweight_tokens_per_s = (sample_x.shape[0] * sample_x.shape[1]) / max(lightweight_latency, 1e-9) else: cache_setup_lightweight = float("nan") recurrent_light = None balanced_latency = float("nan") balanced_tokens_per_s = float("nan") balanced_supported = any(hasattr(layer, "allocate_balanced_deployment_cache") for layer in model.layers) if balanced_supported: cache_setup_balanced, balanced_latency, recurrent_balanced = run_cached_recurrent("allocate_balanced_deployment_cache") balanced_tokens_per_s = (sample_x.shape[0] * sample_x.shape[1]) / max(balanced_latency, 1e-9) else: cache_setup_balanced = float("nan") recurrent_balanced = None tokens = sample_x.shape[0] * sample_x.shape[1] return { "full_latency_ms": 1000.0 * full_latency, "full_tokens_per_s": tokens / max(full_latency, 1e-9), "cache_setup_ms": 1000.0 * cache_setup, "recurrent_latency_ms": 1000.0 * recurrent_latency, "recurrent_tokens_per_s": tokens / max(recurrent_latency, 1e-9), "recurrent_match_mse": float(F.mse_loss(recurrent, pred).detach().cpu()), "deploy_supported": deploy_supported, "deploy_cache_setup_ms": 1000.0 * cache_setup_lightweight, "deploy_recurrent_latency_ms": 1000.0 * lightweight_latency if lightweight_latency == lightweight_latency else float("nan"), "deploy_recurrent_tokens_per_s": lightweight_tokens_per_s, "deploy_match_mse": float(F.mse_loss(recurrent_light, pred).detach().cpu()) if recurrent_light is not None else float("nan"), "balanced_deploy_supported": balanced_supported, "balanced_deploy_cache_setup_ms": 1000.0 * cache_setup_balanced, "balanced_deploy_recurrent_latency_ms": 1000.0 * balanced_latency if balanced_latency == balanced_latency else float("nan"), "balanced_deploy_recurrent_tokens_per_s": balanced_tokens_per_s, "balanced_deploy_match_mse": float(F.mse_loss(recurrent_balanced, pred).detach().cpu()) if recurrent_balanced is not None else float("nan"), "prediction": pred.detach().cpu(), "recurrent_prediction": recurrent.detach().cpu(), } def show_benchmark_tables(df, title="Benchmark"): if df.empty: display(df) return def available(columns): return [col for col in columns if col in df.columns] ordered = df.sort_values(["task", "kind"]).reset_index(drop=True) normal_cols = available([ "task", "kind", "params", "train_loss", "val_loss", "mean_epoch_time_s", "expected_full_mode", "forward_ms", "backward_ms", "optimizer_ms", "full_latency_ms", "full_tokens_per_s", "cache_setup_ms", "recurrent_latency_ms", "recurrent_tokens_per_s", "recurrent_match_mse", ]) deploy_cols = available([ "task", "kind", "deploy_supported", "deploy_cache_setup_ms", "deploy_recurrent_latency_ms", "deploy_recurrent_tokens_per_s", "deploy_match_mse", ]) balanced_cols = available([ "task", "kind", "balanced_deploy_supported", "balanced_deploy_cache_setup_ms", "balanced_deploy_recurrent_latency_ms", "balanced_deploy_recurrent_tokens_per_s", "balanced_deploy_match_mse", ]) print(f"{title}: normal/full-sequence and full recurrent") display(ordered[normal_cols]) print(f"{title}: deployment-lite recurrent") display(ordered[deploy_cols]) print(f"{title}: balanced deployment recurrent") display(ordered[balanced_cols]) """ ) def quick_notebook(): cells = [ md( """ # Gamma Baseline vs Gamma S4 Enhanced Quick Benchmark This is the fast notebook for day-to-day iteration. It is intentionally narrow: - compare `gamma_baseline`, `gamma_s4_enhanced`, and `s4_ternary_dplr_ssm` - use practical sequence lengths - keep `kernel_mode` conservative - report training speed, inference speed, and one-step profiling """ ), setup_cell(), imports_cell(), code( r""" QUICK_TASKS = { "simple": dict(seq_len=192, features=4, train_samples=256, val_samples=64, epochs=4, batch_size=32, d_model=48, hidden_dim=64, num_layers=2, complexity=1), "moderate": dict(seq_len=320, features=6, train_samples=320, val_samples=80, epochs=5, batch_size=24, d_model=64, hidden_dim=96, num_layers=2, complexity=2), } MODEL_OVERRIDES = { "gamma_baseline": {}, "gamma_s4_enhanced": { "kernel_mode": "auto", "kernel_threshold": 384, "discretization": "bilinear", "gate": True, "input_gate": True, "activation": "gelu", "use_D": True, "layer_scale_init": 0.1, }, "s4_ternary_dplr_ssm": { "kernel_mode": "auto", "kernel_threshold": 256, "rank": 1, "gate": True, "input_gate": True, "activation": "gelu", "use_D": True, "layer_scale_init": 0.1, }, } ACTIVE_TASKS = ["simple", "moderate"] MODELS = ["gamma_baseline", "gamma_s4_enhanced", "s4_ternary_dplr_ssm"] """ ), shared_helpers_cell(), code( r""" def train_and_benchmark(task_name, kind): cfg = QUICK_TASKS[task_name] train_ds = make_forecasting_split(cfg, "train", seed=SEED + 11) val_ds = make_forecasting_split(cfg, "val", seed=SEED + 29) train_loader = DataLoader(train_ds, batch_size=cfg["batch_size"], shuffle=True) val_loader = DataLoader(val_ds, batch_size=cfg["batch_size"], shuffle=False) model = build_forecasting_model(kind, cfg, overrides=MODEL_OVERRIDES.get(kind)) optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=1e-4) history = {"train_loss": [], "val_loss": [], "epoch_time_s": []} first_batch_x, first_batch_y = next(iter(train_loader)) profile = profile_train_step(build_forecasting_model(kind, cfg, overrides=MODEL_OVERRIDES.get(kind)), first_batch_x, first_batch_y) for epoch in range(cfg["epochs"]): train_loss, epoch_time = run_epoch(model, train_loader, optimizer=optimizer) val_loss, _ = run_epoch(model, val_loader) history["train_loss"].append(train_loss) history["val_loss"].append(val_loss) history["epoch_time_s"].append(epoch_time) print(f"{task_name} | {kind} | epoch={epoch+1:02d} train={train_loss:.6f} val={val_loss:.6f}") sample_x, sample_y = next(iter(val_loader)) sample_x = sample_x[:2] sample_y = sample_y[:2] inf = benchmark_inference(model, sample_x) metrics = { "task": task_name, "kind": kind, "params": sum(p.numel() for p in model.parameters()), "train_loss": history["train_loss"][-1], "val_loss": history["val_loss"][-1], "mean_epoch_time_s": float(np.mean(history["epoch_time_s"])), **profile, "sample_target": sample_y.cpu(), **inf, } return metrics, history, model """ ), md( """ ## Run the Quick Experiment This is the cell you will usually run in Colab first. """ ), code( r""" all_metrics = [] histories = {} trained_models = {} for task_name in ACTIVE_TASKS: for kind in MODELS: metrics, history, model = train_and_benchmark(task_name, kind) all_metrics.append({k: v for k, v in metrics.items() if k not in {"sample_target", "prediction", "recurrent_prediction"}}) histories[(task_name, kind)] = history trained_models[(task_name, kind)] = metrics summary_df = pd.DataFrame(all_metrics).sort_values(["task", "val_loss"]).reset_index(drop=True) show_benchmark_tables(summary_df, title="Quick benchmark") """ ), code( r""" fig, axes = plt.subplots(2, 2, figsize=(14, 8)) metrics_to_plot = ["val_loss", "mean_epoch_time_s", "full_tokens_per_s", "recurrent_tokens_per_s"] for ax, metric in zip(axes.flatten(), metrics_to_plot): pivot = summary_df.pivot(index="task", columns="kind", values=metric).loc[ACTIVE_TASKS] pivot.plot(ax=ax, marker="o") ax.set_title(metric.replace("_", " ").title()) ax.grid(alpha=0.2) plt.tight_layout() plt.show() """ ), md( """ The benchmark tables above are intentionally split into normal, deployment-lite, and balanced deployment views so the important columns remain visible in Colab. """ ), code( r""" PLOT_TASK = ACTIVE_TASKS[-1] baseline = trained_models[(PLOT_TASK, "gamma_baseline")] target = baseline["sample_target"][0].numpy() baseline_pred = baseline["prediction"][0].numpy() comparison_kinds = [kind for kind in MODELS if kind != "gamma_baseline"] for compare_kind in comparison_kinds: candidate = trained_models[(PLOT_TASK, compare_kind)] candidate_pred = candidate["prediction"][0].numpy() channels = range(min(3, target.shape[-1])) time_axis = np.arange(target.shape[0]) fig, axes = plt.subplots(len(list(channels)), 1, figsize=(12, 3.5 * len(list(channels))), sharex=True) if target.shape[-1] == 1: axes = [axes] for row, channel in enumerate(channels): ax = axes[row] ax.plot(time_axis, target[:, channel], label="ground truth", linewidth=2) ax.plot(time_axis, baseline_pred[:, channel], label="baseline", alpha=0.9) ax.plot(time_axis, candidate_pred[:, channel], label=compare_kind, alpha=0.9) ax.set_title(f"{PLOT_TASK} task, channel {channel}") ax.grid(alpha=0.2) if row == 0: ax.legend() plt.tight_layout() plt.show() """ ), md( """ ## Reading This Notebook Use this notebook for fast feedback: - if either S4-style model loses badly here, do not trust it on bigger tasks - if one of them is competitive here, then move to the research notebook """ ), ] return cells def research_notebook(): cells = [ md( """ # Gamma S4 Practical Benchmark This notebook is the second benchmark track after the quick notebook. It is meant to be closer to practical sequence modeling while staying reasonable on Colab: - one harder long-context synthetic benchmark - enhanced-model ablations on that harder task - an optional lightweight token benchmark """ ), setup_cell(), imports_cell(), code( r""" PRACTICAL_CONFIGS = { "current_reference": dict(seq_len=320, features=6, train_samples=320, val_samples=80, epochs=5, batch_size=24, d_model=64, hidden_dim=96, num_layers=2, complexity=2), "long_context": dict(seq_len=768, features=8, train_samples=256, val_samples=64, epochs=4, batch_size=12, d_model=80, hidden_dim=128, num_layers=3, complexity=3), } RUN_PRACTICAL_SWEEP = True RUN_ABLATIONS = False RUN_TOKEN_TASK = True PRACTICAL_MODELS = ["gamma_baseline", "gamma_s4_enhanced", "s4_ternary_dplr_ssm"] MODEL_OVERRIDES = { "gamma_baseline": {}, "gamma_s4_enhanced": { "kernel_mode": "auto", "kernel_threshold": 384, "discretization": "bilinear", "gate": True, "input_gate": True, "activation": "gelu", "use_D": True, "layer_scale_init": 0.1, }, "gamma_s4_minimal": { "kernel_mode": "auto", "kernel_threshold": 512, "discretization": "bilinear", "use_D": True, }, "s4_ternary_dplr_ssm": { "kernel_mode": "auto", "kernel_threshold": 256, "rank": 1, "gate": True, "input_gate": True, "activation": "gelu", "use_D": True, "layer_scale_init": 0.1, }, } ABLATIONS = [ ("default", {}), ("no_input_gate", {"input_gate": False}), ("no_gate", {"gate": False}), ("no_skip_D", {"use_D": False}), ("euler", {"discretization": "euler"}), ] """ ), shared_helpers_cell(), code( r""" def train_practical_model(task_name, kind, overrides=None): cfg = PRACTICAL_CONFIGS[task_name] train_ds = make_forecasting_split(cfg, "train", seed=SEED + 101) val_ds = make_forecasting_split(cfg, "val", seed=SEED + 151) train_loader = DataLoader(train_ds, batch_size=cfg["batch_size"], shuffle=True) val_loader = DataLoader(val_ds, batch_size=cfg["batch_size"], shuffle=False) model = build_forecasting_model(kind, cfg, overrides=overrides or MODEL_OVERRIDES.get(kind)) optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=1e-4) history = {"train_loss": [], "val_loss": [], "epoch_time_s": []} for epoch in range(cfg["epochs"]): train_loss, epoch_time = run_epoch(model, train_loader, optimizer=optimizer) val_loss, _ = run_epoch(model, val_loader) history["train_loss"].append(train_loss) history["val_loss"].append(val_loss) history["epoch_time_s"].append(epoch_time) print(f"{task_name} | {kind} | epoch={epoch+1:02d} train={train_loss:.6f} val={val_loss:.6f}") sample_x, sample_y = next(iter(val_loader)) sample_x = sample_x[:2] sample_y = sample_y[:2] inf = benchmark_inference(model, sample_x) result = { "task": task_name, "kind": kind, "params": sum(p.numel() for p in model.parameters()), "train_loss": history["train_loss"][-1], "val_loss": history["val_loss"][-1], "mean_epoch_time_s": float(np.mean(history["epoch_time_s"])), "expected_full_mode": "conv" if (kind != "gamma_baseline" and cfg["seq_len"] >= MODEL_OVERRIDES.get(kind, {}).get("kernel_threshold", 10**9)) else "recurrent_like", "sample_target": sample_y.cpu(), **inf, } return result, history, model """ ), md( """ ## Practical Sweep """ ), code( r""" practical_rows = [] practical_artifacts = {} practical_models = {} if RUN_PRACTICAL_SWEEP: for task_name in PRACTICAL_CONFIGS: for kind in PRACTICAL_MODELS: result, _, model = train_practical_model(task_name, kind) practical_artifacts[(task_name, kind)] = result practical_models[(task_name, kind)] = model practical_rows.append({k: v for k, v in result.items() if k not in {"sample_target", "prediction", "recurrent_prediction"}}) practical_df = pd.DataFrame(practical_rows) show_benchmark_tables(practical_df, title="Practical sweep") """ ), code( r""" if not practical_df.empty: fig, axes = plt.subplots(2, 2, figsize=(14, 8)) for ax, metric in zip(axes.flatten(), ["val_loss", "mean_epoch_time_s", "full_tokens_per_s", "recurrent_tokens_per_s"]): pivot = practical_df.pivot(index="task", columns="kind", values=metric) pivot.plot(ax=ax, marker="o") ax.set_title(metric.replace("_", " ").title()) ax.grid(alpha=0.2) plt.tight_layout() plt.show() """ ), md( """ ## Practical Inference Scaling This section benchmarks trained practical models across sequence lengths and batch sizes. It separates full-sequence prefill-like throughput from recurrent/deployment decode-like throughput. """ ), code( r""" inference_scaling_rows = [] if RUN_PRACTICAL_SWEEP and practical_models: INFERENCE_SEQ_LENS = [64, 192, 384, 768] INFERENCE_BATCH_SIZES = [1, 2] for (task_name, kind), model in practical_models.items(): feature_dim = model.out_proj.out_features for seq_len in INFERENCE_SEQ_LENS: for batch_size in INFERENCE_BATCH_SIZES: sample_x = torch.randn(batch_size, seq_len, feature_dim) metrics = benchmark_inference(model, sample_x) inference_scaling_rows.append({ "task": task_name, "kind": kind, "batch_size": batch_size, "seq_len": seq_len, "full_latency_ms": metrics["full_latency_ms"], "full_tokens_per_s": metrics["full_tokens_per_s"], "recurrent_latency_ms": metrics["recurrent_latency_ms"], "recurrent_tokens_per_s": metrics["recurrent_tokens_per_s"], "deploy_recurrent_latency_ms": metrics["deploy_recurrent_latency_ms"], "deploy_recurrent_tokens_per_s": metrics["deploy_recurrent_tokens_per_s"], "balanced_deploy_recurrent_latency_ms": metrics["balanced_deploy_recurrent_latency_ms"], "balanced_deploy_recurrent_tokens_per_s": metrics["balanced_deploy_recurrent_tokens_per_s"], "recurrent_match_mse": metrics["recurrent_match_mse"], "deploy_match_mse": metrics["deploy_match_mse"], "balanced_deploy_match_mse": metrics["balanced_deploy_match_mse"], }) inference_scaling_df = pd.DataFrame(inference_scaling_rows) display(inference_scaling_df.sort_values(["task", "kind", "seq_len", "batch_size"])) """ ), code( r""" if "inference_scaling_df" in globals() and not inference_scaling_df.empty: for task_name in inference_scaling_df["task"].unique(): subset = inference_scaling_df[(inference_scaling_df["task"] == task_name) & (inference_scaling_df["batch_size"] == 1)] fig, axes = plt.subplots(1, 2, figsize=(14, 4)) for kind in subset["kind"].unique(): model_subset = subset[subset["kind"] == kind].sort_values("seq_len") axes[0].plot(model_subset["seq_len"], model_subset["full_tokens_per_s"], marker="o", label=kind) axes[1].plot(model_subset["seq_len"], model_subset["recurrent_tokens_per_s"], marker="o", label=kind) axes[0].set_title(f"{task_name}: full-sequence throughput, batch=1") axes[1].set_title(f"{task_name}: recurrent throughput, batch=1") for ax in axes: ax.set_xlabel("seq_len") ax.set_ylabel("tokens/s") ax.grid(alpha=0.25) ax.legend() plt.tight_layout() plt.show() """ ), md( """ ## Task Visual Preview These plots show what the synthetic tasks actually look like before we discuss scores. """ ), code( r""" if RUN_PRACTICAL_SWEEP: fig, axes = plt.subplots(len(PRACTICAL_CONFIGS), 1, figsize=(14, 3.5 * len(PRACTICAL_CONFIGS)), sharex=False) if len(PRACTICAL_CONFIGS) == 1: axes = [axes] for ax, task_name in zip(axes, PRACTICAL_CONFIGS): cfg = PRACTICAL_CONFIGS[task_name] preview_ds = make_forecasting_split(cfg, "val", seed=SEED + 151) preview_x, preview_y = preview_ds[0] channels = range(min(3, preview_y.shape[-1])) time_axis = np.arange(preview_y.shape[0]) for channel in channels: ax.plot(time_axis, preview_y[:, channel].numpy(), label=f"channel {channel}", linewidth=1.5) ax.set_title(f"{task_name} target preview") ax.grid(alpha=0.2) if task_name == list(PRACTICAL_CONFIGS.keys())[0]: ax.legend(ncol=min(3, preview_y.shape[-1])) plt.tight_layout() plt.show() """ ), md( """ ## Prediction Comparison Plots These are the most presentation-friendly plots in the notebook: ground truth vs baseline vs each S4-style model on the same held-out sample. """ ), code( r""" if practical_artifacts: for task_name in PRACTICAL_CONFIGS: baseline = practical_artifacts[(task_name, "gamma_baseline")] target = baseline["sample_target"][0].numpy() baseline_pred = baseline["prediction"][0].numpy() for compare_kind in [kind for kind in PRACTICAL_MODELS if kind != "gamma_baseline"]: candidate = practical_artifacts[(task_name, compare_kind)] candidate_pred = candidate["prediction"][0].numpy() channels = range(min(3, target.shape[-1])) time_axis = np.arange(target.shape[0]) fig, axes = plt.subplots(len(list(channels)), 1, figsize=(14, 3.5 * len(list(channels))), sharex=True) if target.shape[-1] == 1: axes = [axes] for row, channel in enumerate(channels): ax = axes[row] ax.plot(time_axis, target[:, channel], label="ground truth", linewidth=2) ax.plot(time_axis, baseline_pred[:, channel], label="baseline", alpha=0.9) ax.plot(time_axis, candidate_pred[:, channel], label=compare_kind, alpha=0.9) ax.set_title(f"{task_name} prediction comparison, channel {channel}") ax.grid(alpha=0.2) if row == 0: ax.legend() plt.tight_layout() plt.show() """ ), md( """ ## Error Comparison Plots These show where each model is missing the target signal. Lower absolute error should visually hug zero. """ ), code( r""" if practical_artifacts: for task_name in PRACTICAL_CONFIGS: baseline = practical_artifacts[(task_name, "gamma_baseline")] target = baseline["sample_target"][0].numpy() baseline_err = baseline["prediction"][0].numpy() - target for compare_kind in [kind for kind in PRACTICAL_MODELS if kind != "gamma_baseline"]: candidate = practical_artifacts[(task_name, compare_kind)] candidate_err = candidate["prediction"][0].numpy() - target channels = range(min(2, target.shape[-1])) time_axis = np.arange(target.shape[0]) fig, axes = plt.subplots(len(list(channels)), 1, figsize=(14, 3.0 * len(list(channels))), sharex=True) if len(list(channels)) == 1: axes = [axes] for row, channel in enumerate(channels): ax = axes[row] ax.plot(time_axis, baseline_err[:, channel], label="baseline error", alpha=0.9) ax.plot(time_axis, candidate_err[:, channel], label=f"{compare_kind} error", alpha=0.9) ax.axhline(0.0, color="black", linewidth=1, alpha=0.5) ax.set_title(f"{task_name} error comparison, channel {channel}") ax.grid(alpha=0.2) if row == 0: ax.legend() plt.tight_layout() plt.show() """ ), md( """ ## Enhanced Ablations On Long Context """ ), code( r""" ablation_rows = [] if RUN_ABLATIONS: for name, override in ABLATIONS: merged = {**MODEL_OVERRIDES["gamma_s4_enhanced"], **override} result, _, _ = train_practical_model("long_context", "gamma_s4_enhanced", overrides=merged) row = {k: v for k, v in result.items() if k not in {"sample_target", "prediction", "recurrent_prediction"}} row["ablation"] = name ablation_rows.append(row) ablation_df = pd.DataFrame(ablation_rows) ablation_df """ ), code( r""" if not ablation_df.empty: fig, axes = plt.subplots(1, 3, figsize=(16, 4)) ablation_df.sort_values("val_loss").plot.bar(x="ablation", y="val_loss", ax=axes[0], legend=False) ablation_df.sort_values("mean_epoch_time_s").plot.bar(x="ablation", y="mean_epoch_time_s", ax=axes[1], legend=False) ablation_df.sort_values("recurrent_tokens_per_s").plot.bar(x="ablation", y="recurrent_tokens_per_s", ax=axes[2], legend=False) axes[0].set_title("Validation Loss") axes[1].set_title("Mean Epoch Time") axes[2].set_title("Recurrent Tokens / s") for ax in axes: ax.tick_params(axis="x", rotation=25) ax.grid(alpha=0.2) plt.tight_layout() plt.show() """ ), md( """ ## Optional Token-Lite Task """ ), code( r""" @torch.no_grad() def benchmark_token_inference(model, token_batch, target_batch, vocab_size, repeats=6): model.eval() x = token_batch.to(DEVICE) y = target_batch.to(DEVICE) batch, seq_len = x.shape token_count = batch * seq_len rows = [] def reset_memory(): if DEVICE.type == "cuda": torch.cuda.reset_peak_memory_stats() def max_memory_mb(): if DEVICE.type != "cuda": return float("nan") return torch.cuda.max_memory_allocated() / (1024 ** 2) reset_memory() with torch.autocast(device_type=DEVICE.type, enabled=USE_AMP): logits = model(x) loss = F.cross_entropy(logits.reshape(-1, vocab_size), y.reshape(-1)) synchronize() start = time.perf_counter() for _ in range(repeats): with torch.autocast(device_type=DEVICE.type, enabled=USE_AMP): logits = model(x) synchronize() elapsed = time.perf_counter() - start full_logits = logits.detach() rows.append({ "mode": "prefill_full_sequence", "latency_ms": 1000.0 * elapsed / repeats, "tokens_per_s": token_count * repeats / max(elapsed, 1e-9), "ce": float(loss.detach().cpu()), "match_mse": 0.0, "max_memory_mb": max_memory_mb(), }) hidden = model.embed(x) states = [] caches = [] cache_start = time.perf_counter() for layer in model.layers: state = layer.ssm.init_state(hidden.size(0), DEVICE, hidden.dtype) states.append(state) if hasattr(layer, "allocate_inference_cache"): caches.append(layer.allocate_inference_cache(hidden.size(0), seq_len, DEVICE, hidden.dtype)) else: caches.append(None) synchronize() cache_setup_ms = 1000.0 * (time.perf_counter() - cache_start) def recurrent_pass(cache_list): local_states = [state.clone() for state in states] outputs = [] for t in range(seq_len): step_x = hidden[:, t, :] for layer_idx, layer in enumerate(model.layers): cache = None if cache_list is None else cache_list[layer_idx] try: step_x, local_states[layer_idx] = layer.step(step_x, local_states[layer_idx], cache=cache) except TypeError: step_x, local_states[layer_idx] = layer.step(step_x, local_states[layer_idx]) outputs.append(model.head(step_x)) return torch.stack(outputs, dim=1) reset_memory() recurrent_logits = recurrent_pass(caches) synchronize() start = time.perf_counter() for _ in range(repeats): recurrent_logits = recurrent_pass(caches) synchronize() elapsed = time.perf_counter() - start recurrent_loss = F.cross_entropy(recurrent_logits.reshape(-1, vocab_size), y.reshape(-1)) rows.append({ "mode": "decode_recurrent_exact", "cache_setup_ms": cache_setup_ms, "latency_ms": 1000.0 * elapsed / repeats, "tokens_per_s": token_count * repeats / max(elapsed, 1e-9), "ce": float(recurrent_loss.detach().cpu()), "match_mse": float(F.mse_loss(recurrent_logits, full_logits).detach().cpu()), "max_memory_mb": max_memory_mb(), }) deploy_supported = all(hasattr(layer, "allocate_deployment_cache") for layer in model.layers) if deploy_supported: cache_start = time.perf_counter() deploy_caches = [ layer.allocate_deployment_cache(hidden.size(0), seq_len, DEVICE, hidden.dtype) for layer in model.layers ] synchronize() deploy_cache_setup_ms = 1000.0 * (time.perf_counter() - cache_start) reset_memory() deploy_logits = recurrent_pass(deploy_caches) synchronize() start = time.perf_counter() for _ in range(repeats): deploy_logits = recurrent_pass(deploy_caches) synchronize() elapsed = time.perf_counter() - start deploy_loss = F.cross_entropy(deploy_logits.reshape(-1, vocab_size), y.reshape(-1)) rows.append({ "mode": "decode_deploy_lite", "cache_setup_ms": deploy_cache_setup_ms, "latency_ms": 1000.0 * elapsed / repeats, "tokens_per_s": token_count * repeats / max(elapsed, 1e-9), "ce": float(deploy_loss.detach().cpu()), "match_mse": float(F.mse_loss(deploy_logits, full_logits).detach().cpu()), "max_memory_mb": max_memory_mb(), }) balanced_supported = all(hasattr(layer, "allocate_balanced_deployment_cache") for layer in model.layers) if balanced_supported: cache_start = time.perf_counter() balanced_caches = [ layer.allocate_balanced_deployment_cache(hidden.size(0), seq_len, DEVICE, hidden.dtype) for layer in model.layers ] synchronize() balanced_cache_setup_ms = 1000.0 * (time.perf_counter() - cache_start) reset_memory() balanced_logits = recurrent_pass(balanced_caches) synchronize() start = time.perf_counter() for _ in range(repeats): balanced_logits = recurrent_pass(balanced_caches) synchronize() elapsed = time.perf_counter() - start balanced_loss = F.cross_entropy(balanced_logits.reshape(-1, vocab_size), y.reshape(-1)) rows.append({ "mode": "decode_balanced", "cache_setup_ms": balanced_cache_setup_ms, "latency_ms": 1000.0 * elapsed / repeats, "tokens_per_s": token_count * repeats / max(elapsed, 1e-9), "ce": float(balanced_loss.detach().cpu()), "match_mse": float(F.mse_loss(balanced_logits, full_logits).detach().cpu()), "max_memory_mb": max_memory_mb(), }) return rows """ ), code( r""" if RUN_TOKEN_TASK: TOKEN_DATA_URL = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" TOKEN_DATA_PATH = Path("tmp/jupyter-notebook/tinyshakespeare.txt") TOKEN_DATA_PATH.parent.mkdir(parents=True, exist_ok=True) if not TOKEN_DATA_PATH.exists(): urllib.request.urlretrieve(TOKEN_DATA_URL, TOKEN_DATA_PATH) text = TOKEN_DATA_PATH.read_text(encoding="utf-8") vocab = sorted(set(text)) stoi = {ch: i for i, ch in enumerate(vocab)} tokens = torch.tensor([stoi[ch] for ch in text], dtype=torch.long) TOKEN_CFG = { "seq_len": 192, "train_samples": 1200, "val_samples": 240, "epochs": 2, "batch_size": 12 if DEVICE.type == "cuda" else 6, } def make_token_split(seq_len, train_samples, val_samples): max_start = len(tokens) - seq_len - 1 starts = torch.linspace(0, max_start - 1, steps=train_samples + val_samples).long() x = torch.stack([tokens[s : s + seq_len] for s in starts]) y = torch.stack([tokens[s + 1 : s + seq_len + 1] for s in starts]) return TensorDataset(x[:train_samples], y[:train_samples]), TensorDataset(x[train_samples:], y[train_samples:]) class TokenForecaster(nn.Module): def __init__(self, vocab_size, kind): super().__init__() self.embed = nn.Embedding(vocab_size, 64) if kind == "gamma_baseline": factory = lambda: GammaSingleBlock(d_model=64, hidden_dim=96, dropout=0.0) elif kind == "gamma_s4_enhanced": factory = lambda: GammaS4Block(d_model=64, hidden_dim=96, kernel_mode="auto", kernel_threshold=160) elif kind == "s4_ternary_dplr_ssm": factory = lambda: S4TernaryDPLRBlock(d_model=64, hidden_dim=96, kernel_mode="auto", kernel_threshold=160) else: factory = lambda: GammaS4MinimalBlock(d_model=64, hidden_dim=96, kernel_mode="auto", kernel_threshold=160) self.layers = nn.ModuleList([factory(), factory()]) self.head = nn.Linear(64, vocab_size) def forward(self, x): x = self.embed(x) for layer in self.layers: x, _ = layer(x, state=None, return_state=False) return self.head(x) token_train, token_val = make_token_split( seq_len=TOKEN_CFG["seq_len"], train_samples=TOKEN_CFG["train_samples"], val_samples=TOKEN_CFG["val_samples"], ) token_rows = [] token_inference_rows = [] for kind in PRACTICAL_MODELS: model = TokenForecaster(len(vocab), kind).to(DEVICE) optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=1e-4) train_loader = DataLoader(token_train, batch_size=TOKEN_CFG["batch_size"], shuffle=True) val_loader = DataLoader(token_val, batch_size=TOKEN_CFG["batch_size"], shuffle=False) history = [] for epoch in range(TOKEN_CFG["epochs"]): model.train() train_losses = [] for batch_x, batch_y in train_loader: batch_x = batch_x.to(DEVICE) batch_y = batch_y.to(DEVICE) optimizer.zero_grad(set_to_none=True) with torch.autocast(device_type=DEVICE.type, enabled=USE_AMP): logits = model(batch_x) loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch_y.reshape(-1)) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() train_losses.append(loss.detach().item()) model.eval() val_losses = [] with torch.no_grad(): for batch_x, batch_y in val_loader: batch_x = batch_x.to(DEVICE) batch_y = batch_y.to(DEVICE) logits = model(batch_x) loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch_y.reshape(-1)) val_losses.append(loss.detach().item()) history.append((float(np.mean(train_losses)), float(np.mean(val_losses)))) print(kind, epoch + 1, history[-1]) token_rows.append({ "kind": kind, "train_ce": history[-1][0], "val_ce": history[-1][1], "val_ppl": math.exp(history[-1][1]), "seq_len": TOKEN_CFG["seq_len"], "train_samples": TOKEN_CFG["train_samples"], }) sample_x, sample_y = next(iter(val_loader)) for row in benchmark_token_inference(model, sample_x[:2], sample_y[:2], len(vocab)): row["kind"] = kind row["seq_len"] = TOKEN_CFG["seq_len"] row["batch_size"] = min(2, sample_x.size(0)) token_inference_rows.append(row) token_df = pd.DataFrame(token_rows).sort_values("val_ce") display(token_df) token_inference_df = pd.DataFrame(token_inference_rows) display(token_inference_df.sort_values(["kind", "mode"])) """ ), md( """ Use this notebook after the quick benchmark. The `long_context` task is the more practical synthetic benchmark, and the optional token-lite section gives a small language-like check without making Colab costs too high. `RUN_ABLATIONS` is off by default because the long-context ablation sweep is still materially more expensive than the main comparison. """ ), ] return cells def write_notebook(path, cells): notebook = { "cells": cells, "metadata": { "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"name": "python", "version": "3.11"}, }, "nbformat": 4, "nbformat_minor": 5, } path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(notebook, indent=2), encoding="utf-8") print(f"Wrote {path}") write_notebook(QUICK_NOTEBOOK_PATH, quick_notebook()) write_notebook(RESEARCH_NOTEBOOK_PATH, research_notebook())