TaoNet-mini-T2 / code /Taotern_SSM /scripts /generate_gamma_benchmark_notebook.py
StarMist0012's picture
Add files using upload-large-folder tool
388fd6e verified
import json
from pathlib import Path
def md(text):
return {
"cell_type": "markdown",
"metadata": {},
"source": text.strip("\n").splitlines(keepends=True),
}
def code(text):
return {
"cell_type": "code",
"execution_count": None,
"metadata": {},
"outputs": [],
"source": text.strip("\n").splitlines(keepends=True),
}
QUICK_NOTEBOOK_PATH = Path("output/jupyter-notebook/gamma-s4-sinewave-benchmark.ipynb")
RESEARCH_NOTEBOOK_PATH = Path("output/jupyter-notebook/gamma-s4-research-benchmark.ipynb")
def setup_cell():
return code(
r"""
import os
import sys
import subprocess
import importlib
from pathlib import Path
IN_COLAB = "google.colab" in sys.modules
REPO_DIR = Path.cwd()
if IN_COLAB:
print("Running in Google Colab")
from google.colab import userdata
from getpass import getpass
REPO_NAME = "gamma_ssm_s4_v2"
REPO_DIR = Path("/content") / REPO_NAME
GITHUB_REPO = "StarMists/gamma_SSM_S4_enhanced"
token = os.environ.get("GITHUB_TOKEN") or os.environ.get("GH_TOKEN")
if not token:
try:
token = userdata.get("GITHUB_TOKEN")
except Exception:
token = None
if not token:
token = getpass("GitHub personal access token for private repo access: ").strip()
clone_url = f"https://{token}@github.com/{GITHUB_REPO}.git"
if REPO_DIR.exists():
subprocess.run(["git", "-C", str(REPO_DIR), "fetch", "origin"], check=True)
subprocess.run(["git", "-C", str(REPO_DIR), "checkout", "main"], check=True)
subprocess.run(["git", "-C", str(REPO_DIR), "reset", "--hard", "origin/main"], check=True)
else:
subprocess.run(["git", "clone", clone_url, str(REPO_DIR)], check=True)
os.chdir(REPO_DIR)
sys.path.insert(0, str(REPO_DIR))
else:
print("Running locally from", REPO_DIR)
sys.path.insert(0, str(REPO_DIR))
importlib.invalidate_caches()
for name in list(sys.modules):
if (
name == "gamma_space_model"
or name.startswith("gamma_space_model.")
or name == "csrc"
or name.startswith("csrc.")
or name == "tilelang"
or name.startswith("tilelang.")
):
del sys.modules[name]
try:
commit = subprocess.check_output(["git", "-C", str(REPO_DIR), "rev-parse", "--short", "HEAD"], text=True).strip()
print("Repo commit:", commit)
except Exception:
pass
"""
)
def imports_cell():
return code(
r"""
import math
import random
import time
import urllib.request
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from gamma_space_model import GammaSingleBlock, GammaS4Block, GammaS4MinimalBlock, S4TernaryDPLRBlock
SEED = 7
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_AMP = DEVICE.type == "cuda"
if hasattr(torch, "amp") and hasattr(torch.amp, "GradScaler"):
scaler = torch.amp.GradScaler("cuda", enabled=USE_AMP)
else:
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
def synchronize():
if DEVICE.type == "cuda":
torch.cuda.synchronize()
print("Device:", DEVICE)
print("Deployment cache available on GammaS4Block:", hasattr(GammaS4Block, "allocate_deployment_cache"))
if DEVICE.type != "cuda":
print("WARNING: running on CPU. Treat speed numbers as smoke-test only, not main benchmark evidence.")
"""
)
def shared_helpers_cell():
return code(
r"""
def make_forecasting_split(config, split, seed):
rng = np.random.default_rng(seed)
count = config["train_samples"] if split == "train" else config["val_samples"]
seq_len = config["seq_len"]
features = config["features"]
complexity = config["complexity"]
t = np.linspace(0.0, 1.0, seq_len + 1, dtype=np.float32)
data = np.zeros((count, seq_len + 1, features), dtype=np.float32)
for i in range(count):
phase = rng.uniform(0.0, 2.0 * np.pi)
chirp = np.sin(2.0 * np.pi * (1.0 + 1.75 * t**2) * rng.uniform(0.9, 1.15) + phase)
slow = np.sin(2.0 * np.pi * (0.4 + 0.25 * complexity) * t + 0.7 * phase)
medium = np.sin(2.0 * np.pi * (2.5 + complexity) * t + 1.1 * phase)
fast = np.cos(2.0 * np.pi * (5.0 + 2.0 * complexity) * t + 1.5 * phase)
bursts = (np.sin(2.0 * np.pi * (3.0 + complexity) * t + 0.4 * phase) > 0.8).astype(np.float32)
bursts = bursts * np.sin(2.0 * np.pi * (10.0 + complexity) * t + 0.3 * phase)
delayed = np.roll(chirp, 4 * complexity)
modulated = medium * (1.0 + 0.35 * slow)
components = [chirp, slow, medium, fast, bursts, delayed, modulated]
for channel in range(features):
weights = rng.normal(0.0, 1.0, size=len(components)).astype(np.float32)
weights[: 2 + complexity] *= 1.2
signal = sum(w * c for w, c in zip(weights, components))
if complexity >= 2:
signal += 0.20 * np.tanh(np.roll(signal, channel + 1))
if complexity >= 3:
signal += 0.10 * np.sin(signal * (0.5 + 0.08 * channel))
signal += rng.normal(0.0, 0.03 + 0.01 * complexity, size=seq_len + 1).astype(np.float32)
data[i, :, channel] = signal
data -= data.mean(axis=1, keepdims=True)
data /= data.std(axis=1, keepdims=True) + 1e-5
return TensorDataset(torch.from_numpy(data[:, :-1, :]), torch.from_numpy(data[:, 1:, :]))
class SequenceForecaster(nn.Module):
def __init__(self, input_dim, model_dim, layers, block_factory):
super().__init__()
self.in_proj = nn.Linear(input_dim, model_dim)
self.layers = nn.ModuleList([block_factory(model_dim) for _ in range(layers)])
self.out_proj = nn.Linear(model_dim, input_dim)
def forward(self, x):
x = self.in_proj(x)
for layer in self.layers:
x, _ = layer(x, state=None, return_state=False)
return self.out_proj(x), None
def build_forecasting_model(kind, config, overrides=None):
overrides = overrides or {}
d_model = config["d_model"]
hidden_dim = config["hidden_dim"]
num_layers = config["num_layers"]
input_dim = config["features"]
if kind == "gamma_baseline":
block_factory = lambda width: GammaSingleBlock(
d_model=width,
hidden_dim=hidden_dim,
dropout=0.0,
)
elif kind == "gamma_s4_enhanced":
block_factory = lambda width: GammaS4Block(
d_model=width,
hidden_dim=hidden_dim,
kernel_mode=overrides.get("kernel_mode", "auto"),
kernel_threshold=overrides.get("kernel_threshold", 384),
discretization=overrides.get("discretization", "bilinear"),
gate=overrides.get("gate", True),
input_gate=overrides.get("input_gate", True),
activation=overrides.get("activation", "gelu"),
use_D=overrides.get("use_D", True),
layer_scale_init=overrides.get("layer_scale_init", 0.1),
)
elif kind == "gamma_s4_minimal":
block_factory = lambda width: GammaS4MinimalBlock(
d_model=width,
hidden_dim=hidden_dim,
kernel_mode=overrides.get("kernel_mode", "auto"),
kernel_threshold=overrides.get("kernel_threshold", 384),
discretization=overrides.get("discretization", "bilinear"),
use_D=overrides.get("use_D", True),
)
elif kind == "s4_ternary_dplr_ssm":
block_factory = lambda width: S4TernaryDPLRBlock(
d_model=width,
hidden_dim=hidden_dim,
rank=overrides.get("rank", 1),
kernel_mode=overrides.get("kernel_mode", "auto"),
kernel_threshold=overrides.get("kernel_threshold", 256),
gate=overrides.get("gate", True),
input_gate=overrides.get("input_gate", True),
activation=overrides.get("activation", "gelu"),
use_D=overrides.get("use_D", True),
layer_scale_init=overrides.get("layer_scale_init", 0.1),
)
else:
raise ValueError(kind)
return SequenceForecaster(input_dim, d_model, num_layers, block_factory).to(DEVICE)
def profile_train_step(model, batch_x, batch_y):
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
model.train()
batch_x = batch_x.to(DEVICE)
batch_y = batch_y.to(DEVICE)
optimizer.zero_grad(set_to_none=True)
synchronize()
t0 = time.perf_counter()
with torch.autocast(device_type=DEVICE.type, enabled=USE_AMP):
pred, _ = model(batch_x)
loss = F.mse_loss(pred, batch_y)
synchronize()
t1 = time.perf_counter()
scaler.scale(loss).backward()
synchronize()
t2 = time.perf_counter()
scaler.step(optimizer)
scaler.update()
synchronize()
t3 = time.perf_counter()
return {
"forward_ms": 1000.0 * (t1 - t0),
"backward_ms": 1000.0 * (t2 - t1),
"optimizer_ms": 1000.0 * (t3 - t2),
"loss": float(loss.detach().cpu()),
}
def run_epoch(model, loader, optimizer=None):
training = optimizer is not None
model.train(training)
losses = []
synchronize()
start = time.perf_counter()
for batch_x, batch_y in loader:
batch_x = batch_x.to(DEVICE)
batch_y = batch_y.to(DEVICE)
if training:
optimizer.zero_grad(set_to_none=True)
with torch.autocast(device_type=DEVICE.type, enabled=USE_AMP):
pred, _ = model(batch_x)
loss = F.mse_loss(pred, batch_y)
if training:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
losses.append(loss.detach().item())
synchronize()
return float(np.mean(losses)), time.perf_counter() - start
def benchmark_inference(model, sample_x):
model.eval()
sample_x = sample_x.to(DEVICE)
with torch.no_grad():
for _ in range(2):
_ = model(sample_x)
synchronize()
t0 = time.perf_counter()
pred, _ = model(sample_x)
synchronize()
full_latency = time.perf_counter() - t0
hidden = model.in_proj(sample_x)
states, caches, outputs = [], [], []
synchronize()
t_cache = time.perf_counter()
for layer in model.layers:
ssm = getattr(layer, "ssm", None)
if ssm is None:
states.append(None)
caches.append(None)
else:
states.append(ssm.init_state(sample_x.size(0), DEVICE, hidden.dtype))
if hasattr(layer, "allocate_inference_cache"):
caches.append(layer.allocate_inference_cache(sample_x.size(0), sample_x.size(1), DEVICE, hidden.dtype))
else:
caches.append(None)
synchronize()
cache_setup = time.perf_counter() - t_cache
synchronize()
t1 = time.perf_counter()
for step in range(sample_x.size(1)):
token = hidden[:, step, :]
new_outputs = token
for idx, layer in enumerate(model.layers):
if caches[idx] is None:
new_outputs, states[idx] = layer.step(new_outputs, states[idx])
else:
new_outputs, states[idx] = layer.step(new_outputs, states[idx], cache=caches[idx])
outputs.append(new_outputs)
recurrent = model.out_proj(torch.stack(outputs, dim=1))
synchronize()
recurrent_latency = time.perf_counter() - t1
def run_cached_recurrent(cache_allocator_name):
hidden_local = model.in_proj(sample_x)
states_local, caches_local, outputs_local = [], [], []
synchronize()
cache_start = time.perf_counter()
for layer in model.layers:
ssm = getattr(layer, "ssm", None)
if ssm is None:
states_local.append(None)
caches_local.append(None)
else:
states_local.append(ssm.init_state(sample_x.size(0), DEVICE, hidden_local.dtype))
if hasattr(layer, cache_allocator_name):
allocator = getattr(layer, cache_allocator_name)
caches_local.append(allocator(sample_x.size(0), sample_x.size(1), DEVICE, hidden_local.dtype))
elif hasattr(layer, "allocate_deployment_cache"):
caches_local.append(layer.allocate_deployment_cache(sample_x.size(0), sample_x.size(1), DEVICE, hidden_local.dtype))
elif hasattr(layer, "allocate_inference_cache"):
caches_local.append(layer.allocate_inference_cache(sample_x.size(0), sample_x.size(1), DEVICE, hidden_local.dtype))
else:
caches_local.append(None)
synchronize()
cache_elapsed = time.perf_counter() - cache_start
synchronize()
start = time.perf_counter()
for step in range(sample_x.size(1)):
token = hidden_local[:, step, :]
new_outputs = token
for idx, layer in enumerate(model.layers):
if caches_local[idx] is None:
new_outputs, states_local[idx] = layer.step(new_outputs, states_local[idx])
else:
new_outputs, states_local[idx] = layer.step(new_outputs, states_local[idx], cache=caches_local[idx])
outputs_local.append(new_outputs)
recurrent_out = model.out_proj(torch.stack(outputs_local, dim=1))
synchronize()
elapsed = time.perf_counter() - start
return cache_elapsed, elapsed, recurrent_out
lightweight_latency = float("nan")
lightweight_tokens_per_s = float("nan")
deploy_supported = any(hasattr(layer, "allocate_deployment_cache") for layer in model.layers)
if deploy_supported:
cache_setup_lightweight, lightweight_latency, recurrent_light = run_cached_recurrent("allocate_deployment_cache")
lightweight_tokens_per_s = (sample_x.shape[0] * sample_x.shape[1]) / max(lightweight_latency, 1e-9)
else:
cache_setup_lightweight = float("nan")
recurrent_light = None
balanced_latency = float("nan")
balanced_tokens_per_s = float("nan")
balanced_supported = any(hasattr(layer, "allocate_balanced_deployment_cache") for layer in model.layers)
if balanced_supported:
cache_setup_balanced, balanced_latency, recurrent_balanced = run_cached_recurrent("allocate_balanced_deployment_cache")
balanced_tokens_per_s = (sample_x.shape[0] * sample_x.shape[1]) / max(balanced_latency, 1e-9)
else:
cache_setup_balanced = float("nan")
recurrent_balanced = None
tokens = sample_x.shape[0] * sample_x.shape[1]
return {
"full_latency_ms": 1000.0 * full_latency,
"full_tokens_per_s": tokens / max(full_latency, 1e-9),
"cache_setup_ms": 1000.0 * cache_setup,
"recurrent_latency_ms": 1000.0 * recurrent_latency,
"recurrent_tokens_per_s": tokens / max(recurrent_latency, 1e-9),
"recurrent_match_mse": float(F.mse_loss(recurrent, pred).detach().cpu()),
"deploy_supported": deploy_supported,
"deploy_cache_setup_ms": 1000.0 * cache_setup_lightweight,
"deploy_recurrent_latency_ms": 1000.0 * lightweight_latency if lightweight_latency == lightweight_latency else float("nan"),
"deploy_recurrent_tokens_per_s": lightweight_tokens_per_s,
"deploy_match_mse": float(F.mse_loss(recurrent_light, pred).detach().cpu()) if recurrent_light is not None else float("nan"),
"balanced_deploy_supported": balanced_supported,
"balanced_deploy_cache_setup_ms": 1000.0 * cache_setup_balanced,
"balanced_deploy_recurrent_latency_ms": 1000.0 * balanced_latency if balanced_latency == balanced_latency else float("nan"),
"balanced_deploy_recurrent_tokens_per_s": balanced_tokens_per_s,
"balanced_deploy_match_mse": float(F.mse_loss(recurrent_balanced, pred).detach().cpu()) if recurrent_balanced is not None else float("nan"),
"prediction": pred.detach().cpu(),
"recurrent_prediction": recurrent.detach().cpu(),
}
def show_benchmark_tables(df, title="Benchmark"):
if df.empty:
display(df)
return
def available(columns):
return [col for col in columns if col in df.columns]
ordered = df.sort_values(["task", "kind"]).reset_index(drop=True)
normal_cols = available([
"task",
"kind",
"params",
"train_loss",
"val_loss",
"mean_epoch_time_s",
"expected_full_mode",
"forward_ms",
"backward_ms",
"optimizer_ms",
"full_latency_ms",
"full_tokens_per_s",
"cache_setup_ms",
"recurrent_latency_ms",
"recurrent_tokens_per_s",
"recurrent_match_mse",
])
deploy_cols = available([
"task",
"kind",
"deploy_supported",
"deploy_cache_setup_ms",
"deploy_recurrent_latency_ms",
"deploy_recurrent_tokens_per_s",
"deploy_match_mse",
])
balanced_cols = available([
"task",
"kind",
"balanced_deploy_supported",
"balanced_deploy_cache_setup_ms",
"balanced_deploy_recurrent_latency_ms",
"balanced_deploy_recurrent_tokens_per_s",
"balanced_deploy_match_mse",
])
print(f"{title}: normal/full-sequence and full recurrent")
display(ordered[normal_cols])
print(f"{title}: deployment-lite recurrent")
display(ordered[deploy_cols])
print(f"{title}: balanced deployment recurrent")
display(ordered[balanced_cols])
"""
)
def quick_notebook():
cells = [
md(
"""
# Gamma Baseline vs Gamma S4 Enhanced Quick Benchmark
This is the fast notebook for day-to-day iteration.
It is intentionally narrow:
- compare `gamma_baseline`, `gamma_s4_enhanced`, and `s4_ternary_dplr_ssm`
- use practical sequence lengths
- keep `kernel_mode` conservative
- report training speed, inference speed, and one-step profiling
"""
),
setup_cell(),
imports_cell(),
code(
r"""
QUICK_TASKS = {
"simple": dict(seq_len=192, features=4, train_samples=256, val_samples=64, epochs=4, batch_size=32, d_model=48, hidden_dim=64, num_layers=2, complexity=1),
"moderate": dict(seq_len=320, features=6, train_samples=320, val_samples=80, epochs=5, batch_size=24, d_model=64, hidden_dim=96, num_layers=2, complexity=2),
}
MODEL_OVERRIDES = {
"gamma_baseline": {},
"gamma_s4_enhanced": {
"kernel_mode": "auto",
"kernel_threshold": 384,
"discretization": "bilinear",
"gate": True,
"input_gate": True,
"activation": "gelu",
"use_D": True,
"layer_scale_init": 0.1,
},
"s4_ternary_dplr_ssm": {
"kernel_mode": "auto",
"kernel_threshold": 256,
"rank": 1,
"gate": True,
"input_gate": True,
"activation": "gelu",
"use_D": True,
"layer_scale_init": 0.1,
},
}
ACTIVE_TASKS = ["simple", "moderate"]
MODELS = ["gamma_baseline", "gamma_s4_enhanced", "s4_ternary_dplr_ssm"]
"""
),
shared_helpers_cell(),
code(
r"""
def train_and_benchmark(task_name, kind):
cfg = QUICK_TASKS[task_name]
train_ds = make_forecasting_split(cfg, "train", seed=SEED + 11)
val_ds = make_forecasting_split(cfg, "val", seed=SEED + 29)
train_loader = DataLoader(train_ds, batch_size=cfg["batch_size"], shuffle=True)
val_loader = DataLoader(val_ds, batch_size=cfg["batch_size"], shuffle=False)
model = build_forecasting_model(kind, cfg, overrides=MODEL_OVERRIDES.get(kind))
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=1e-4)
history = {"train_loss": [], "val_loss": [], "epoch_time_s": []}
first_batch_x, first_batch_y = next(iter(train_loader))
profile = profile_train_step(build_forecasting_model(kind, cfg, overrides=MODEL_OVERRIDES.get(kind)), first_batch_x, first_batch_y)
for epoch in range(cfg["epochs"]):
train_loss, epoch_time = run_epoch(model, train_loader, optimizer=optimizer)
val_loss, _ = run_epoch(model, val_loader)
history["train_loss"].append(train_loss)
history["val_loss"].append(val_loss)
history["epoch_time_s"].append(epoch_time)
print(f"{task_name} | {kind} | epoch={epoch+1:02d} train={train_loss:.6f} val={val_loss:.6f}")
sample_x, sample_y = next(iter(val_loader))
sample_x = sample_x[:2]
sample_y = sample_y[:2]
inf = benchmark_inference(model, sample_x)
metrics = {
"task": task_name,
"kind": kind,
"params": sum(p.numel() for p in model.parameters()),
"train_loss": history["train_loss"][-1],
"val_loss": history["val_loss"][-1],
"mean_epoch_time_s": float(np.mean(history["epoch_time_s"])),
**profile,
"sample_target": sample_y.cpu(),
**inf,
}
return metrics, history, model
"""
),
md(
"""
## Run the Quick Experiment
This is the cell you will usually run in Colab first.
"""
),
code(
r"""
all_metrics = []
histories = {}
trained_models = {}
for task_name in ACTIVE_TASKS:
for kind in MODELS:
metrics, history, model = train_and_benchmark(task_name, kind)
all_metrics.append({k: v for k, v in metrics.items() if k not in {"sample_target", "prediction", "recurrent_prediction"}})
histories[(task_name, kind)] = history
trained_models[(task_name, kind)] = metrics
summary_df = pd.DataFrame(all_metrics).sort_values(["task", "val_loss"]).reset_index(drop=True)
show_benchmark_tables(summary_df, title="Quick benchmark")
"""
),
code(
r"""
fig, axes = plt.subplots(2, 2, figsize=(14, 8))
metrics_to_plot = ["val_loss", "mean_epoch_time_s", "full_tokens_per_s", "recurrent_tokens_per_s"]
for ax, metric in zip(axes.flatten(), metrics_to_plot):
pivot = summary_df.pivot(index="task", columns="kind", values=metric).loc[ACTIVE_TASKS]
pivot.plot(ax=ax, marker="o")
ax.set_title(metric.replace("_", " ").title())
ax.grid(alpha=0.2)
plt.tight_layout()
plt.show()
"""
),
md(
"""
The benchmark tables above are intentionally split into normal, deployment-lite, and balanced deployment views so the important columns remain visible in Colab.
"""
),
code(
r"""
PLOT_TASK = ACTIVE_TASKS[-1]
baseline = trained_models[(PLOT_TASK, "gamma_baseline")]
target = baseline["sample_target"][0].numpy()
baseline_pred = baseline["prediction"][0].numpy()
comparison_kinds = [kind for kind in MODELS if kind != "gamma_baseline"]
for compare_kind in comparison_kinds:
candidate = trained_models[(PLOT_TASK, compare_kind)]
candidate_pred = candidate["prediction"][0].numpy()
channels = range(min(3, target.shape[-1]))
time_axis = np.arange(target.shape[0])
fig, axes = plt.subplots(len(list(channels)), 1, figsize=(12, 3.5 * len(list(channels))), sharex=True)
if target.shape[-1] == 1:
axes = [axes]
for row, channel in enumerate(channels):
ax = axes[row]
ax.plot(time_axis, target[:, channel], label="ground truth", linewidth=2)
ax.plot(time_axis, baseline_pred[:, channel], label="baseline", alpha=0.9)
ax.plot(time_axis, candidate_pred[:, channel], label=compare_kind, alpha=0.9)
ax.set_title(f"{PLOT_TASK} task, channel {channel}")
ax.grid(alpha=0.2)
if row == 0:
ax.legend()
plt.tight_layout()
plt.show()
"""
),
md(
"""
## Reading This Notebook
Use this notebook for fast feedback:
- if either S4-style model loses badly here, do not trust it on bigger tasks
- if one of them is competitive here, then move to the research notebook
"""
),
]
return cells
def research_notebook():
cells = [
md(
"""
# Gamma S4 Practical Benchmark
This notebook is the second benchmark track after the quick notebook.
It is meant to be closer to practical sequence modeling while staying reasonable on Colab:
- one harder long-context synthetic benchmark
- enhanced-model ablations on that harder task
- an optional lightweight token benchmark
"""
),
setup_cell(),
imports_cell(),
code(
r"""
PRACTICAL_CONFIGS = {
"current_reference": dict(seq_len=320, features=6, train_samples=320, val_samples=80, epochs=5, batch_size=24, d_model=64, hidden_dim=96, num_layers=2, complexity=2),
"long_context": dict(seq_len=768, features=8, train_samples=256, val_samples=64, epochs=4, batch_size=12, d_model=80, hidden_dim=128, num_layers=3, complexity=3),
}
RUN_PRACTICAL_SWEEP = True
RUN_ABLATIONS = False
RUN_TOKEN_TASK = True
PRACTICAL_MODELS = ["gamma_baseline", "gamma_s4_enhanced", "s4_ternary_dplr_ssm"]
MODEL_OVERRIDES = {
"gamma_baseline": {},
"gamma_s4_enhanced": {
"kernel_mode": "auto",
"kernel_threshold": 384,
"discretization": "bilinear",
"gate": True,
"input_gate": True,
"activation": "gelu",
"use_D": True,
"layer_scale_init": 0.1,
},
"gamma_s4_minimal": {
"kernel_mode": "auto",
"kernel_threshold": 512,
"discretization": "bilinear",
"use_D": True,
},
"s4_ternary_dplr_ssm": {
"kernel_mode": "auto",
"kernel_threshold": 256,
"rank": 1,
"gate": True,
"input_gate": True,
"activation": "gelu",
"use_D": True,
"layer_scale_init": 0.1,
},
}
ABLATIONS = [
("default", {}),
("no_input_gate", {"input_gate": False}),
("no_gate", {"gate": False}),
("no_skip_D", {"use_D": False}),
("euler", {"discretization": "euler"}),
]
"""
),
shared_helpers_cell(),
code(
r"""
def train_practical_model(task_name, kind, overrides=None):
cfg = PRACTICAL_CONFIGS[task_name]
train_ds = make_forecasting_split(cfg, "train", seed=SEED + 101)
val_ds = make_forecasting_split(cfg, "val", seed=SEED + 151)
train_loader = DataLoader(train_ds, batch_size=cfg["batch_size"], shuffle=True)
val_loader = DataLoader(val_ds, batch_size=cfg["batch_size"], shuffle=False)
model = build_forecasting_model(kind, cfg, overrides=overrides or MODEL_OVERRIDES.get(kind))
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=1e-4)
history = {"train_loss": [], "val_loss": [], "epoch_time_s": []}
for epoch in range(cfg["epochs"]):
train_loss, epoch_time = run_epoch(model, train_loader, optimizer=optimizer)
val_loss, _ = run_epoch(model, val_loader)
history["train_loss"].append(train_loss)
history["val_loss"].append(val_loss)
history["epoch_time_s"].append(epoch_time)
print(f"{task_name} | {kind} | epoch={epoch+1:02d} train={train_loss:.6f} val={val_loss:.6f}")
sample_x, sample_y = next(iter(val_loader))
sample_x = sample_x[:2]
sample_y = sample_y[:2]
inf = benchmark_inference(model, sample_x)
result = {
"task": task_name,
"kind": kind,
"params": sum(p.numel() for p in model.parameters()),
"train_loss": history["train_loss"][-1],
"val_loss": history["val_loss"][-1],
"mean_epoch_time_s": float(np.mean(history["epoch_time_s"])),
"expected_full_mode": "conv" if (kind != "gamma_baseline" and cfg["seq_len"] >= MODEL_OVERRIDES.get(kind, {}).get("kernel_threshold", 10**9)) else "recurrent_like",
"sample_target": sample_y.cpu(),
**inf,
}
return result, history, model
"""
),
md(
"""
## Practical Sweep
"""
),
code(
r"""
practical_rows = []
practical_artifacts = {}
practical_models = {}
if RUN_PRACTICAL_SWEEP:
for task_name in PRACTICAL_CONFIGS:
for kind in PRACTICAL_MODELS:
result, _, model = train_practical_model(task_name, kind)
practical_artifacts[(task_name, kind)] = result
practical_models[(task_name, kind)] = model
practical_rows.append({k: v for k, v in result.items() if k not in {"sample_target", "prediction", "recurrent_prediction"}})
practical_df = pd.DataFrame(practical_rows)
show_benchmark_tables(practical_df, title="Practical sweep")
"""
),
code(
r"""
if not practical_df.empty:
fig, axes = plt.subplots(2, 2, figsize=(14, 8))
for ax, metric in zip(axes.flatten(), ["val_loss", "mean_epoch_time_s", "full_tokens_per_s", "recurrent_tokens_per_s"]):
pivot = practical_df.pivot(index="task", columns="kind", values=metric)
pivot.plot(ax=ax, marker="o")
ax.set_title(metric.replace("_", " ").title())
ax.grid(alpha=0.2)
plt.tight_layout()
plt.show()
"""
),
md(
"""
## Practical Inference Scaling
This section benchmarks trained practical models across sequence lengths and batch sizes. It separates full-sequence prefill-like throughput from recurrent/deployment decode-like throughput.
"""
),
code(
r"""
inference_scaling_rows = []
if RUN_PRACTICAL_SWEEP and practical_models:
INFERENCE_SEQ_LENS = [64, 192, 384, 768]
INFERENCE_BATCH_SIZES = [1, 2]
for (task_name, kind), model in practical_models.items():
feature_dim = model.out_proj.out_features
for seq_len in INFERENCE_SEQ_LENS:
for batch_size in INFERENCE_BATCH_SIZES:
sample_x = torch.randn(batch_size, seq_len, feature_dim)
metrics = benchmark_inference(model, sample_x)
inference_scaling_rows.append({
"task": task_name,
"kind": kind,
"batch_size": batch_size,
"seq_len": seq_len,
"full_latency_ms": metrics["full_latency_ms"],
"full_tokens_per_s": metrics["full_tokens_per_s"],
"recurrent_latency_ms": metrics["recurrent_latency_ms"],
"recurrent_tokens_per_s": metrics["recurrent_tokens_per_s"],
"deploy_recurrent_latency_ms": metrics["deploy_recurrent_latency_ms"],
"deploy_recurrent_tokens_per_s": metrics["deploy_recurrent_tokens_per_s"],
"balanced_deploy_recurrent_latency_ms": metrics["balanced_deploy_recurrent_latency_ms"],
"balanced_deploy_recurrent_tokens_per_s": metrics["balanced_deploy_recurrent_tokens_per_s"],
"recurrent_match_mse": metrics["recurrent_match_mse"],
"deploy_match_mse": metrics["deploy_match_mse"],
"balanced_deploy_match_mse": metrics["balanced_deploy_match_mse"],
})
inference_scaling_df = pd.DataFrame(inference_scaling_rows)
display(inference_scaling_df.sort_values(["task", "kind", "seq_len", "batch_size"]))
"""
),
code(
r"""
if "inference_scaling_df" in globals() and not inference_scaling_df.empty:
for task_name in inference_scaling_df["task"].unique():
subset = inference_scaling_df[(inference_scaling_df["task"] == task_name) & (inference_scaling_df["batch_size"] == 1)]
fig, axes = plt.subplots(1, 2, figsize=(14, 4))
for kind in subset["kind"].unique():
model_subset = subset[subset["kind"] == kind].sort_values("seq_len")
axes[0].plot(model_subset["seq_len"], model_subset["full_tokens_per_s"], marker="o", label=kind)
axes[1].plot(model_subset["seq_len"], model_subset["recurrent_tokens_per_s"], marker="o", label=kind)
axes[0].set_title(f"{task_name}: full-sequence throughput, batch=1")
axes[1].set_title(f"{task_name}: recurrent throughput, batch=1")
for ax in axes:
ax.set_xlabel("seq_len")
ax.set_ylabel("tokens/s")
ax.grid(alpha=0.25)
ax.legend()
plt.tight_layout()
plt.show()
"""
),
md(
"""
## Task Visual Preview
These plots show what the synthetic tasks actually look like before we discuss scores.
"""
),
code(
r"""
if RUN_PRACTICAL_SWEEP:
fig, axes = plt.subplots(len(PRACTICAL_CONFIGS), 1, figsize=(14, 3.5 * len(PRACTICAL_CONFIGS)), sharex=False)
if len(PRACTICAL_CONFIGS) == 1:
axes = [axes]
for ax, task_name in zip(axes, PRACTICAL_CONFIGS):
cfg = PRACTICAL_CONFIGS[task_name]
preview_ds = make_forecasting_split(cfg, "val", seed=SEED + 151)
preview_x, preview_y = preview_ds[0]
channels = range(min(3, preview_y.shape[-1]))
time_axis = np.arange(preview_y.shape[0])
for channel in channels:
ax.plot(time_axis, preview_y[:, channel].numpy(), label=f"channel {channel}", linewidth=1.5)
ax.set_title(f"{task_name} target preview")
ax.grid(alpha=0.2)
if task_name == list(PRACTICAL_CONFIGS.keys())[0]:
ax.legend(ncol=min(3, preview_y.shape[-1]))
plt.tight_layout()
plt.show()
"""
),
md(
"""
## Prediction Comparison Plots
These are the most presentation-friendly plots in the notebook:
ground truth vs baseline vs each S4-style model on the same held-out sample.
"""
),
code(
r"""
if practical_artifacts:
for task_name in PRACTICAL_CONFIGS:
baseline = practical_artifacts[(task_name, "gamma_baseline")]
target = baseline["sample_target"][0].numpy()
baseline_pred = baseline["prediction"][0].numpy()
for compare_kind in [kind for kind in PRACTICAL_MODELS if kind != "gamma_baseline"]:
candidate = practical_artifacts[(task_name, compare_kind)]
candidate_pred = candidate["prediction"][0].numpy()
channels = range(min(3, target.shape[-1]))
time_axis = np.arange(target.shape[0])
fig, axes = plt.subplots(len(list(channels)), 1, figsize=(14, 3.5 * len(list(channels))), sharex=True)
if target.shape[-1] == 1:
axes = [axes]
for row, channel in enumerate(channels):
ax = axes[row]
ax.plot(time_axis, target[:, channel], label="ground truth", linewidth=2)
ax.plot(time_axis, baseline_pred[:, channel], label="baseline", alpha=0.9)
ax.plot(time_axis, candidate_pred[:, channel], label=compare_kind, alpha=0.9)
ax.set_title(f"{task_name} prediction comparison, channel {channel}")
ax.grid(alpha=0.2)
if row == 0:
ax.legend()
plt.tight_layout()
plt.show()
"""
),
md(
"""
## Error Comparison Plots
These show where each model is missing the target signal. Lower absolute error should visually hug zero.
"""
),
code(
r"""
if practical_artifacts:
for task_name in PRACTICAL_CONFIGS:
baseline = practical_artifacts[(task_name, "gamma_baseline")]
target = baseline["sample_target"][0].numpy()
baseline_err = baseline["prediction"][0].numpy() - target
for compare_kind in [kind for kind in PRACTICAL_MODELS if kind != "gamma_baseline"]:
candidate = practical_artifacts[(task_name, compare_kind)]
candidate_err = candidate["prediction"][0].numpy() - target
channels = range(min(2, target.shape[-1]))
time_axis = np.arange(target.shape[0])
fig, axes = plt.subplots(len(list(channels)), 1, figsize=(14, 3.0 * len(list(channels))), sharex=True)
if len(list(channels)) == 1:
axes = [axes]
for row, channel in enumerate(channels):
ax = axes[row]
ax.plot(time_axis, baseline_err[:, channel], label="baseline error", alpha=0.9)
ax.plot(time_axis, candidate_err[:, channel], label=f"{compare_kind} error", alpha=0.9)
ax.axhline(0.0, color="black", linewidth=1, alpha=0.5)
ax.set_title(f"{task_name} error comparison, channel {channel}")
ax.grid(alpha=0.2)
if row == 0:
ax.legend()
plt.tight_layout()
plt.show()
"""
),
md(
"""
## Enhanced Ablations On Long Context
"""
),
code(
r"""
ablation_rows = []
if RUN_ABLATIONS:
for name, override in ABLATIONS:
merged = {**MODEL_OVERRIDES["gamma_s4_enhanced"], **override}
result, _, _ = train_practical_model("long_context", "gamma_s4_enhanced", overrides=merged)
row = {k: v for k, v in result.items() if k not in {"sample_target", "prediction", "recurrent_prediction"}}
row["ablation"] = name
ablation_rows.append(row)
ablation_df = pd.DataFrame(ablation_rows)
ablation_df
"""
),
code(
r"""
if not ablation_df.empty:
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
ablation_df.sort_values("val_loss").plot.bar(x="ablation", y="val_loss", ax=axes[0], legend=False)
ablation_df.sort_values("mean_epoch_time_s").plot.bar(x="ablation", y="mean_epoch_time_s", ax=axes[1], legend=False)
ablation_df.sort_values("recurrent_tokens_per_s").plot.bar(x="ablation", y="recurrent_tokens_per_s", ax=axes[2], legend=False)
axes[0].set_title("Validation Loss")
axes[1].set_title("Mean Epoch Time")
axes[2].set_title("Recurrent Tokens / s")
for ax in axes:
ax.tick_params(axis="x", rotation=25)
ax.grid(alpha=0.2)
plt.tight_layout()
plt.show()
"""
),
md(
"""
## Optional Token-Lite Task
"""
),
code(
r"""
@torch.no_grad()
def benchmark_token_inference(model, token_batch, target_batch, vocab_size, repeats=6):
model.eval()
x = token_batch.to(DEVICE)
y = target_batch.to(DEVICE)
batch, seq_len = x.shape
token_count = batch * seq_len
rows = []
def reset_memory():
if DEVICE.type == "cuda":
torch.cuda.reset_peak_memory_stats()
def max_memory_mb():
if DEVICE.type != "cuda":
return float("nan")
return torch.cuda.max_memory_allocated() / (1024 ** 2)
reset_memory()
with torch.autocast(device_type=DEVICE.type, enabled=USE_AMP):
logits = model(x)
loss = F.cross_entropy(logits.reshape(-1, vocab_size), y.reshape(-1))
synchronize()
start = time.perf_counter()
for _ in range(repeats):
with torch.autocast(device_type=DEVICE.type, enabled=USE_AMP):
logits = model(x)
synchronize()
elapsed = time.perf_counter() - start
full_logits = logits.detach()
rows.append({
"mode": "prefill_full_sequence",
"latency_ms": 1000.0 * elapsed / repeats,
"tokens_per_s": token_count * repeats / max(elapsed, 1e-9),
"ce": float(loss.detach().cpu()),
"match_mse": 0.0,
"max_memory_mb": max_memory_mb(),
})
hidden = model.embed(x)
states = []
caches = []
cache_start = time.perf_counter()
for layer in model.layers:
state = layer.ssm.init_state(hidden.size(0), DEVICE, hidden.dtype)
states.append(state)
if hasattr(layer, "allocate_inference_cache"):
caches.append(layer.allocate_inference_cache(hidden.size(0), seq_len, DEVICE, hidden.dtype))
else:
caches.append(None)
synchronize()
cache_setup_ms = 1000.0 * (time.perf_counter() - cache_start)
def recurrent_pass(cache_list):
local_states = [state.clone() for state in states]
outputs = []
for t in range(seq_len):
step_x = hidden[:, t, :]
for layer_idx, layer in enumerate(model.layers):
cache = None if cache_list is None else cache_list[layer_idx]
try:
step_x, local_states[layer_idx] = layer.step(step_x, local_states[layer_idx], cache=cache)
except TypeError:
step_x, local_states[layer_idx] = layer.step(step_x, local_states[layer_idx])
outputs.append(model.head(step_x))
return torch.stack(outputs, dim=1)
reset_memory()
recurrent_logits = recurrent_pass(caches)
synchronize()
start = time.perf_counter()
for _ in range(repeats):
recurrent_logits = recurrent_pass(caches)
synchronize()
elapsed = time.perf_counter() - start
recurrent_loss = F.cross_entropy(recurrent_logits.reshape(-1, vocab_size), y.reshape(-1))
rows.append({
"mode": "decode_recurrent_exact",
"cache_setup_ms": cache_setup_ms,
"latency_ms": 1000.0 * elapsed / repeats,
"tokens_per_s": token_count * repeats / max(elapsed, 1e-9),
"ce": float(recurrent_loss.detach().cpu()),
"match_mse": float(F.mse_loss(recurrent_logits, full_logits).detach().cpu()),
"max_memory_mb": max_memory_mb(),
})
deploy_supported = all(hasattr(layer, "allocate_deployment_cache") for layer in model.layers)
if deploy_supported:
cache_start = time.perf_counter()
deploy_caches = [
layer.allocate_deployment_cache(hidden.size(0), seq_len, DEVICE, hidden.dtype)
for layer in model.layers
]
synchronize()
deploy_cache_setup_ms = 1000.0 * (time.perf_counter() - cache_start)
reset_memory()
deploy_logits = recurrent_pass(deploy_caches)
synchronize()
start = time.perf_counter()
for _ in range(repeats):
deploy_logits = recurrent_pass(deploy_caches)
synchronize()
elapsed = time.perf_counter() - start
deploy_loss = F.cross_entropy(deploy_logits.reshape(-1, vocab_size), y.reshape(-1))
rows.append({
"mode": "decode_deploy_lite",
"cache_setup_ms": deploy_cache_setup_ms,
"latency_ms": 1000.0 * elapsed / repeats,
"tokens_per_s": token_count * repeats / max(elapsed, 1e-9),
"ce": float(deploy_loss.detach().cpu()),
"match_mse": float(F.mse_loss(deploy_logits, full_logits).detach().cpu()),
"max_memory_mb": max_memory_mb(),
})
balanced_supported = all(hasattr(layer, "allocate_balanced_deployment_cache") for layer in model.layers)
if balanced_supported:
cache_start = time.perf_counter()
balanced_caches = [
layer.allocate_balanced_deployment_cache(hidden.size(0), seq_len, DEVICE, hidden.dtype)
for layer in model.layers
]
synchronize()
balanced_cache_setup_ms = 1000.0 * (time.perf_counter() - cache_start)
reset_memory()
balanced_logits = recurrent_pass(balanced_caches)
synchronize()
start = time.perf_counter()
for _ in range(repeats):
balanced_logits = recurrent_pass(balanced_caches)
synchronize()
elapsed = time.perf_counter() - start
balanced_loss = F.cross_entropy(balanced_logits.reshape(-1, vocab_size), y.reshape(-1))
rows.append({
"mode": "decode_balanced",
"cache_setup_ms": balanced_cache_setup_ms,
"latency_ms": 1000.0 * elapsed / repeats,
"tokens_per_s": token_count * repeats / max(elapsed, 1e-9),
"ce": float(balanced_loss.detach().cpu()),
"match_mse": float(F.mse_loss(balanced_logits, full_logits).detach().cpu()),
"max_memory_mb": max_memory_mb(),
})
return rows
"""
),
code(
r"""
if RUN_TOKEN_TASK:
TOKEN_DATA_URL = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
TOKEN_DATA_PATH = Path("tmp/jupyter-notebook/tinyshakespeare.txt")
TOKEN_DATA_PATH.parent.mkdir(parents=True, exist_ok=True)
if not TOKEN_DATA_PATH.exists():
urllib.request.urlretrieve(TOKEN_DATA_URL, TOKEN_DATA_PATH)
text = TOKEN_DATA_PATH.read_text(encoding="utf-8")
vocab = sorted(set(text))
stoi = {ch: i for i, ch in enumerate(vocab)}
tokens = torch.tensor([stoi[ch] for ch in text], dtype=torch.long)
TOKEN_CFG = {
"seq_len": 192,
"train_samples": 1200,
"val_samples": 240,
"epochs": 2,
"batch_size": 12 if DEVICE.type == "cuda" else 6,
}
def make_token_split(seq_len, train_samples, val_samples):
max_start = len(tokens) - seq_len - 1
starts = torch.linspace(0, max_start - 1, steps=train_samples + val_samples).long()
x = torch.stack([tokens[s : s + seq_len] for s in starts])
y = torch.stack([tokens[s + 1 : s + seq_len + 1] for s in starts])
return TensorDataset(x[:train_samples], y[:train_samples]), TensorDataset(x[train_samples:], y[train_samples:])
class TokenForecaster(nn.Module):
def __init__(self, vocab_size, kind):
super().__init__()
self.embed = nn.Embedding(vocab_size, 64)
if kind == "gamma_baseline":
factory = lambda: GammaSingleBlock(d_model=64, hidden_dim=96, dropout=0.0)
elif kind == "gamma_s4_enhanced":
factory = lambda: GammaS4Block(d_model=64, hidden_dim=96, kernel_mode="auto", kernel_threshold=160)
elif kind == "s4_ternary_dplr_ssm":
factory = lambda: S4TernaryDPLRBlock(d_model=64, hidden_dim=96, kernel_mode="auto", kernel_threshold=160)
else:
factory = lambda: GammaS4MinimalBlock(d_model=64, hidden_dim=96, kernel_mode="auto", kernel_threshold=160)
self.layers = nn.ModuleList([factory(), factory()])
self.head = nn.Linear(64, vocab_size)
def forward(self, x):
x = self.embed(x)
for layer in self.layers:
x, _ = layer(x, state=None, return_state=False)
return self.head(x)
token_train, token_val = make_token_split(
seq_len=TOKEN_CFG["seq_len"],
train_samples=TOKEN_CFG["train_samples"],
val_samples=TOKEN_CFG["val_samples"],
)
token_rows = []
token_inference_rows = []
for kind in PRACTICAL_MODELS:
model = TokenForecaster(len(vocab), kind).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=1e-4)
train_loader = DataLoader(token_train, batch_size=TOKEN_CFG["batch_size"], shuffle=True)
val_loader = DataLoader(token_val, batch_size=TOKEN_CFG["batch_size"], shuffle=False)
history = []
for epoch in range(TOKEN_CFG["epochs"]):
model.train()
train_losses = []
for batch_x, batch_y in train_loader:
batch_x = batch_x.to(DEVICE)
batch_y = batch_y.to(DEVICE)
optimizer.zero_grad(set_to_none=True)
with torch.autocast(device_type=DEVICE.type, enabled=USE_AMP):
logits = model(batch_x)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch_y.reshape(-1))
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_losses.append(loss.detach().item())
model.eval()
val_losses = []
with torch.no_grad():
for batch_x, batch_y in val_loader:
batch_x = batch_x.to(DEVICE)
batch_y = batch_y.to(DEVICE)
logits = model(batch_x)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch_y.reshape(-1))
val_losses.append(loss.detach().item())
history.append((float(np.mean(train_losses)), float(np.mean(val_losses))))
print(kind, epoch + 1, history[-1])
token_rows.append({
"kind": kind,
"train_ce": history[-1][0],
"val_ce": history[-1][1],
"val_ppl": math.exp(history[-1][1]),
"seq_len": TOKEN_CFG["seq_len"],
"train_samples": TOKEN_CFG["train_samples"],
})
sample_x, sample_y = next(iter(val_loader))
for row in benchmark_token_inference(model, sample_x[:2], sample_y[:2], len(vocab)):
row["kind"] = kind
row["seq_len"] = TOKEN_CFG["seq_len"]
row["batch_size"] = min(2, sample_x.size(0))
token_inference_rows.append(row)
token_df = pd.DataFrame(token_rows).sort_values("val_ce")
display(token_df)
token_inference_df = pd.DataFrame(token_inference_rows)
display(token_inference_df.sort_values(["kind", "mode"]))
"""
),
md(
"""
Use this notebook after the quick benchmark. The `long_context` task is the more practical synthetic benchmark, and the optional token-lite section gives a small language-like check without making Colab costs too high.
`RUN_ABLATIONS` is off by default because the long-context ablation sweep is still materially more expensive than the main comparison.
"""
),
]
return cells
def write_notebook(path, cells):
notebook = {
"cells": cells,
"metadata": {
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
"language_info": {"name": "python", "version": "3.11"},
},
"nbformat": 4,
"nbformat_minor": 5,
}
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(notebook, indent=2), encoding="utf-8")
print(f"Wrote {path}")
write_notebook(QUICK_NOTEBOOK_PATH, quick_notebook())
write_notebook(RESEARCH_NOTEBOOK_PATH, research_notebook())