"""Token-level benchmark for TaoNet attention vs TaoNet-SSM. The goal is to compare the two LLM wrappers with the same outer dimensions: original MLA attention TaoNet versus TaoNet with an SSM mixer. """ from __future__ import annotations import argparse from contextlib import nullcontext from contextlib import redirect_stdout import csv import io import json import os from pathlib import Path import platform import subprocess import sys import time from typing import Any import torch REPO_ROOT = Path(__file__).resolve().parents[1] SRC_ROOT = REPO_ROOT / "src" if str(SRC_ROOT) not in sys.path: sys.path.insert(0, str(SRC_ROOT)) from taoTrain.config import ModelConfig from taoTrain.models import get_model DTYPES = { "float32": torch.float32, "fp32": torch.float32, "float16": torch.float16, "fp16": torch.float16, "bfloat16": torch.bfloat16, "bf16": torch.bfloat16, } def parse_int_list(value: str) -> list[int]: return [int(item.strip()) for item in value.split(",") if item.strip()] def synchronize(device: torch.device) -> None: if device.type == "cuda": torch.cuda.synchronize(device) def reset_memory(device: torch.device) -> None: if device.type == "cuda": torch.cuda.reset_peak_memory_stats(device) def memory_stats(device: torch.device) -> dict[str, float | None]: if device.type != "cuda": return { "peak_allocated_mb": None, "peak_reserved_mb": None, } return { "peak_allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2), "peak_reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2), } def nvidia_smi_snapshot() -> str | None: try: completed = subprocess.run( [ "nvidia-smi", "--query-gpu=name,memory.used,memory.total,utilization.gpu,utilization.memory,power.draw,temperature.gpu", "--format=csv,noheader,nounits", ], check=False, capture_output=True, text=True, timeout=5, ) except (OSError, subprocess.TimeoutExpired): return None if completed.returncode != 0: return None return completed.stdout.strip() def make_token_batch( *, batch_size: int, seq_len: int, vocab_size: int, device: torch.device, task: str = "random", ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if task == "random": input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) labels = torch.empty_like(input_ids) labels[:, :-1] = input_ids[:, 1:] labels[:, -1] = torch.randint(0, vocab_size, (batch_size,), device=device) elif task == "increment": starts = torch.randint(0, vocab_size, (batch_size, 1), device=device) offsets = torch.arange(seq_len, device=device).view(1, seq_len) input_ids = (starts + offsets) % vocab_size labels = (input_ids + 1) % vocab_size elif task == "previous": input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) labels = torch.empty_like(input_ids) labels[:, 0] = -100 labels[:, 1:] = input_ids[:, :-1] else: raise ValueError(f"Unsupported token task '{task}'.") attention_mask = torch.ones_like(input_ids) return input_ids, labels, attention_mask def token_accuracy(logits: torch.Tensor, labels: torch.Tensor) -> float: valid = labels != -100 if not torch.any(valid): return float("nan") predictions = torch.argmax(logits, dim=-1) correct = (predictions == labels) & valid return float(correct.sum().detach().cpu() / valid.sum().detach().cpu()) def build_config(args: argparse.Namespace, architecture: str) -> ModelConfig: d_latent_kv = args.d_latent_kv if args.d_latent_kv is not None else int(args.hidden_dim * 0.75) d_rope = args.d_rope if args.d_rope is not None else args.hidden_dim // args.num_heads hidden_dim_ff = args.hidden_dim_ff if args.hidden_dim_ff is not None else args.hidden_dim * 4 return ModelConfig( architecture_type=architecture, vocab_size=args.vocab_size, hidden_dim=args.hidden_dim, num_layers=args.num_layers, num_heads=args.num_heads, max_seq_length=max(parse_int_list(args.seq_lens)), d_latent_kv=d_latent_kv, d_rope=d_rope, hidden_dim_ff=hidden_dim_ff, dropout=args.dropout, gqa_groups=args.gqa_groups, rope_scale=args.rope_scale, yarn_alpha=args.yarn_alpha, init_std=args.init_std, ssm_core=args.ssm_core, ssm_hidden_dim=args.ssm_hidden_dim or d_latent_kv, ssm_mixer_dim=args.ssm_mixer_dim, ssm_rank=args.ssm_rank, ssm_max_low_rank_scale=args.ssm_max_low_rank_scale, ssm_kernel_mode=args.ssm_kernel_mode, ssm_kernel_threshold=args.ssm_kernel_threshold, ssm_dt_min=args.ssm_dt_min, ssm_dt_max=args.ssm_dt_max, ssm_dt_init=args.ssm_dt_init, ssm_use_padding_mask=args.ssm_use_padding_mask, ssm_activation=args.ssm_activation, ssm_gate=args.ssm_gate, ssm_input_gate=args.ssm_input_gate, ssm_layer_scale_init=args.ssm_layer_scale_init, ssm_local_shift=args.ssm_local_shift, ssm_local_shift_init=args.ssm_local_shift_init, ssm_local_shift_per_channel=args.ssm_local_shift_per_channel, ) def count_params(model: torch.nn.Module) -> tuple[int, int]: total = sum(param.numel() for param in model.parameters()) trainable = sum(param.numel() for param in model.parameters() if param.requires_grad) return total, trainable def time_repeats(fn, *, device: torch.device, warmup: int, repeats: int) -> tuple[float, float, float]: last_loss = float("nan") for _ in range(warmup): last_loss = fn() synchronize(device) latencies = [] for _ in range(repeats): reset_memory(device) synchronize(device) start = time.perf_counter() last_loss = fn() synchronize(device) latencies.append(time.perf_counter() - start) return sum(latencies) / len(latencies), min(latencies), last_loss def evaluate_model( model: torch.nn.Module, *, args: argparse.Namespace, batch_size: int, seq_len: int, device: torch.device, autocast_context, ) -> tuple[float, float]: model.eval() losses = [] accuracies = [] with torch.no_grad(): for _ in range(args.eval_batches): input_ids, labels, attention_mask = make_token_batch( batch_size=batch_size, seq_len=seq_len, vocab_size=args.vocab_size, device=device, task=args.token_task, ) with autocast_context(): outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) losses.append(float(outputs["loss"].detach().cpu())) accuracies.append(token_accuracy(outputs["logits"], labels)) model.train() return sum(losses) / len(losses), sum(accuracies) / len(accuracies) def train_model( model: torch.nn.Module, *, args: argparse.Namespace, batch_size: int, seq_len: int, device: torch.device, autocast_context, ) -> tuple[float | None, float | None]: if args.train_steps <= 0: return None, None model.train() optimizer = torch.optim.AdamW( model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, ) last_loss = float("nan") start = time.perf_counter() for _ in range(args.train_steps): input_ids, labels, attention_mask = make_token_batch( batch_size=batch_size, seq_len=seq_len, vocab_size=args.vocab_size, device=device, task=args.token_task, ) optimizer.zero_grad(set_to_none=True) with autocast_context(): outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) loss = outputs["loss"] loss.backward() optimizer.step() last_loss = float(loss.detach().cpu()) synchronize(device) return last_loss, time.perf_counter() - start def benchmark_case( *, args: argparse.Namespace, architecture: str, batch_size: int, seq_len: int, dtype: torch.dtype, device: torch.device, ) -> list[dict[str, Any]]: config = build_config(args, architecture) with redirect_stdout(io.StringIO()): model = get_model(config, device=device) model.train() total_params, trainable_params = count_params(model) tokens = batch_size * seq_len input_ids, labels, attention_mask = make_token_batch( batch_size=batch_size, seq_len=seq_len, vocab_size=args.vocab_size, device=device, task=args.token_task, ) autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16} def autocast_context(): if not autocast_enabled: return nullcontext() return torch.autocast(device_type=device.type, dtype=dtype, enabled=True) train_final_loss, train_seconds = train_model( model, args=args, batch_size=batch_size, seq_len=seq_len, device=device, autocast_context=autocast_context, ) eval_loss, eval_accuracy = evaluate_model( model, args=args, batch_size=batch_size, seq_len=seq_len, device=device, autocast_context=autocast_context, ) rows: list[dict[str, Any]] = [] def add_row(mode: str, mean_s: float, min_s: float, loss: float) -> None: rows.append( { "architecture": architecture, "ssm_core": args.ssm_core if architecture == "taonet_ssm" else None, "token_task": args.token_task, "train_steps": args.train_steps, "mode": mode, "batch_size": batch_size, "seq_len": seq_len, "tokens": tokens, "vocab_size": args.vocab_size, "hidden_dim": args.hidden_dim, "num_layers": args.num_layers, "num_heads": args.num_heads, "d_latent_kv": config.d_latent_kv, "ssm_hidden_dim": config.ssm_hidden_dim if architecture == "taonet_ssm" else None, "ssm_mixer_dim": config.ssm_mixer_dim if architecture == "taonet_ssm" else None, "ssm_rank": config.ssm_rank if architecture == "taonet_ssm" else None, "ssm_local_shift": config.ssm_local_shift if architecture == "taonet_ssm" else None, "ssm_local_shift_init": config.ssm_local_shift_init if architecture == "taonet_ssm" else None, "ssm_local_shift_per_channel": config.ssm_local_shift_per_channel if architecture == "taonet_ssm" else None, "dtype": str(dtype).replace("torch.", ""), "device": str(device), "total_params": total_params, "trainable_params": trainable_params, "mean_ms": mean_s * 1000.0, "min_ms": min_s * 1000.0, "tokens_per_s_mean": tokens / max(mean_s, 1e-12), "tokens_per_s_best": tokens / max(min_s, 1e-12), "loss": loss, "eval_loss": eval_loss, "eval_accuracy": eval_accuracy, "train_final_loss": train_final_loss, "train_seconds": train_seconds, **memory_stats(device), } ) def forward_only() -> float: with torch.no_grad(): with autocast_context(): outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) loss = outputs["loss"] return float(loss.detach().cpu()) mean_s, min_s, loss = time_repeats( forward_only, device=device, warmup=args.warmup, repeats=args.repeats, ) add_row("forward", mean_s, min_s, loss) if args.backward: def forward_backward() -> float: model.zero_grad(set_to_none=True) with autocast_context(): outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) loss = outputs["loss"] loss.backward() return float(loss.detach().cpu()) mean_s, min_s, loss = time_repeats( forward_backward, device=device, warmup=args.warmup, repeats=args.repeats, ) add_row("forward_backward", mean_s, min_s, loss) return rows def print_table(rows: list[dict[str, Any]]) -> None: columns = [ "architecture", "ssm_core", "token_task", "mode", "batch_size", "seq_len", "mean_ms", "tokens_per_s_mean", "peak_allocated_mb", "loss", "eval_loss", "eval_accuracy", ] print("\t".join(columns)) for row in rows: values = [] for column in columns: value = row[column] if isinstance(value, float): values.append(f"{value:.3f}") else: values.append(str(value)) print("\t".join(values)) def write_outputs(rows: list[dict[str, Any]], output_dir: Path, metadata: dict[str, Any]) -> None: output_dir.mkdir(parents=True, exist_ok=True) json_path = output_dir / "taonet_token_benchmark.json" csv_path = output_dir / "taonet_token_benchmark.csv" json_path.write_text(json.dumps({"metadata": metadata, "results": rows}, indent=2), encoding="utf-8") fieldnames = list(rows[0].keys()) if rows else [] with csv_path.open("w", newline="", encoding="utf-8") as handle: writer = csv.DictWriter(handle, fieldnames=fieldnames) writer.writeheader() writer.writerows(rows) print(f"Wrote {json_path}") print(f"Wrote {csv_path}") def main() -> None: parser = argparse.ArgumentParser(description="Benchmark TaoNet attention vs TaoNet-SSM on token batches.") parser.add_argument("--architectures", default="taonet,taonet_ssm") parser.add_argument("--batch-sizes", default="1,4") parser.add_argument("--seq-lens", default="128,512") parser.add_argument("--vocab-size", type=int, default=8192) parser.add_argument("--hidden-dim", type=int, default=256) parser.add_argument("--num-layers", type=int, default=4) parser.add_argument("--num-heads", type=int, default=4) parser.add_argument("--d-latent-kv", type=int, default=None) parser.add_argument("--d-rope", type=int, default=None) parser.add_argument("--hidden-dim-ff", type=int, default=None) parser.add_argument("--dropout", type=float, default=0.0) parser.add_argument("--gqa-groups", type=int, default=1) parser.add_argument("--rope-scale", type=float, default=40.0) parser.add_argument("--yarn-alpha", type=float, default=1.0) parser.add_argument("--init-std", type=float, default=0.02) parser.add_argument("--ssm-core", choices=["gamma_s4", "dplr"], default="dplr") parser.add_argument("--ssm-hidden-dim", type=int, default=None) parser.add_argument("--ssm-mixer-dim", type=int, default=None) parser.add_argument("--ssm-rank", type=int, default=1) parser.add_argument("--ssm-max-low-rank-scale", type=float, default=0.1) parser.add_argument("--ssm-kernel-mode", choices=["auto", "conv", "conv_transfer", "recurrent"], default="conv") parser.add_argument("--ssm-kernel-threshold", type=int, default=1) parser.add_argument("--ssm-dt-min", type=float, default=1e-3) parser.add_argument("--ssm-dt-max", type=float, default=1e-1) parser.add_argument("--ssm-dt-init", type=float, default=1e-2) parser.add_argument("--ssm-use-padding-mask", action="store_true") parser.add_argument("--ssm-activation", choices=["gelu", "silu", "identity", "linear"], default="gelu") parser.add_argument("--ssm-gate", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--ssm-input-gate", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--ssm-layer-scale-init", type=float, default=0.1) parser.add_argument("--ssm-local-shift", action=argparse.BooleanOptionalAction, default=False) parser.add_argument("--ssm-local-shift-init", type=float, default=0.1) parser.add_argument("--ssm-local-shift-per-channel", action=argparse.BooleanOptionalAction, default=False) parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16") parser.add_argument("--device", default="auto") parser.add_argument("--warmup", type=int, default=2) parser.add_argument("--repeats", type=int, default=5) parser.add_argument("--backward", action="store_true") parser.add_argument("--token-task", choices=["random", "increment", "previous"], default="random") parser.add_argument("--train-steps", type=int, default=0) parser.add_argument("--learning-rate", type=float, default=3e-4) parser.add_argument("--weight-decay", type=float, default=0.01) parser.add_argument("--eval-batches", type=int, default=1) parser.add_argument("--output-dir", default=os.environ.get("REPOBRIDGE_OUTPUT_DIR", "results/token-bench")) args = parser.parse_args() if args.device == "auto": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device(args.device) dtype = DTYPES[args.dtype] if device.type != "cuda" and dtype == torch.float16: raise ValueError("float16 benchmark requires CUDA.") if device.type == "cuda": torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True architectures = [item.strip() for item in args.architectures.split(",") if item.strip()] rows: list[dict[str, Any]] = [] metadata = { "python": platform.python_version(), "platform": platform.platform(), "torch": torch.__version__, "cuda_available": torch.cuda.is_available(), "cuda_device": torch.cuda.get_device_name(device) if device.type == "cuda" else None, "nvidia_smi_before": nvidia_smi_snapshot(), "args": vars(args), } for architecture in architectures: for batch_size in parse_int_list(args.batch_sizes): for seq_len in parse_int_list(args.seq_lens): print(f"Benchmarking architecture={architecture} batch={batch_size} seq={seq_len}") rows.extend( benchmark_case( args=args, architecture=architecture, batch_size=batch_size, seq_len=seq_len, dtype=dtype, device=device, ) ) metadata["nvidia_smi_after"] = nvidia_smi_snapshot() print_table(rows) write_outputs(rows, Path(args.output_dir), metadata) if __name__ == "__main__": main()