import json from pathlib import Path def md(text): return { "cell_type": "markdown", "metadata": {}, "source": text.strip("\n").splitlines(keepends=True), } def code(text): return { "cell_type": "code", "execution_count": None, "metadata": {}, "outputs": [], "source": text.strip("\n").splitlines(keepends=True), } NOTEBOOK_PATH = Path("output/jupyter-notebook/gamma-s4-challenge-benchmark.ipynb") def setup_cell(): return code( r""" import os import sys import subprocess import importlib from pathlib import Path IN_COLAB = "google.colab" in sys.modules REPO_DIR = Path.cwd() if IN_COLAB: print("Running in Google Colab") from google.colab import userdata from getpass import getpass REPO_NAME = "gamma_ssm_s4_v2" REPO_DIR = Path("/content") / REPO_NAME GITHUB_REPO = "StarMists/gamma_SSM_S4_enhanced" token = os.environ.get("GITHUB_TOKEN") or os.environ.get("GH_TOKEN") if not token: try: token = userdata.get("GITHUB_TOKEN") except Exception: token = None if not token: token = getpass("GitHub personal access token for private repo access: ").strip() clone_url = f"https://{token}@github.com/{GITHUB_REPO}.git" if REPO_DIR.exists(): subprocess.run(["git", "-C", str(REPO_DIR), "fetch", "origin"], check=True) subprocess.run(["git", "-C", str(REPO_DIR), "checkout", "main"], check=True) subprocess.run(["git", "-C", str(REPO_DIR), "reset", "--hard", "origin/main"], check=True) else: subprocess.run(["git", "clone", clone_url, str(REPO_DIR)], check=True) os.chdir(REPO_DIR) sys.path.insert(0, str(REPO_DIR)) else: print("Running locally from", REPO_DIR) sys.path.insert(0, str(REPO_DIR)) importlib.invalidate_caches() for name in list(sys.modules): if ( name == "gamma_space_model" or name.startswith("gamma_space_model.") or name == "csrc" or name.startswith("csrc.") or name == "tilelang" or name.startswith("tilelang.") ): del sys.modules[name] try: commit = subprocess.check_output(["git", "-C", str(REPO_DIR), "rev-parse", "--short", "HEAD"], text=True).strip() print("Repo commit:", commit) except Exception: pass """ ) def imports_cell(): return code( r""" import math import random import time from dataclasses import dataclass import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset try: from torchvision import datasets, transforms except Exception: if IN_COLAB: subprocess.run([sys.executable, "-m", "pip", "install", "-q", "torchvision"], check=True) from torchvision import datasets, transforms else: raise from gamma_space_model import GammaSingleBlock, GammaS4Block, S4TernaryDPLRBlock SEED = 7 random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") USE_AMP = DEVICE.type == "cuda" if hasattr(torch, "amp") and hasattr(torch.amp, "GradScaler"): scaler = torch.amp.GradScaler("cuda", enabled=USE_AMP) else: scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP) def synchronize(): if DEVICE.type == "cuda": torch.cuda.synchronize() print("Device:", DEVICE) if DEVICE.type != "cuda": print("WARNING: CPU run detected. Treat speed numbers as smoke-test only.") """ ) def config_cell(): return code( r""" # Set any RUN_* flag to False if you want to skip a benchmark during a quick smoke run. RUN_PERMUTED_MNIST = True RUN_SELECTIVE_COPYING = True RUN_INDUCTION_RECALL = True RUN_TOKEN_CURRICULUM = True CHALLENGE_CONFIG = { "mnist": { "epochs": 4, "batch_size": 64, "train_samples": 6000, "val_samples": 1000, "d_model": 48, "hidden_dim": 96, "num_layers": 2, "lr": 2e-3, "weight_decay": 1e-4, }, "selective_copying": { "epochs": 12, "batch_size": 128, "train_samples": 4096, "val_samples": 1024, "seq_len": 128, "vocab_size": 32, "num_mem_tokens": 6, "d_model": 64, "hidden_dim": 128, "num_layers": 2, "lr": 2e-3, "weight_decay": 1e-4, }, "induction_recall": { "epochs": 12, "batch_size": 128, "train_samples": 4096, "val_samples": 1024, "seq_len": 128, "num_keys": 32, "num_values": 32, "d_model": 64, "hidden_dim": 128, "num_layers": 2, "lr": 2e-3, "weight_decay": 1e-4, }, } MODEL_KINDS = ["gamma_baseline", "gamma_s4_enhanced", "s4_ternary_dplr_ssm"] TOKEN_CURRICULUM = [ { "difficulty": "easy", "seq_len": 64, "classes": 8, "num_mem_tokens": 3, "train_samples": 2048, "val_samples": 512, "epochs": 10, }, { "difficulty": "moderate", "seq_len": 96, "classes": 16, "num_mem_tokens": 4, "train_samples": 4096, "val_samples": 1024, "epochs": 10, }, { "difficulty": "hard", "seq_len": 128, "classes": 32, "num_mem_tokens": 6, "train_samples": 4096, "val_samples": 1024, "epochs": 12, }, ] """ ) def model_cell(): return code( r""" def make_block(kind, width, hidden_dim): if kind == "gamma_baseline": return GammaSingleBlock( d_model=width, hidden_dim=hidden_dim, dropout=0.0, ) if kind == "gamma_s4_enhanced": return GammaS4Block( d_model=width, hidden_dim=hidden_dim, kernel_mode="auto", kernel_threshold=96, discretization="bilinear", gate=True, input_gate=True, activation="gelu", use_D=True, layer_scale_init=0.1, ) if kind == "s4_ternary_dplr_ssm": return S4TernaryDPLRBlock( d_model=width, hidden_dim=hidden_dim, rank=1, kernel_mode="auto", kernel_threshold=96, gate=True, input_gate=True, activation="gelu", use_D=True, layer_scale_init=0.1, ) raise ValueError(f"Unknown model kind: {kind}") class BlockStack(nn.Module): def __init__(self, kind, d_model, hidden_dim, num_layers): super().__init__() self.layers = nn.ModuleList( [make_block(kind, d_model, hidden_dim) for _ in range(num_layers)] ) def forward(self, x): for layer in self.layers: x, _ = layer(x, state=None, return_state=False) return x class PermutedMNISTModel(nn.Module): def __init__(self, kind, d_model, hidden_dim, num_layers): super().__init__() self.input_proj = nn.Linear(1, d_model) self.stack = BlockStack(kind, d_model, hidden_dim, num_layers) self.classifier = nn.Linear(d_model, 10) def forward(self, x): x = self.input_proj(x) x = self.stack(x) return self.classifier(x[:, -1]) class TokenSequenceModel(nn.Module): def __init__(self, kind, vocab_size, d_model, hidden_dim, num_layers): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.stack = BlockStack(kind, d_model, hidden_dim, num_layers) self.head = nn.Linear(d_model, vocab_size) def forward(self, tokens): x = self.embedding(tokens) x = self.stack(x) return self.head(x) class TokenClassifierModel(nn.Module): def __init__(self, kind, vocab_size, num_classes, d_model, hidden_dim, num_layers): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.stack = BlockStack(kind, d_model, hidden_dim, num_layers) self.head = nn.Linear(d_model, num_classes) def forward(self, tokens): x = self.embedding(tokens) x = self.stack(x) return self.head(x[:, -1]) """ ) def data_cell(): return code( r""" class PermutedMNISTDataset(Dataset): def __init__(self, root, train, permutation, limit): base = datasets.MNIST( root=root, train=train, download=True, transform=transforms.ToTensor(), ) indices = list(range(min(limit, len(base)))) if limit is not None else list(range(len(base))) self.base = Subset(base, indices) self.permutation = permutation def __len__(self): return len(self.base) def __getitem__(self, idx): image, label = self.base[idx] seq = image.view(-1)[self.permutation].unsqueeze(-1) return seq, torch.tensor(label, dtype=torch.long) def make_permuted_mnist_loaders(config): generator = torch.Generator().manual_seed(SEED) permutation = torch.randperm(28 * 28, generator=generator) train_ds = PermutedMNISTDataset( root=str(REPO_DIR / "data"), train=True, permutation=permutation, limit=config["train_samples"], ) val_ds = PermutedMNISTDataset( root=str(REPO_DIR / "data"), train=False, permutation=permutation, limit=config["val_samples"], ) train_loader = DataLoader( train_ds, batch_size=config["batch_size"], shuffle=True, num_workers=2, pin_memory=DEVICE.type == "cuda", ) val_loader = DataLoader( val_ds, batch_size=config["batch_size"], shuffle=False, num_workers=2, pin_memory=DEVICE.type == "cuda", ) return train_loader, val_loader def make_selective_copying_dataset(samples, config, seed): rng = np.random.default_rng(seed) seq_len = config["seq_len"] vocab_size = config["vocab_size"] num_mem = config["num_mem_tokens"] blank = vocab_size marker = vocab_size + 1 query = vocab_size + 2 model_vocab = vocab_size + 3 x = np.full((samples, seq_len), blank, dtype=np.int64) y = np.full((samples, seq_len), -100, dtype=np.int64) memory_window = max(8, seq_len // 2) query_start = seq_len - num_mem for i in range(samples): positions = np.sort(rng.choice(np.arange(1, memory_window - 1), size=num_mem, replace=False)) values = rng.integers(0, vocab_size, size=num_mem, dtype=np.int64) x[i, positions - 1] = marker x[i, positions] = values x[i, query_start:] = query y[i, query_start:] = values return TensorDataset(torch.from_numpy(x), torch.from_numpy(y)), model_vocab def make_selective_copying_loaders(config): train_ds, model_vocab = make_selective_copying_dataset(config["train_samples"], config, SEED + 11) val_ds, _ = make_selective_copying_dataset(config["val_samples"], config, SEED + 17) train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True) val_loader = DataLoader(val_ds, batch_size=config["batch_size"], shuffle=False) return train_loader, val_loader, model_vocab def make_induction_recall_dataset(samples, config, seed): rng = np.random.default_rng(seed) seq_len = config["seq_len"] num_keys = config["num_keys"] num_values = config["num_values"] key_offset = 0 value_offset = num_keys filler_offset = num_keys + num_values query_marker = filler_offset + num_keys model_vocab = query_marker + 1 x = rng.integers(filler_offset, filler_offset + num_keys, size=(samples, seq_len), dtype=np.int64) y = np.zeros(samples, dtype=np.int64) for i in range(samples): key = rng.integers(0, num_keys) value = rng.integers(0, num_values) early = rng.integers(1, seq_len // 3) distractor_count = max(3, seq_len // 24) x[i, early] = key_offset + key x[i, early + 1] = value_offset + value for _ in range(distractor_count): pos = rng.integers(seq_len // 3, seq_len - 4) d_key = rng.integers(0, num_keys) d_value = rng.integers(0, num_values) if d_key == key: d_key = (d_key + 1) % num_keys x[i, pos] = key_offset + d_key x[i, pos + 1] = value_offset + d_value x[i, -2] = query_marker x[i, -1] = key_offset + key y[i] = value return TensorDataset(torch.from_numpy(x), torch.from_numpy(y)), model_vocab, num_values def make_induction_recall_loaders(config): train_ds, model_vocab, num_classes = make_induction_recall_dataset(config["train_samples"], config, SEED + 23) val_ds, _, _ = make_induction_recall_dataset(config["val_samples"], config, SEED + 29) train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True) val_loader = DataLoader(val_ds, batch_size=config["batch_size"], shuffle=False) return train_loader, val_loader, model_vocab, num_classes """ ) def training_cell(): return code( r""" def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def run_classifier_epoch(model, loader, optimizer=None): training = optimizer is not None model.train(training) total_loss = 0.0 total_correct = 0 total_count = 0 total_tokens = 0 start = time.perf_counter() for x, y in loader: x = x.to(DEVICE, non_blocking=True) y = y.to(DEVICE, non_blocking=True) if training: optimizer.zero_grad(set_to_none=True) with torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=USE_AMP): logits = model(x) loss = F.cross_entropy(logits, y) if training: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() total_loss += loss.detach().item() * y.numel() total_correct += (logits.argmax(dim=-1) == y).sum().item() total_count += y.numel() total_tokens += x.shape[0] * x.shape[1] synchronize() elapsed = time.perf_counter() - start return { "loss": total_loss / max(total_count, 1), "accuracy": total_correct / max(total_count, 1), "tokens_per_s": total_tokens / max(elapsed, 1e-9), "elapsed_s": elapsed, } def run_sequence_epoch(model, loader, optimizer=None): training = optimizer is not None model.train(training) total_loss = 0.0 total_correct = 0 total_count = 0 total_tokens = 0 start = time.perf_counter() for x, y in loader: x = x.to(DEVICE, non_blocking=True) y = y.to(DEVICE, non_blocking=True) if training: optimizer.zero_grad(set_to_none=True) with torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=USE_AMP): logits = model(x) loss = F.cross_entropy( logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=-100, ) if training: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() valid = y != -100 pred = logits.argmax(dim=-1) total_loss += loss.detach().item() * valid.sum().item() total_correct += ((pred == y) & valid).sum().item() total_count += valid.sum().item() total_tokens += x.numel() synchronize() elapsed = time.perf_counter() - start return { "loss": total_loss / max(total_count, 1), "accuracy": total_correct / max(total_count, 1), "tokens_per_s": total_tokens / max(elapsed, 1e-9), "elapsed_s": elapsed, } @torch.no_grad() def benchmark_forward(model, loader, repeats=8): model.eval() x, _ = next(iter(loader)) x = x.to(DEVICE) for module in model.modules(): if hasattr(module, "clear_kernel_cache"): module.clear_kernel_cache() with torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=USE_AMP): _ = model(x) synchronize() start = time.perf_counter() for _ in range(repeats): with torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=USE_AMP): _ = model(x) synchronize() elapsed = time.perf_counter() - start return { "forward_ms": elapsed * 1000.0 / repeats, "forward_tokens_per_s": (x.shape[0] * x.shape[1] * repeats) / max(elapsed, 1e-9), } def train_challenge_model(name, model, train_loader, val_loader, config, epoch_fn): model = model.to(DEVICE) optimizer = torch.optim.AdamW( model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"], ) rows = [] for epoch in range(config["epochs"]): train_metrics = epoch_fn(model, train_loader, optimizer) val_metrics = epoch_fn(model, val_loader) row = { "epoch": epoch + 1, "train_loss": train_metrics["loss"], "train_accuracy": train_metrics["accuracy"], "val_loss": val_metrics["loss"], "val_accuracy": val_metrics["accuracy"], "train_tokens_per_s": train_metrics["tokens_per_s"], "val_tokens_per_s": val_metrics["tokens_per_s"], "epoch_s": train_metrics["elapsed_s"] + val_metrics["elapsed_s"], } rows.append(row) print( f"{name} | epoch={epoch + 1:02d} " f"train_loss={row['train_loss']:.4f} train_acc={row['train_accuracy']:.3f} " f"val_loss={row['val_loss']:.4f} val_acc={row['val_accuracy']:.3f}" ) history = pd.DataFrame(rows) final = rows[-1].copy() final["best_val_loss"] = float(history["val_loss"].min()) final["best_val_accuracy"] = float(history["val_accuracy"].max()) final.update(benchmark_forward(model, val_loader)) final["params"] = count_parameters(model) return final, history, model def model_hidden_and_layers(model, x): if isinstance(model, PermutedMNISTModel): return model.input_proj(x), model.stack.layers, model.classifier, "classifier" if isinstance(model, TokenSequenceModel): return model.embedding(x), model.stack.layers, model.head, "sequence" if isinstance(model, TokenClassifierModel): return model.embedding(x), model.stack.layers, model.head, "classifier" raise TypeError(f"Unsupported model type: {type(model).__name__}") @torch.no_grad() def benchmark_model_inference_modes(model, loader, task_type, repeats=6): model.eval() x, y = next(iter(loader)) x = x[: min(2, x.size(0))].to(DEVICE) y = y[: min(2, y.size(0))].to(DEVICE) token_count = x.shape[0] * x.shape[1] rows = [] def reset_memory(): if DEVICE.type == "cuda": torch.cuda.reset_peak_memory_stats() def max_memory_mb(): if DEVICE.type != "cuda": return float("nan") return torch.cuda.max_memory_allocated() / (1024 ** 2) def score_logits(logits): if task_type == "sequence": valid = y != -100 loss = F.cross_entropy( logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=-100, ) pred = logits.argmax(dim=-1) accuracy = ((pred == y) & valid).sum().float() / valid.sum().clamp_min(1) else: loss = F.cross_entropy(logits, y) accuracy = (logits.argmax(dim=-1) == y).float().mean() return float(loss.detach().cpu()), float(accuracy.detach().cpu()) reset_memory() with torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=USE_AMP): full_logits = model(x) full_ce, full_acc = score_logits(full_logits) synchronize() start = time.perf_counter() for _ in range(repeats): with torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=USE_AMP): full_logits = model(x) synchronize() elapsed = time.perf_counter() - start rows.append({ "mode": "prefill_full_sequence", "latency_ms": 1000.0 * elapsed / repeats, "tokens_per_s": token_count * repeats / max(elapsed, 1e-9), "ce": full_ce, "accuracy": full_acc, "match_mse": 0.0, "max_memory_mb": max_memory_mb(), }) hidden, layers, head, output_kind = model_hidden_and_layers(model, x) states = [] exact_caches = [] cache_start = time.perf_counter() for layer in layers: state = layer.ssm.init_state(hidden.size(0), DEVICE, hidden.dtype) states.append(state) if hasattr(layer, "allocate_inference_cache"): exact_caches.append(layer.allocate_inference_cache(hidden.size(0), hidden.size(1), DEVICE, hidden.dtype)) else: exact_caches.append(None) synchronize() cache_setup_ms = 1000.0 * (time.perf_counter() - cache_start) def recurrent_logits(cache_list): local_states = [state.clone() for state in states] outputs = [] for t in range(hidden.size(1)): step_x = hidden[:, t, :] for layer_idx, layer in enumerate(layers): cache = None if cache_list is None else cache_list[layer_idx] try: step_x, local_states[layer_idx] = layer.step(step_x, local_states[layer_idx], cache=cache) except TypeError: step_x, local_states[layer_idx] = layer.step(step_x, local_states[layer_idx]) outputs.append(step_x) hidden_seq = torch.stack(outputs, dim=1) if output_kind == "sequence": return head(hidden_seq) return head(hidden_seq[:, -1]) def add_recurrent_row(mode, cache_list, setup_ms=None): reset_memory() logits = recurrent_logits(cache_list) ce, acc = score_logits(logits) synchronize() start = time.perf_counter() for _ in range(repeats): logits = recurrent_logits(cache_list) synchronize() elapsed = time.perf_counter() - start rows.append({ "mode": mode, "cache_setup_ms": setup_ms, "latency_ms": 1000.0 * elapsed / repeats, "tokens_per_s": token_count * repeats / max(elapsed, 1e-9), "ce": ce, "accuracy": acc, "match_mse": float(F.mse_loss(logits, full_logits).detach().cpu()), "max_memory_mb": max_memory_mb(), }) add_recurrent_row("decode_recurrent_exact", exact_caches, setup_ms=cache_setup_ms) if all(hasattr(layer, "allocate_deployment_cache") for layer in layers): cache_start = time.perf_counter() deploy_caches = [ layer.allocate_deployment_cache(hidden.size(0), hidden.size(1), DEVICE, hidden.dtype) for layer in layers ] synchronize() deploy_cache_setup_ms = 1000.0 * (time.perf_counter() - cache_start) add_recurrent_row("decode_deploy_lite", deploy_caches, setup_ms=deploy_cache_setup_ms) if all(hasattr(layer, "allocate_balanced_deployment_cache") for layer in layers): cache_start = time.perf_counter() balanced_caches = [ layer.allocate_balanced_deployment_cache(hidden.size(0), hidden.size(1), DEVICE, hidden.dtype) for layer in layers ] synchronize() balanced_cache_setup_ms = 1000.0 * (time.perf_counter() - cache_start) add_recurrent_row("decode_balanced", balanced_caches, setup_ms=balanced_cache_setup_ms) return rows """ ) def run_cell(): return code( r""" all_results = [] all_histories = {} trained_models = {} challenge_inference_rows = [] if RUN_PERMUTED_MNIST: config = CHALLENGE_CONFIG["mnist"] train_loader, val_loader = make_permuted_mnist_loaders(config) for kind in MODEL_KINDS: print(f"\n=== Permuted MNIST | {kind} ===") model = PermutedMNISTModel( kind=kind, d_model=config["d_model"], hidden_dim=config["hidden_dim"], num_layers=config["num_layers"], ) metrics, history, model = train_challenge_model( f"permuted_mnist/{kind}", model, train_loader, val_loader, config, run_classifier_epoch, ) metrics.update({"task": "permuted_mnist", "model": kind, "chance_accuracy": 0.1}) all_results.append(metrics) all_histories[(metrics["task"], kind)] = history trained_models[(metrics["task"], kind)] = model for row in benchmark_model_inference_modes(model, val_loader, task_type="classifier"): row.update({"task": "permuted_mnist", "model": kind}) challenge_inference_rows.append(row) if RUN_SELECTIVE_COPYING: config = CHALLENGE_CONFIG["selective_copying"] train_loader, val_loader, vocab_size = make_selective_copying_loaders(config) for kind in MODEL_KINDS: print(f"\n=== Selective Copying | {kind} ===") model = TokenSequenceModel( kind=kind, vocab_size=vocab_size, d_model=config["d_model"], hidden_dim=config["hidden_dim"], num_layers=config["num_layers"], ) metrics, history, model = train_challenge_model( f"selective_copying/{kind}", model, train_loader, val_loader, config, run_sequence_epoch, ) metrics.update({"task": "selective_copying", "model": kind, "chance_accuracy": 1.0 / config["vocab_size"]}) all_results.append(metrics) all_histories[(metrics["task"], kind)] = history trained_models[(metrics["task"], kind)] = model for row in benchmark_model_inference_modes(model, val_loader, task_type="sequence"): row.update({"task": "selective_copying", "model": kind}) challenge_inference_rows.append(row) if RUN_INDUCTION_RECALL: config = CHALLENGE_CONFIG["induction_recall"] train_loader, val_loader, vocab_size, num_classes = make_induction_recall_loaders(config) for kind in MODEL_KINDS: print(f"\n=== Induction Recall | {kind} ===") model = TokenClassifierModel( kind=kind, vocab_size=vocab_size, num_classes=num_classes, d_model=config["d_model"], hidden_dim=config["hidden_dim"], num_layers=config["num_layers"], ) metrics, history, model = train_challenge_model( f"induction_recall/{kind}", model, train_loader, val_loader, config, run_classifier_epoch, ) metrics.update({"task": "induction_recall", "model": kind, "chance_accuracy": 1.0 / config["num_values"]}) all_results.append(metrics) all_histories[(metrics["task"], kind)] = history trained_models[(metrics["task"], kind)] = model for row in benchmark_model_inference_modes(model, val_loader, task_type="classifier"): row.update({"task": "induction_recall", "model": kind}) challenge_inference_rows.append(row) results_df = pd.DataFrame(all_results) display(results_df[ [ "task", "model", "val_loss", "val_accuracy", "best_val_loss", "best_val_accuracy", "chance_accuracy", "epoch_s", "forward_ms", "forward_tokens_per_s", "params", ] ].sort_values(["task", "model"])) challenge_inference_df = pd.DataFrame(challenge_inference_rows) display(challenge_inference_df.sort_values(["task", "model", "mode"])) """ ) def plotting_cell(): return code( r""" if len(all_histories) > 0: tasks = sorted({task for task, _ in all_histories}) fig, axes = plt.subplots(len(tasks), 2, figsize=(12, 3.5 * len(tasks)), squeeze=False) for row, task in enumerate(tasks): for kind in MODEL_KINDS: hist = all_histories.get((task, kind)) if hist is None: continue axes[row, 0].plot(hist["epoch"], hist["val_loss"], marker="o", label=kind) axes[row, 1].plot(hist["epoch"], hist["val_accuracy"], marker="o", label=kind) axes[row, 0].set_title(f"{task}: validation loss") axes[row, 1].set_title(f"{task}: validation accuracy") axes[row, 0].set_xlabel("epoch") axes[row, 1].set_xlabel("epoch") axes[row, 0].grid(alpha=0.25) axes[row, 1].grid(alpha=0.25) axes[row, 0].legend() axes[row, 1].legend() plt.tight_layout() plt.show() """ ) def curriculum_cell(): return code( r""" curriculum_rows = [] curriculum_histories = {} if RUN_TOKEN_CURRICULUM: for spec in TOKEN_CURRICULUM: shared = { "epochs": spec["epochs"], "batch_size": 128, "train_samples": spec["train_samples"], "val_samples": spec["val_samples"], "seq_len": spec["seq_len"], "d_model": 64, "hidden_dim": 128, "num_layers": 2, "lr": 2e-3, "weight_decay": 1e-4, } copy_cfg = { **shared, "vocab_size": spec["classes"], "num_mem_tokens": spec["num_mem_tokens"], } train_loader, val_loader, vocab_size = make_selective_copying_loaders(copy_cfg) for kind in MODEL_KINDS: print(f"\n=== Curriculum selective_copying/{spec['difficulty']} | {kind} ===") model = TokenSequenceModel( kind=kind, vocab_size=vocab_size, d_model=copy_cfg["d_model"], hidden_dim=copy_cfg["hidden_dim"], num_layers=copy_cfg["num_layers"], ) metrics, history, model = train_challenge_model( f"curriculum_selective_copying_{spec['difficulty']}/{kind}", model, train_loader, val_loader, copy_cfg, run_sequence_epoch, ) metrics.update({ "task": "selective_copying", "difficulty": spec["difficulty"], "model": kind, "classes": spec["classes"], "seq_len": spec["seq_len"], "chance_accuracy": 1.0 / spec["classes"], }) curriculum_rows.append(metrics) curriculum_histories[(metrics["task"], spec["difficulty"], kind)] = history if spec["difficulty"] in {"easy", "hard"}: for row in benchmark_model_inference_modes(model, val_loader, task_type="sequence"): row.update({ "task": "selective_copying", "difficulty": spec["difficulty"], "model": kind, }) challenge_inference_rows.append(row) induction_cfg = { **shared, "num_keys": spec["classes"], "num_values": spec["classes"], } train_loader, val_loader, vocab_size, num_classes = make_induction_recall_loaders(induction_cfg) for kind in MODEL_KINDS: print(f"\n=== Curriculum induction_recall/{spec['difficulty']} | {kind} ===") model = TokenClassifierModel( kind=kind, vocab_size=vocab_size, num_classes=num_classes, d_model=induction_cfg["d_model"], hidden_dim=induction_cfg["hidden_dim"], num_layers=induction_cfg["num_layers"], ) metrics, history, model = train_challenge_model( f"curriculum_induction_recall_{spec['difficulty']}/{kind}", model, train_loader, val_loader, induction_cfg, run_classifier_epoch, ) metrics.update({ "task": "induction_recall", "difficulty": spec["difficulty"], "model": kind, "classes": spec["classes"], "seq_len": spec["seq_len"], "chance_accuracy": 1.0 / spec["classes"], }) curriculum_rows.append(metrics) curriculum_histories[(metrics["task"], spec["difficulty"], kind)] = history if spec["difficulty"] in {"easy", "hard"}: for row in benchmark_model_inference_modes(model, val_loader, task_type="classifier"): row.update({ "task": "induction_recall", "difficulty": spec["difficulty"], "model": kind, }) challenge_inference_rows.append(row) curriculum_df = pd.DataFrame(curriculum_rows) if not curriculum_df.empty: display(curriculum_df[ [ "task", "difficulty", "model", "classes", "seq_len", "val_loss", "val_accuracy", "best_val_accuracy", "chance_accuracy", "epoch_s", "forward_tokens_per_s", ] ].sort_values(["task", "difficulty", "model"])) else: display(curriculum_df) if challenge_inference_rows: challenge_inference_df = pd.DataFrame(challenge_inference_rows) sort_cols = ["task"] if "difficulty" in challenge_inference_df.columns: sort_cols.append("difficulty") sort_cols.extend(["model", "mode"]) display(challenge_inference_df.sort_values(sort_cols)) """ ) def curriculum_plot_cell(): return code( r""" if "curriculum_df" in globals() and not curriculum_df.empty: for task in sorted(curriculum_df["task"].unique()): subset = curriculum_df[curriculum_df["task"] == task] fig, ax = plt.subplots(figsize=(10, 4)) for kind in MODEL_KINDS: model_subset = subset[subset["model"] == kind].set_index("difficulty").loc[["easy", "moderate", "hard"]] ax.plot(model_subset.index, model_subset["best_val_accuracy"], marker="o", label=kind) chance = subset.drop_duplicates("difficulty").set_index("difficulty").loc[["easy", "moderate", "hard"]] ax.plot(chance.index, chance["chance_accuracy"], linestyle="--", color="black", label="chance") ax.set_title(f"{task}: best validation accuracy by difficulty") ax.set_ylabel("accuracy") ax.grid(alpha=0.25) ax.legend() plt.show() """ ) def notes_cell(): return md( """ ## How To Read This Notebook - Permuted MNIST stresses long-range sequence classification because each image is flattened to 784 tokens and then shuffled by a fixed random permutation. - Selective copying stresses sparse memory recall: the model sees a few marked values early and must emit them at query positions near the end. - Induction recall is an induction-head-style associative recall task: the model observes a key-value pair, later receives the same key as a query, and must predict the associated value. The synthetic tasks are intentionally small enough for Colab iteration. The curriculum section is especially important: if a model cannot beat chance on the easy tier, the issue is probably model mechanism or task formulation rather than too few epochs. If it wins easy but fails hard, the next knobs are capacity, sample count, and sequence length. """ ) def build_notebook(): cells = [ md( """ # Gamma S4 Challenge Benchmarks This notebook adds three harder sequence benchmarks requested by the technical lead: 1. Permuted MNIST image classification, following the long-memory benchmark tradition used in the HiPPO paper. 2. Selective copying, a synthetic content-based recall benchmark discussed in the Mamba line of work. 3. Induction-style associative recall, a compact proxy for induction-head behavior. The goal is not to claim state of the art from this first version. The goal is to create a repeatable stress test that tells us whether Gamma S4 is moving toward practical long-context and token-memory behavior. """ ), setup_cell(), imports_cell(), config_cell(), model_cell(), data_cell(), training_cell(), run_cell(), plotting_cell(), md( """ ## Token Memory Curriculum The main challenge tasks are deliberately hard. This curriculum repeats selective copying and induction recall at easy, moderate, and hard settings so we can tell whether failures are caused by too few epochs, insufficient capacity, or a missing content-selective memory mechanism. """ ), curriculum_cell(), curriculum_plot_cell(), notes_cell(), ] return { "cells": cells, "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3", }, "language_info": { "name": "python", "pygments_lexer": "ipython3", }, "colab": { "name": "gamma-s4-challenge-benchmark.ipynb", }, }, "nbformat": 4, "nbformat_minor": 5, } def main(): NOTEBOOK_PATH.parent.mkdir(parents=True, exist_ok=True) NOTEBOOK_PATH.write_text(json.dumps(build_notebook(), indent=2), encoding="utf-8") print(f"Wrote {NOTEBOOK_PATH}") if __name__ == "__main__": main()