TaoNet-mini-T2 / code /TaoTrain /scripts /benchmark_taonet_token_variants.py
StarMist0012's picture
Add files using upload-large-folder tool
e2bfccc verified
"""Token-level benchmark for TaoNet attention vs TaoNet-SSM.
The goal is to compare the two LLM wrappers with the same outer dimensions:
original MLA attention TaoNet versus TaoNet with an SSM mixer.
"""
from __future__ import annotations
import argparse
from contextlib import nullcontext
from contextlib import redirect_stdout
import csv
import io
import json
import os
from pathlib import Path
import platform
import subprocess
import sys
import time
from typing import Any
import torch
REPO_ROOT = Path(__file__).resolve().parents[1]
SRC_ROOT = REPO_ROOT / "src"
if str(SRC_ROOT) not in sys.path:
sys.path.insert(0, str(SRC_ROOT))
from taoTrain.config import ModelConfig
from taoTrain.models import get_model
DTYPES = {
"float32": torch.float32,
"fp32": torch.float32,
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
}
def parse_int_list(value: str) -> list[int]:
return [int(item.strip()) for item in value.split(",") if item.strip()]
def synchronize(device: torch.device) -> None:
if device.type == "cuda":
torch.cuda.synchronize(device)
def reset_memory(device: torch.device) -> None:
if device.type == "cuda":
torch.cuda.reset_peak_memory_stats(device)
def memory_stats(device: torch.device) -> dict[str, float | None]:
if device.type != "cuda":
return {
"peak_allocated_mb": None,
"peak_reserved_mb": None,
}
return {
"peak_allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2),
"peak_reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2),
}
def nvidia_smi_snapshot() -> str | None:
try:
completed = subprocess.run(
[
"nvidia-smi",
"--query-gpu=name,memory.used,memory.total,utilization.gpu,utilization.memory,power.draw,temperature.gpu",
"--format=csv,noheader,nounits",
],
check=False,
capture_output=True,
text=True,
timeout=5,
)
except (OSError, subprocess.TimeoutExpired):
return None
if completed.returncode != 0:
return None
return completed.stdout.strip()
def make_token_batch(
*,
batch_size: int,
seq_len: int,
vocab_size: int,
device: torch.device,
task: str = "random",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if task == "random":
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
labels = torch.empty_like(input_ids)
labels[:, :-1] = input_ids[:, 1:]
labels[:, -1] = torch.randint(0, vocab_size, (batch_size,), device=device)
elif task == "increment":
starts = torch.randint(0, vocab_size, (batch_size, 1), device=device)
offsets = torch.arange(seq_len, device=device).view(1, seq_len)
input_ids = (starts + offsets) % vocab_size
labels = (input_ids + 1) % vocab_size
elif task == "previous":
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
labels = torch.empty_like(input_ids)
labels[:, 0] = -100
labels[:, 1:] = input_ids[:, :-1]
else:
raise ValueError(f"Unsupported token task '{task}'.")
attention_mask = torch.ones_like(input_ids)
return input_ids, labels, attention_mask
def token_accuracy(logits: torch.Tensor, labels: torch.Tensor) -> float:
valid = labels != -100
if not torch.any(valid):
return float("nan")
predictions = torch.argmax(logits, dim=-1)
correct = (predictions == labels) & valid
return float(correct.sum().detach().cpu() / valid.sum().detach().cpu())
def build_config(args: argparse.Namespace, architecture: str) -> ModelConfig:
d_latent_kv = args.d_latent_kv if args.d_latent_kv is not None else int(args.hidden_dim * 0.75)
d_rope = args.d_rope if args.d_rope is not None else args.hidden_dim // args.num_heads
hidden_dim_ff = args.hidden_dim_ff if args.hidden_dim_ff is not None else args.hidden_dim * 4
return ModelConfig(
architecture_type=architecture,
vocab_size=args.vocab_size,
hidden_dim=args.hidden_dim,
num_layers=args.num_layers,
num_heads=args.num_heads,
max_seq_length=max(parse_int_list(args.seq_lens)),
d_latent_kv=d_latent_kv,
d_rope=d_rope,
hidden_dim_ff=hidden_dim_ff,
dropout=args.dropout,
gqa_groups=args.gqa_groups,
rope_scale=args.rope_scale,
yarn_alpha=args.yarn_alpha,
init_std=args.init_std,
ssm_core=args.ssm_core,
ssm_hidden_dim=args.ssm_hidden_dim or d_latent_kv,
ssm_mixer_dim=args.ssm_mixer_dim,
ssm_rank=args.ssm_rank,
ssm_max_low_rank_scale=args.ssm_max_low_rank_scale,
ssm_kernel_mode=args.ssm_kernel_mode,
ssm_kernel_threshold=args.ssm_kernel_threshold,
ssm_dt_min=args.ssm_dt_min,
ssm_dt_max=args.ssm_dt_max,
ssm_dt_init=args.ssm_dt_init,
ssm_use_padding_mask=args.ssm_use_padding_mask,
ssm_activation=args.ssm_activation,
ssm_gate=args.ssm_gate,
ssm_input_gate=args.ssm_input_gate,
ssm_layer_scale_init=args.ssm_layer_scale_init,
ssm_local_shift=args.ssm_local_shift,
ssm_local_shift_init=args.ssm_local_shift_init,
ssm_local_shift_per_channel=args.ssm_local_shift_per_channel,
)
def count_params(model: torch.nn.Module) -> tuple[int, int]:
total = sum(param.numel() for param in model.parameters())
trainable = sum(param.numel() for param in model.parameters() if param.requires_grad)
return total, trainable
def time_repeats(fn, *, device: torch.device, warmup: int, repeats: int) -> tuple[float, float, float]:
last_loss = float("nan")
for _ in range(warmup):
last_loss = fn()
synchronize(device)
latencies = []
for _ in range(repeats):
reset_memory(device)
synchronize(device)
start = time.perf_counter()
last_loss = fn()
synchronize(device)
latencies.append(time.perf_counter() - start)
return sum(latencies) / len(latencies), min(latencies), last_loss
def evaluate_model(
model: torch.nn.Module,
*,
args: argparse.Namespace,
batch_size: int,
seq_len: int,
device: torch.device,
autocast_context,
) -> tuple[float, float]:
model.eval()
losses = []
accuracies = []
with torch.no_grad():
for _ in range(args.eval_batches):
input_ids, labels, attention_mask = make_token_batch(
batch_size=batch_size,
seq_len=seq_len,
vocab_size=args.vocab_size,
device=device,
task=args.token_task,
)
with autocast_context():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
losses.append(float(outputs["loss"].detach().cpu()))
accuracies.append(token_accuracy(outputs["logits"], labels))
model.train()
return sum(losses) / len(losses), sum(accuracies) / len(accuracies)
def train_model(
model: torch.nn.Module,
*,
args: argparse.Namespace,
batch_size: int,
seq_len: int,
device: torch.device,
autocast_context,
) -> tuple[float | None, float | None]:
if args.train_steps <= 0:
return None, None
model.train()
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
weight_decay=args.weight_decay,
)
last_loss = float("nan")
start = time.perf_counter()
for _ in range(args.train_steps):
input_ids, labels, attention_mask = make_token_batch(
batch_size=batch_size,
seq_len=seq_len,
vocab_size=args.vocab_size,
device=device,
task=args.token_task,
)
optimizer.zero_grad(set_to_none=True)
with autocast_context():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs["loss"]
loss.backward()
optimizer.step()
last_loss = float(loss.detach().cpu())
synchronize(device)
return last_loss, time.perf_counter() - start
def benchmark_case(
*,
args: argparse.Namespace,
architecture: str,
batch_size: int,
seq_len: int,
dtype: torch.dtype,
device: torch.device,
) -> list[dict[str, Any]]:
config = build_config(args, architecture)
with redirect_stdout(io.StringIO()):
model = get_model(config, device=device)
model.train()
total_params, trainable_params = count_params(model)
tokens = batch_size * seq_len
input_ids, labels, attention_mask = make_token_batch(
batch_size=batch_size,
seq_len=seq_len,
vocab_size=args.vocab_size,
device=device,
task=args.token_task,
)
autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
def autocast_context():
if not autocast_enabled:
return nullcontext()
return torch.autocast(device_type=device.type, dtype=dtype, enabled=True)
train_final_loss, train_seconds = train_model(
model,
args=args,
batch_size=batch_size,
seq_len=seq_len,
device=device,
autocast_context=autocast_context,
)
eval_loss, eval_accuracy = evaluate_model(
model,
args=args,
batch_size=batch_size,
seq_len=seq_len,
device=device,
autocast_context=autocast_context,
)
rows: list[dict[str, Any]] = []
def add_row(mode: str, mean_s: float, min_s: float, loss: float) -> None:
rows.append(
{
"architecture": architecture,
"ssm_core": args.ssm_core if architecture == "taonet_ssm" else None,
"token_task": args.token_task,
"train_steps": args.train_steps,
"mode": mode,
"batch_size": batch_size,
"seq_len": seq_len,
"tokens": tokens,
"vocab_size": args.vocab_size,
"hidden_dim": args.hidden_dim,
"num_layers": args.num_layers,
"num_heads": args.num_heads,
"d_latent_kv": config.d_latent_kv,
"ssm_hidden_dim": config.ssm_hidden_dim if architecture == "taonet_ssm" else None,
"ssm_mixer_dim": config.ssm_mixer_dim if architecture == "taonet_ssm" else None,
"ssm_rank": config.ssm_rank if architecture == "taonet_ssm" else None,
"ssm_local_shift": config.ssm_local_shift if architecture == "taonet_ssm" else None,
"ssm_local_shift_init": config.ssm_local_shift_init if architecture == "taonet_ssm" else None,
"ssm_local_shift_per_channel": config.ssm_local_shift_per_channel if architecture == "taonet_ssm" else None,
"dtype": str(dtype).replace("torch.", ""),
"device": str(device),
"total_params": total_params,
"trainable_params": trainable_params,
"mean_ms": mean_s * 1000.0,
"min_ms": min_s * 1000.0,
"tokens_per_s_mean": tokens / max(mean_s, 1e-12),
"tokens_per_s_best": tokens / max(min_s, 1e-12),
"loss": loss,
"eval_loss": eval_loss,
"eval_accuracy": eval_accuracy,
"train_final_loss": train_final_loss,
"train_seconds": train_seconds,
**memory_stats(device),
}
)
def forward_only() -> float:
with torch.no_grad():
with autocast_context():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs["loss"]
return float(loss.detach().cpu())
mean_s, min_s, loss = time_repeats(
forward_only,
device=device,
warmup=args.warmup,
repeats=args.repeats,
)
add_row("forward", mean_s, min_s, loss)
if args.backward:
def forward_backward() -> float:
model.zero_grad(set_to_none=True)
with autocast_context():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs["loss"]
loss.backward()
return float(loss.detach().cpu())
mean_s, min_s, loss = time_repeats(
forward_backward,
device=device,
warmup=args.warmup,
repeats=args.repeats,
)
add_row("forward_backward", mean_s, min_s, loss)
return rows
def print_table(rows: list[dict[str, Any]]) -> None:
columns = [
"architecture",
"ssm_core",
"token_task",
"mode",
"batch_size",
"seq_len",
"mean_ms",
"tokens_per_s_mean",
"peak_allocated_mb",
"loss",
"eval_loss",
"eval_accuracy",
]
print("\t".join(columns))
for row in rows:
values = []
for column in columns:
value = row[column]
if isinstance(value, float):
values.append(f"{value:.3f}")
else:
values.append(str(value))
print("\t".join(values))
def write_outputs(rows: list[dict[str, Any]], output_dir: Path, metadata: dict[str, Any]) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
json_path = output_dir / "taonet_token_benchmark.json"
csv_path = output_dir / "taonet_token_benchmark.csv"
json_path.write_text(json.dumps({"metadata": metadata, "results": rows}, indent=2), encoding="utf-8")
fieldnames = list(rows[0].keys()) if rows else []
with csv_path.open("w", newline="", encoding="utf-8") as handle:
writer = csv.DictWriter(handle, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
print(f"Wrote {json_path}")
print(f"Wrote {csv_path}")
def main() -> None:
parser = argparse.ArgumentParser(description="Benchmark TaoNet attention vs TaoNet-SSM on token batches.")
parser.add_argument("--architectures", default="taonet,taonet_ssm")
parser.add_argument("--batch-sizes", default="1,4")
parser.add_argument("--seq-lens", default="128,512")
parser.add_argument("--vocab-size", type=int, default=8192)
parser.add_argument("--hidden-dim", type=int, default=256)
parser.add_argument("--num-layers", type=int, default=4)
parser.add_argument("--num-heads", type=int, default=4)
parser.add_argument("--d-latent-kv", type=int, default=None)
parser.add_argument("--d-rope", type=int, default=None)
parser.add_argument("--hidden-dim-ff", type=int, default=None)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--gqa-groups", type=int, default=1)
parser.add_argument("--rope-scale", type=float, default=40.0)
parser.add_argument("--yarn-alpha", type=float, default=1.0)
parser.add_argument("--init-std", type=float, default=0.02)
parser.add_argument("--ssm-core", choices=["gamma_s4", "dplr"], default="dplr")
parser.add_argument("--ssm-hidden-dim", type=int, default=None)
parser.add_argument("--ssm-mixer-dim", type=int, default=None)
parser.add_argument("--ssm-rank", type=int, default=1)
parser.add_argument("--ssm-max-low-rank-scale", type=float, default=0.1)
parser.add_argument("--ssm-kernel-mode", choices=["auto", "conv", "conv_transfer", "recurrent"], default="conv")
parser.add_argument("--ssm-kernel-threshold", type=int, default=1)
parser.add_argument("--ssm-dt-min", type=float, default=1e-3)
parser.add_argument("--ssm-dt-max", type=float, default=1e-1)
parser.add_argument("--ssm-dt-init", type=float, default=1e-2)
parser.add_argument("--ssm-use-padding-mask", action="store_true")
parser.add_argument("--ssm-activation", choices=["gelu", "silu", "identity", "linear"], default="gelu")
parser.add_argument("--ssm-gate", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--ssm-input-gate", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--ssm-layer-scale-init", type=float, default=0.1)
parser.add_argument("--ssm-local-shift", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--ssm-local-shift-init", type=float, default=0.1)
parser.add_argument("--ssm-local-shift-per-channel", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16")
parser.add_argument("--device", default="auto")
parser.add_argument("--warmup", type=int, default=2)
parser.add_argument("--repeats", type=int, default=5)
parser.add_argument("--backward", action="store_true")
parser.add_argument("--token-task", choices=["random", "increment", "previous"], default="random")
parser.add_argument("--train-steps", type=int, default=0)
parser.add_argument("--learning-rate", type=float, default=3e-4)
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--eval-batches", type=int, default=1)
parser.add_argument("--output-dir", default=os.environ.get("REPOBRIDGE_OUTPUT_DIR", "results/token-bench"))
args = parser.parse_args()
if args.device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
dtype = DTYPES[args.dtype]
if device.type != "cuda" and dtype == torch.float16:
raise ValueError("float16 benchmark requires CUDA.")
if device.type == "cuda":
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
architectures = [item.strip() for item in args.architectures.split(",") if item.strip()]
rows: list[dict[str, Any]] = []
metadata = {
"python": platform.python_version(),
"platform": platform.platform(),
"torch": torch.__version__,
"cuda_available": torch.cuda.is_available(),
"cuda_device": torch.cuda.get_device_name(device) if device.type == "cuda" else None,
"nvidia_smi_before": nvidia_smi_snapshot(),
"args": vars(args),
}
for architecture in architectures:
for batch_size in parse_int_list(args.batch_sizes):
for seq_len in parse_int_list(args.seq_lens):
print(f"Benchmarking architecture={architecture} batch={batch_size} seq={seq_len}")
rows.extend(
benchmark_case(
args=args,
architecture=architecture,
batch_size=batch_size,
seq_len=seq_len,
dtype=dtype,
device=device,
)
)
metadata["nvidia_smi_after"] = nvidia_smi_snapshot()
print_table(rows)
write_outputs(rows, Path(args.output_dir), metadata)
if __name__ == "__main__":
main()