TaoNet-mini-T2 / code /Taotern_SSM /scripts /generate_gamma_challenge_benchmark_notebook.py
StarMist0012's picture
Add files using upload-large-folder tool
388fd6e verified
import json
from pathlib import Path
def md(text):
return {
"cell_type": "markdown",
"metadata": {},
"source": text.strip("\n").splitlines(keepends=True),
}
def code(text):
return {
"cell_type": "code",
"execution_count": None,
"metadata": {},
"outputs": [],
"source": text.strip("\n").splitlines(keepends=True),
}
NOTEBOOK_PATH = Path("output/jupyter-notebook/gamma-s4-challenge-benchmark.ipynb")
def setup_cell():
return code(
r"""
import os
import sys
import subprocess
import importlib
from pathlib import Path
IN_COLAB = "google.colab" in sys.modules
REPO_DIR = Path.cwd()
if IN_COLAB:
print("Running in Google Colab")
from google.colab import userdata
from getpass import getpass
REPO_NAME = "gamma_ssm_s4_v2"
REPO_DIR = Path("/content") / REPO_NAME
GITHUB_REPO = "StarMists/gamma_SSM_S4_enhanced"
token = os.environ.get("GITHUB_TOKEN") or os.environ.get("GH_TOKEN")
if not token:
try:
token = userdata.get("GITHUB_TOKEN")
except Exception:
token = None
if not token:
token = getpass("GitHub personal access token for private repo access: ").strip()
clone_url = f"https://{token}@github.com/{GITHUB_REPO}.git"
if REPO_DIR.exists():
subprocess.run(["git", "-C", str(REPO_DIR), "fetch", "origin"], check=True)
subprocess.run(["git", "-C", str(REPO_DIR), "checkout", "main"], check=True)
subprocess.run(["git", "-C", str(REPO_DIR), "reset", "--hard", "origin/main"], check=True)
else:
subprocess.run(["git", "clone", clone_url, str(REPO_DIR)], check=True)
os.chdir(REPO_DIR)
sys.path.insert(0, str(REPO_DIR))
else:
print("Running locally from", REPO_DIR)
sys.path.insert(0, str(REPO_DIR))
importlib.invalidate_caches()
for name in list(sys.modules):
if (
name == "gamma_space_model"
or name.startswith("gamma_space_model.")
or name == "csrc"
or name.startswith("csrc.")
or name == "tilelang"
or name.startswith("tilelang.")
):
del sys.modules[name]
try:
commit = subprocess.check_output(["git", "-C", str(REPO_DIR), "rev-parse", "--short", "HEAD"], text=True).strip()
print("Repo commit:", commit)
except Exception:
pass
"""
)
def imports_cell():
return code(
r"""
import math
import random
import time
from dataclasses import dataclass
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
try:
from torchvision import datasets, transforms
except Exception:
if IN_COLAB:
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "torchvision"], check=True)
from torchvision import datasets, transforms
else:
raise
from gamma_space_model import GammaSingleBlock, GammaS4Block, S4TernaryDPLRBlock
SEED = 7
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_AMP = DEVICE.type == "cuda"
if hasattr(torch, "amp") and hasattr(torch.amp, "GradScaler"):
scaler = torch.amp.GradScaler("cuda", enabled=USE_AMP)
else:
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
def synchronize():
if DEVICE.type == "cuda":
torch.cuda.synchronize()
print("Device:", DEVICE)
if DEVICE.type != "cuda":
print("WARNING: CPU run detected. Treat speed numbers as smoke-test only.")
"""
)
def config_cell():
return code(
r"""
# Set any RUN_* flag to False if you want to skip a benchmark during a quick smoke run.
RUN_PERMUTED_MNIST = True
RUN_SELECTIVE_COPYING = True
RUN_INDUCTION_RECALL = True
RUN_TOKEN_CURRICULUM = True
CHALLENGE_CONFIG = {
"mnist": {
"epochs": 4,
"batch_size": 64,
"train_samples": 6000,
"val_samples": 1000,
"d_model": 48,
"hidden_dim": 96,
"num_layers": 2,
"lr": 2e-3,
"weight_decay": 1e-4,
},
"selective_copying": {
"epochs": 12,
"batch_size": 128,
"train_samples": 4096,
"val_samples": 1024,
"seq_len": 128,
"vocab_size": 32,
"num_mem_tokens": 6,
"d_model": 64,
"hidden_dim": 128,
"num_layers": 2,
"lr": 2e-3,
"weight_decay": 1e-4,
},
"induction_recall": {
"epochs": 12,
"batch_size": 128,
"train_samples": 4096,
"val_samples": 1024,
"seq_len": 128,
"num_keys": 32,
"num_values": 32,
"d_model": 64,
"hidden_dim": 128,
"num_layers": 2,
"lr": 2e-3,
"weight_decay": 1e-4,
},
}
MODEL_KINDS = ["gamma_baseline", "gamma_s4_enhanced", "s4_ternary_dplr_ssm"]
TOKEN_CURRICULUM = [
{
"difficulty": "easy",
"seq_len": 64,
"classes": 8,
"num_mem_tokens": 3,
"train_samples": 2048,
"val_samples": 512,
"epochs": 10,
},
{
"difficulty": "moderate",
"seq_len": 96,
"classes": 16,
"num_mem_tokens": 4,
"train_samples": 4096,
"val_samples": 1024,
"epochs": 10,
},
{
"difficulty": "hard",
"seq_len": 128,
"classes": 32,
"num_mem_tokens": 6,
"train_samples": 4096,
"val_samples": 1024,
"epochs": 12,
},
]
"""
)
def model_cell():
return code(
r"""
def make_block(kind, width, hidden_dim):
if kind == "gamma_baseline":
return GammaSingleBlock(
d_model=width,
hidden_dim=hidden_dim,
dropout=0.0,
)
if kind == "gamma_s4_enhanced":
return GammaS4Block(
d_model=width,
hidden_dim=hidden_dim,
kernel_mode="auto",
kernel_threshold=96,
discretization="bilinear",
gate=True,
input_gate=True,
activation="gelu",
use_D=True,
layer_scale_init=0.1,
)
if kind == "s4_ternary_dplr_ssm":
return S4TernaryDPLRBlock(
d_model=width,
hidden_dim=hidden_dim,
rank=1,
kernel_mode="auto",
kernel_threshold=96,
gate=True,
input_gate=True,
activation="gelu",
use_D=True,
layer_scale_init=0.1,
)
raise ValueError(f"Unknown model kind: {kind}")
class BlockStack(nn.Module):
def __init__(self, kind, d_model, hidden_dim, num_layers):
super().__init__()
self.layers = nn.ModuleList(
[make_block(kind, d_model, hidden_dim) for _ in range(num_layers)]
)
def forward(self, x):
for layer in self.layers:
x, _ = layer(x, state=None, return_state=False)
return x
class PermutedMNISTModel(nn.Module):
def __init__(self, kind, d_model, hidden_dim, num_layers):
super().__init__()
self.input_proj = nn.Linear(1, d_model)
self.stack = BlockStack(kind, d_model, hidden_dim, num_layers)
self.classifier = nn.Linear(d_model, 10)
def forward(self, x):
x = self.input_proj(x)
x = self.stack(x)
return self.classifier(x[:, -1])
class TokenSequenceModel(nn.Module):
def __init__(self, kind, vocab_size, d_model, hidden_dim, num_layers):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.stack = BlockStack(kind, d_model, hidden_dim, num_layers)
self.head = nn.Linear(d_model, vocab_size)
def forward(self, tokens):
x = self.embedding(tokens)
x = self.stack(x)
return self.head(x)
class TokenClassifierModel(nn.Module):
def __init__(self, kind, vocab_size, num_classes, d_model, hidden_dim, num_layers):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.stack = BlockStack(kind, d_model, hidden_dim, num_layers)
self.head = nn.Linear(d_model, num_classes)
def forward(self, tokens):
x = self.embedding(tokens)
x = self.stack(x)
return self.head(x[:, -1])
"""
)
def data_cell():
return code(
r"""
class PermutedMNISTDataset(Dataset):
def __init__(self, root, train, permutation, limit):
base = datasets.MNIST(
root=root,
train=train,
download=True,
transform=transforms.ToTensor(),
)
indices = list(range(min(limit, len(base)))) if limit is not None else list(range(len(base)))
self.base = Subset(base, indices)
self.permutation = permutation
def __len__(self):
return len(self.base)
def __getitem__(self, idx):
image, label = self.base[idx]
seq = image.view(-1)[self.permutation].unsqueeze(-1)
return seq, torch.tensor(label, dtype=torch.long)
def make_permuted_mnist_loaders(config):
generator = torch.Generator().manual_seed(SEED)
permutation = torch.randperm(28 * 28, generator=generator)
train_ds = PermutedMNISTDataset(
root=str(REPO_DIR / "data"),
train=True,
permutation=permutation,
limit=config["train_samples"],
)
val_ds = PermutedMNISTDataset(
root=str(REPO_DIR / "data"),
train=False,
permutation=permutation,
limit=config["val_samples"],
)
train_loader = DataLoader(
train_ds,
batch_size=config["batch_size"],
shuffle=True,
num_workers=2,
pin_memory=DEVICE.type == "cuda",
)
val_loader = DataLoader(
val_ds,
batch_size=config["batch_size"],
shuffle=False,
num_workers=2,
pin_memory=DEVICE.type == "cuda",
)
return train_loader, val_loader
def make_selective_copying_dataset(samples, config, seed):
rng = np.random.default_rng(seed)
seq_len = config["seq_len"]
vocab_size = config["vocab_size"]
num_mem = config["num_mem_tokens"]
blank = vocab_size
marker = vocab_size + 1
query = vocab_size + 2
model_vocab = vocab_size + 3
x = np.full((samples, seq_len), blank, dtype=np.int64)
y = np.full((samples, seq_len), -100, dtype=np.int64)
memory_window = max(8, seq_len // 2)
query_start = seq_len - num_mem
for i in range(samples):
positions = np.sort(rng.choice(np.arange(1, memory_window - 1), size=num_mem, replace=False))
values = rng.integers(0, vocab_size, size=num_mem, dtype=np.int64)
x[i, positions - 1] = marker
x[i, positions] = values
x[i, query_start:] = query
y[i, query_start:] = values
return TensorDataset(torch.from_numpy(x), torch.from_numpy(y)), model_vocab
def make_selective_copying_loaders(config):
train_ds, model_vocab = make_selective_copying_dataset(config["train_samples"], config, SEED + 11)
val_ds, _ = make_selective_copying_dataset(config["val_samples"], config, SEED + 17)
train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_ds, batch_size=config["batch_size"], shuffle=False)
return train_loader, val_loader, model_vocab
def make_induction_recall_dataset(samples, config, seed):
rng = np.random.default_rng(seed)
seq_len = config["seq_len"]
num_keys = config["num_keys"]
num_values = config["num_values"]
key_offset = 0
value_offset = num_keys
filler_offset = num_keys + num_values
query_marker = filler_offset + num_keys
model_vocab = query_marker + 1
x = rng.integers(filler_offset, filler_offset + num_keys, size=(samples, seq_len), dtype=np.int64)
y = np.zeros(samples, dtype=np.int64)
for i in range(samples):
key = rng.integers(0, num_keys)
value = rng.integers(0, num_values)
early = rng.integers(1, seq_len // 3)
distractor_count = max(3, seq_len // 24)
x[i, early] = key_offset + key
x[i, early + 1] = value_offset + value
for _ in range(distractor_count):
pos = rng.integers(seq_len // 3, seq_len - 4)
d_key = rng.integers(0, num_keys)
d_value = rng.integers(0, num_values)
if d_key == key:
d_key = (d_key + 1) % num_keys
x[i, pos] = key_offset + d_key
x[i, pos + 1] = value_offset + d_value
x[i, -2] = query_marker
x[i, -1] = key_offset + key
y[i] = value
return TensorDataset(torch.from_numpy(x), torch.from_numpy(y)), model_vocab, num_values
def make_induction_recall_loaders(config):
train_ds, model_vocab, num_classes = make_induction_recall_dataset(config["train_samples"], config, SEED + 23)
val_ds, _, _ = make_induction_recall_dataset(config["val_samples"], config, SEED + 29)
train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_ds, batch_size=config["batch_size"], shuffle=False)
return train_loader, val_loader, model_vocab, num_classes
"""
)
def training_cell():
return code(
r"""
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def run_classifier_epoch(model, loader, optimizer=None):
training = optimizer is not None
model.train(training)
total_loss = 0.0
total_correct = 0
total_count = 0
total_tokens = 0
start = time.perf_counter()
for x, y in loader:
x = x.to(DEVICE, non_blocking=True)
y = y.to(DEVICE, non_blocking=True)
if training:
optimizer.zero_grad(set_to_none=True)
with torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=USE_AMP):
logits = model(x)
loss = F.cross_entropy(logits, y)
if training:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
total_loss += loss.detach().item() * y.numel()
total_correct += (logits.argmax(dim=-1) == y).sum().item()
total_count += y.numel()
total_tokens += x.shape[0] * x.shape[1]
synchronize()
elapsed = time.perf_counter() - start
return {
"loss": total_loss / max(total_count, 1),
"accuracy": total_correct / max(total_count, 1),
"tokens_per_s": total_tokens / max(elapsed, 1e-9),
"elapsed_s": elapsed,
}
def run_sequence_epoch(model, loader, optimizer=None):
training = optimizer is not None
model.train(training)
total_loss = 0.0
total_correct = 0
total_count = 0
total_tokens = 0
start = time.perf_counter()
for x, y in loader:
x = x.to(DEVICE, non_blocking=True)
y = y.to(DEVICE, non_blocking=True)
if training:
optimizer.zero_grad(set_to_none=True)
with torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=USE_AMP):
logits = model(x)
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
y.reshape(-1),
ignore_index=-100,
)
if training:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
valid = y != -100
pred = logits.argmax(dim=-1)
total_loss += loss.detach().item() * valid.sum().item()
total_correct += ((pred == y) & valid).sum().item()
total_count += valid.sum().item()
total_tokens += x.numel()
synchronize()
elapsed = time.perf_counter() - start
return {
"loss": total_loss / max(total_count, 1),
"accuracy": total_correct / max(total_count, 1),
"tokens_per_s": total_tokens / max(elapsed, 1e-9),
"elapsed_s": elapsed,
}
@torch.no_grad()
def benchmark_forward(model, loader, repeats=8):
model.eval()
x, _ = next(iter(loader))
x = x.to(DEVICE)
for module in model.modules():
if hasattr(module, "clear_kernel_cache"):
module.clear_kernel_cache()
with torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=USE_AMP):
_ = model(x)
synchronize()
start = time.perf_counter()
for _ in range(repeats):
with torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=USE_AMP):
_ = model(x)
synchronize()
elapsed = time.perf_counter() - start
return {
"forward_ms": elapsed * 1000.0 / repeats,
"forward_tokens_per_s": (x.shape[0] * x.shape[1] * repeats) / max(elapsed, 1e-9),
}
def train_challenge_model(name, model, train_loader, val_loader, config, epoch_fn):
model = model.to(DEVICE)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config["lr"],
weight_decay=config["weight_decay"],
)
rows = []
for epoch in range(config["epochs"]):
train_metrics = epoch_fn(model, train_loader, optimizer)
val_metrics = epoch_fn(model, val_loader)
row = {
"epoch": epoch + 1,
"train_loss": train_metrics["loss"],
"train_accuracy": train_metrics["accuracy"],
"val_loss": val_metrics["loss"],
"val_accuracy": val_metrics["accuracy"],
"train_tokens_per_s": train_metrics["tokens_per_s"],
"val_tokens_per_s": val_metrics["tokens_per_s"],
"epoch_s": train_metrics["elapsed_s"] + val_metrics["elapsed_s"],
}
rows.append(row)
print(
f"{name} | epoch={epoch + 1:02d} "
f"train_loss={row['train_loss']:.4f} train_acc={row['train_accuracy']:.3f} "
f"val_loss={row['val_loss']:.4f} val_acc={row['val_accuracy']:.3f}"
)
history = pd.DataFrame(rows)
final = rows[-1].copy()
final["best_val_loss"] = float(history["val_loss"].min())
final["best_val_accuracy"] = float(history["val_accuracy"].max())
final.update(benchmark_forward(model, val_loader))
final["params"] = count_parameters(model)
return final, history, model
def model_hidden_and_layers(model, x):
if isinstance(model, PermutedMNISTModel):
return model.input_proj(x), model.stack.layers, model.classifier, "classifier"
if isinstance(model, TokenSequenceModel):
return model.embedding(x), model.stack.layers, model.head, "sequence"
if isinstance(model, TokenClassifierModel):
return model.embedding(x), model.stack.layers, model.head, "classifier"
raise TypeError(f"Unsupported model type: {type(model).__name__}")
@torch.no_grad()
def benchmark_model_inference_modes(model, loader, task_type, repeats=6):
model.eval()
x, y = next(iter(loader))
x = x[: min(2, x.size(0))].to(DEVICE)
y = y[: min(2, y.size(0))].to(DEVICE)
token_count = x.shape[0] * x.shape[1]
rows = []
def reset_memory():
if DEVICE.type == "cuda":
torch.cuda.reset_peak_memory_stats()
def max_memory_mb():
if DEVICE.type != "cuda":
return float("nan")
return torch.cuda.max_memory_allocated() / (1024 ** 2)
def score_logits(logits):
if task_type == "sequence":
valid = y != -100
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
y.reshape(-1),
ignore_index=-100,
)
pred = logits.argmax(dim=-1)
accuracy = ((pred == y) & valid).sum().float() / valid.sum().clamp_min(1)
else:
loss = F.cross_entropy(logits, y)
accuracy = (logits.argmax(dim=-1) == y).float().mean()
return float(loss.detach().cpu()), float(accuracy.detach().cpu())
reset_memory()
with torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=USE_AMP):
full_logits = model(x)
full_ce, full_acc = score_logits(full_logits)
synchronize()
start = time.perf_counter()
for _ in range(repeats):
with torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=USE_AMP):
full_logits = model(x)
synchronize()
elapsed = time.perf_counter() - start
rows.append({
"mode": "prefill_full_sequence",
"latency_ms": 1000.0 * elapsed / repeats,
"tokens_per_s": token_count * repeats / max(elapsed, 1e-9),
"ce": full_ce,
"accuracy": full_acc,
"match_mse": 0.0,
"max_memory_mb": max_memory_mb(),
})
hidden, layers, head, output_kind = model_hidden_and_layers(model, x)
states = []
exact_caches = []
cache_start = time.perf_counter()
for layer in layers:
state = layer.ssm.init_state(hidden.size(0), DEVICE, hidden.dtype)
states.append(state)
if hasattr(layer, "allocate_inference_cache"):
exact_caches.append(layer.allocate_inference_cache(hidden.size(0), hidden.size(1), DEVICE, hidden.dtype))
else:
exact_caches.append(None)
synchronize()
cache_setup_ms = 1000.0 * (time.perf_counter() - cache_start)
def recurrent_logits(cache_list):
local_states = [state.clone() for state in states]
outputs = []
for t in range(hidden.size(1)):
step_x = hidden[:, t, :]
for layer_idx, layer in enumerate(layers):
cache = None if cache_list is None else cache_list[layer_idx]
try:
step_x, local_states[layer_idx] = layer.step(step_x, local_states[layer_idx], cache=cache)
except TypeError:
step_x, local_states[layer_idx] = layer.step(step_x, local_states[layer_idx])
outputs.append(step_x)
hidden_seq = torch.stack(outputs, dim=1)
if output_kind == "sequence":
return head(hidden_seq)
return head(hidden_seq[:, -1])
def add_recurrent_row(mode, cache_list, setup_ms=None):
reset_memory()
logits = recurrent_logits(cache_list)
ce, acc = score_logits(logits)
synchronize()
start = time.perf_counter()
for _ in range(repeats):
logits = recurrent_logits(cache_list)
synchronize()
elapsed = time.perf_counter() - start
rows.append({
"mode": mode,
"cache_setup_ms": setup_ms,
"latency_ms": 1000.0 * elapsed / repeats,
"tokens_per_s": token_count * repeats / max(elapsed, 1e-9),
"ce": ce,
"accuracy": acc,
"match_mse": float(F.mse_loss(logits, full_logits).detach().cpu()),
"max_memory_mb": max_memory_mb(),
})
add_recurrent_row("decode_recurrent_exact", exact_caches, setup_ms=cache_setup_ms)
if all(hasattr(layer, "allocate_deployment_cache") for layer in layers):
cache_start = time.perf_counter()
deploy_caches = [
layer.allocate_deployment_cache(hidden.size(0), hidden.size(1), DEVICE, hidden.dtype)
for layer in layers
]
synchronize()
deploy_cache_setup_ms = 1000.0 * (time.perf_counter() - cache_start)
add_recurrent_row("decode_deploy_lite", deploy_caches, setup_ms=deploy_cache_setup_ms)
if all(hasattr(layer, "allocate_balanced_deployment_cache") for layer in layers):
cache_start = time.perf_counter()
balanced_caches = [
layer.allocate_balanced_deployment_cache(hidden.size(0), hidden.size(1), DEVICE, hidden.dtype)
for layer in layers
]
synchronize()
balanced_cache_setup_ms = 1000.0 * (time.perf_counter() - cache_start)
add_recurrent_row("decode_balanced", balanced_caches, setup_ms=balanced_cache_setup_ms)
return rows
"""
)
def run_cell():
return code(
r"""
all_results = []
all_histories = {}
trained_models = {}
challenge_inference_rows = []
if RUN_PERMUTED_MNIST:
config = CHALLENGE_CONFIG["mnist"]
train_loader, val_loader = make_permuted_mnist_loaders(config)
for kind in MODEL_KINDS:
print(f"\n=== Permuted MNIST | {kind} ===")
model = PermutedMNISTModel(
kind=kind,
d_model=config["d_model"],
hidden_dim=config["hidden_dim"],
num_layers=config["num_layers"],
)
metrics, history, model = train_challenge_model(
f"permuted_mnist/{kind}",
model,
train_loader,
val_loader,
config,
run_classifier_epoch,
)
metrics.update({"task": "permuted_mnist", "model": kind, "chance_accuracy": 0.1})
all_results.append(metrics)
all_histories[(metrics["task"], kind)] = history
trained_models[(metrics["task"], kind)] = model
for row in benchmark_model_inference_modes(model, val_loader, task_type="classifier"):
row.update({"task": "permuted_mnist", "model": kind})
challenge_inference_rows.append(row)
if RUN_SELECTIVE_COPYING:
config = CHALLENGE_CONFIG["selective_copying"]
train_loader, val_loader, vocab_size = make_selective_copying_loaders(config)
for kind in MODEL_KINDS:
print(f"\n=== Selective Copying | {kind} ===")
model = TokenSequenceModel(
kind=kind,
vocab_size=vocab_size,
d_model=config["d_model"],
hidden_dim=config["hidden_dim"],
num_layers=config["num_layers"],
)
metrics, history, model = train_challenge_model(
f"selective_copying/{kind}",
model,
train_loader,
val_loader,
config,
run_sequence_epoch,
)
metrics.update({"task": "selective_copying", "model": kind, "chance_accuracy": 1.0 / config["vocab_size"]})
all_results.append(metrics)
all_histories[(metrics["task"], kind)] = history
trained_models[(metrics["task"], kind)] = model
for row in benchmark_model_inference_modes(model, val_loader, task_type="sequence"):
row.update({"task": "selective_copying", "model": kind})
challenge_inference_rows.append(row)
if RUN_INDUCTION_RECALL:
config = CHALLENGE_CONFIG["induction_recall"]
train_loader, val_loader, vocab_size, num_classes = make_induction_recall_loaders(config)
for kind in MODEL_KINDS:
print(f"\n=== Induction Recall | {kind} ===")
model = TokenClassifierModel(
kind=kind,
vocab_size=vocab_size,
num_classes=num_classes,
d_model=config["d_model"],
hidden_dim=config["hidden_dim"],
num_layers=config["num_layers"],
)
metrics, history, model = train_challenge_model(
f"induction_recall/{kind}",
model,
train_loader,
val_loader,
config,
run_classifier_epoch,
)
metrics.update({"task": "induction_recall", "model": kind, "chance_accuracy": 1.0 / config["num_values"]})
all_results.append(metrics)
all_histories[(metrics["task"], kind)] = history
trained_models[(metrics["task"], kind)] = model
for row in benchmark_model_inference_modes(model, val_loader, task_type="classifier"):
row.update({"task": "induction_recall", "model": kind})
challenge_inference_rows.append(row)
results_df = pd.DataFrame(all_results)
display(results_df[
[
"task",
"model",
"val_loss",
"val_accuracy",
"best_val_loss",
"best_val_accuracy",
"chance_accuracy",
"epoch_s",
"forward_ms",
"forward_tokens_per_s",
"params",
]
].sort_values(["task", "model"]))
challenge_inference_df = pd.DataFrame(challenge_inference_rows)
display(challenge_inference_df.sort_values(["task", "model", "mode"]))
"""
)
def plotting_cell():
return code(
r"""
if len(all_histories) > 0:
tasks = sorted({task for task, _ in all_histories})
fig, axes = plt.subplots(len(tasks), 2, figsize=(12, 3.5 * len(tasks)), squeeze=False)
for row, task in enumerate(tasks):
for kind in MODEL_KINDS:
hist = all_histories.get((task, kind))
if hist is None:
continue
axes[row, 0].plot(hist["epoch"], hist["val_loss"], marker="o", label=kind)
axes[row, 1].plot(hist["epoch"], hist["val_accuracy"], marker="o", label=kind)
axes[row, 0].set_title(f"{task}: validation loss")
axes[row, 1].set_title(f"{task}: validation accuracy")
axes[row, 0].set_xlabel("epoch")
axes[row, 1].set_xlabel("epoch")
axes[row, 0].grid(alpha=0.25)
axes[row, 1].grid(alpha=0.25)
axes[row, 0].legend()
axes[row, 1].legend()
plt.tight_layout()
plt.show()
"""
)
def curriculum_cell():
return code(
r"""
curriculum_rows = []
curriculum_histories = {}
if RUN_TOKEN_CURRICULUM:
for spec in TOKEN_CURRICULUM:
shared = {
"epochs": spec["epochs"],
"batch_size": 128,
"train_samples": spec["train_samples"],
"val_samples": spec["val_samples"],
"seq_len": spec["seq_len"],
"d_model": 64,
"hidden_dim": 128,
"num_layers": 2,
"lr": 2e-3,
"weight_decay": 1e-4,
}
copy_cfg = {
**shared,
"vocab_size": spec["classes"],
"num_mem_tokens": spec["num_mem_tokens"],
}
train_loader, val_loader, vocab_size = make_selective_copying_loaders(copy_cfg)
for kind in MODEL_KINDS:
print(f"\n=== Curriculum selective_copying/{spec['difficulty']} | {kind} ===")
model = TokenSequenceModel(
kind=kind,
vocab_size=vocab_size,
d_model=copy_cfg["d_model"],
hidden_dim=copy_cfg["hidden_dim"],
num_layers=copy_cfg["num_layers"],
)
metrics, history, model = train_challenge_model(
f"curriculum_selective_copying_{spec['difficulty']}/{kind}",
model,
train_loader,
val_loader,
copy_cfg,
run_sequence_epoch,
)
metrics.update({
"task": "selective_copying",
"difficulty": spec["difficulty"],
"model": kind,
"classes": spec["classes"],
"seq_len": spec["seq_len"],
"chance_accuracy": 1.0 / spec["classes"],
})
curriculum_rows.append(metrics)
curriculum_histories[(metrics["task"], spec["difficulty"], kind)] = history
if spec["difficulty"] in {"easy", "hard"}:
for row in benchmark_model_inference_modes(model, val_loader, task_type="sequence"):
row.update({
"task": "selective_copying",
"difficulty": spec["difficulty"],
"model": kind,
})
challenge_inference_rows.append(row)
induction_cfg = {
**shared,
"num_keys": spec["classes"],
"num_values": spec["classes"],
}
train_loader, val_loader, vocab_size, num_classes = make_induction_recall_loaders(induction_cfg)
for kind in MODEL_KINDS:
print(f"\n=== Curriculum induction_recall/{spec['difficulty']} | {kind} ===")
model = TokenClassifierModel(
kind=kind,
vocab_size=vocab_size,
num_classes=num_classes,
d_model=induction_cfg["d_model"],
hidden_dim=induction_cfg["hidden_dim"],
num_layers=induction_cfg["num_layers"],
)
metrics, history, model = train_challenge_model(
f"curriculum_induction_recall_{spec['difficulty']}/{kind}",
model,
train_loader,
val_loader,
induction_cfg,
run_classifier_epoch,
)
metrics.update({
"task": "induction_recall",
"difficulty": spec["difficulty"],
"model": kind,
"classes": spec["classes"],
"seq_len": spec["seq_len"],
"chance_accuracy": 1.0 / spec["classes"],
})
curriculum_rows.append(metrics)
curriculum_histories[(metrics["task"], spec["difficulty"], kind)] = history
if spec["difficulty"] in {"easy", "hard"}:
for row in benchmark_model_inference_modes(model, val_loader, task_type="classifier"):
row.update({
"task": "induction_recall",
"difficulty": spec["difficulty"],
"model": kind,
})
challenge_inference_rows.append(row)
curriculum_df = pd.DataFrame(curriculum_rows)
if not curriculum_df.empty:
display(curriculum_df[
[
"task",
"difficulty",
"model",
"classes",
"seq_len",
"val_loss",
"val_accuracy",
"best_val_accuracy",
"chance_accuracy",
"epoch_s",
"forward_tokens_per_s",
]
].sort_values(["task", "difficulty", "model"]))
else:
display(curriculum_df)
if challenge_inference_rows:
challenge_inference_df = pd.DataFrame(challenge_inference_rows)
sort_cols = ["task"]
if "difficulty" in challenge_inference_df.columns:
sort_cols.append("difficulty")
sort_cols.extend(["model", "mode"])
display(challenge_inference_df.sort_values(sort_cols))
"""
)
def curriculum_plot_cell():
return code(
r"""
if "curriculum_df" in globals() and not curriculum_df.empty:
for task in sorted(curriculum_df["task"].unique()):
subset = curriculum_df[curriculum_df["task"] == task]
fig, ax = plt.subplots(figsize=(10, 4))
for kind in MODEL_KINDS:
model_subset = subset[subset["model"] == kind].set_index("difficulty").loc[["easy", "moderate", "hard"]]
ax.plot(model_subset.index, model_subset["best_val_accuracy"], marker="o", label=kind)
chance = subset.drop_duplicates("difficulty").set_index("difficulty").loc[["easy", "moderate", "hard"]]
ax.plot(chance.index, chance["chance_accuracy"], linestyle="--", color="black", label="chance")
ax.set_title(f"{task}: best validation accuracy by difficulty")
ax.set_ylabel("accuracy")
ax.grid(alpha=0.25)
ax.legend()
plt.show()
"""
)
def notes_cell():
return md(
"""
## How To Read This Notebook
- Permuted MNIST stresses long-range sequence classification because each image is flattened to 784 tokens and then shuffled by a fixed random permutation.
- Selective copying stresses sparse memory recall: the model sees a few marked values early and must emit them at query positions near the end.
- Induction recall is an induction-head-style associative recall task: the model observes a key-value pair, later receives the same key as a query, and must predict the associated value.
The synthetic tasks are intentionally small enough for Colab iteration. The curriculum section is especially important: if a model cannot beat chance on the easy tier, the issue is probably model mechanism or task formulation rather than too few epochs. If it wins easy but fails hard, the next knobs are capacity, sample count, and sequence length.
"""
)
def build_notebook():
cells = [
md(
"""
# Gamma S4 Challenge Benchmarks
This notebook adds three harder sequence benchmarks requested by the technical lead:
1. Permuted MNIST image classification, following the long-memory benchmark tradition used in the HiPPO paper.
2. Selective copying, a synthetic content-based recall benchmark discussed in the Mamba line of work.
3. Induction-style associative recall, a compact proxy for induction-head behavior.
The goal is not to claim state of the art from this first version. The goal is to create a repeatable stress test that tells us whether Gamma S4 is moving toward practical long-context and token-memory behavior.
"""
),
setup_cell(),
imports_cell(),
config_cell(),
model_cell(),
data_cell(),
training_cell(),
run_cell(),
plotting_cell(),
md(
"""
## Token Memory Curriculum
The main challenge tasks are deliberately hard. This curriculum repeats selective copying and induction recall at easy, moderate, and hard settings so we can tell whether failures are caused by too few epochs, insufficient capacity, or a missing content-selective memory mechanism.
"""
),
curriculum_cell(),
curriculum_plot_cell(),
notes_cell(),
]
return {
"cells": cells,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3",
},
"language_info": {
"name": "python",
"pygments_lexer": "ipython3",
},
"colab": {
"name": "gamma-s4-challenge-benchmark.ipynb",
},
},
"nbformat": 4,
"nbformat_minor": 5,
}
def main():
NOTEBOOK_PATH.parent.mkdir(parents=True, exist_ok=True)
NOTEBOOK_PATH.write_text(json.dumps(build_notebook(), indent=2), encoding="utf-8")
print(f"Wrote {NOTEBOOK_PATH}")
if __name__ == "__main__":
main()