"""Profile TaoNet and TaoNet-SSM component costs on synthetic token batches. The real-token benchmark tells us end-to-end quality and throughput. This script is the companion microscope: it times forward components such as the SSM core, gates, projections, FFN, embeddings, and output head so hardware work targets the largest measured costs. """ from __future__ import annotations import argparse from collections import defaultdict from contextlib import nullcontext from contextlib import redirect_stdout import io import json import os from pathlib import Path import platform import sys import time from typing import Any import torch REPO_ROOT = Path(__file__).resolve().parents[1] SRC_ROOT = REPO_ROOT / "src" if str(SRC_ROOT) not in sys.path: sys.path.insert(0, str(SRC_ROOT)) from taoTrain.config import ModelConfig from taoTrain.models import get_model DTYPES = { "float32": torch.float32, "fp32": torch.float32, "float16": torch.float16, "fp16": torch.float16, "bfloat16": torch.bfloat16, "bf16": torch.bfloat16, } def synchronize(device: torch.device) -> None: if device.type == "cuda": torch.cuda.synchronize(device) def reset_memory(device: torch.device) -> None: if device.type == "cuda": torch.cuda.reset_peak_memory_stats(device) def memory_stats(device: torch.device) -> dict[str, float | None]: if device.type != "cuda": return {"peak_allocated_mb": None, "peak_reserved_mb": None} return { "peak_allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2), "peak_reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2), } class ComponentTimer: def __init__(self, device: torch.device) -> None: self.device = device self.records: dict[str, list[float]] = defaultdict(list) self._starts: dict[int, Any] = {} self._handles = [] def _record_ms(self, name: str, start: Any) -> None: if self.device.type == "cuda": end = torch.cuda.Event(enable_timing=True) end.record() end.synchronize() self.records[name].append(float(start.elapsed_time(end))) else: self.records[name].append((time.perf_counter() - start) * 1000.0) def add(self, module: torch.nn.Module, name: str) -> None: def pre_hook(mod, inputs): del inputs if self.device.type == "cuda": start = torch.cuda.Event(enable_timing=True) start.record() else: start = time.perf_counter() self._starts[id(mod)] = start def post_hook(mod, inputs, output): del inputs, output start = self._starts.pop(id(mod), None) if start is not None: self._record_ms(name, start) self._handles.append(module.register_forward_pre_hook(pre_hook)) self._handles.append(module.register_forward_hook(post_hook)) def close(self) -> None: for handle in self._handles: handle.remove() self._handles.clear() def summary(self) -> list[dict[str, float | str | int]]: rows = [] for name, values in sorted(self.records.items()): if not values: continue rows.append( { "component": name, "calls": len(values), "mean_ms": sum(values) / len(values), "total_ms": sum(values), "min_ms": min(values), "max_ms": max(values), } ) rows.sort(key=lambda row: float(row["total_ms"]), reverse=True) return rows def build_config(args: argparse.Namespace, architecture: str) -> ModelConfig: d_latent_kv = args.d_latent_kv if args.d_latent_kv is not None else int(args.hidden_dim * 0.75) d_rope = args.d_rope if args.d_rope is not None else args.hidden_dim // args.num_heads hidden_dim_ff = args.hidden_dim_ff if args.hidden_dim_ff is not None else args.hidden_dim * 4 return ModelConfig( architecture_type=architecture, vocab_size=args.vocab_size, hidden_dim=args.hidden_dim, num_layers=args.num_layers, num_heads=args.num_heads, max_seq_length=args.seq_len, d_latent_kv=d_latent_kv, d_rope=d_rope, hidden_dim_ff=hidden_dim_ff, dropout=args.dropout, gqa_groups=args.gqa_groups, rope_scale=args.rope_scale, yarn_alpha=args.yarn_alpha, init_std=args.init_std, ssm_core=args.ssm_core, ssm_hidden_dim=args.ssm_hidden_dim, ssm_mixer_dim=args.ssm_mixer_dim, ssm_rank=args.ssm_rank, ssm_max_low_rank_scale=args.ssm_max_low_rank_scale, ssm_kernel_mode=args.ssm_kernel_mode, ssm_kernel_threshold=args.ssm_kernel_threshold, ssm_dt_min=args.ssm_dt_min, ssm_dt_max=args.ssm_dt_max, ssm_dt_init=args.ssm_dt_init, ssm_use_padding_mask=False, ssm_activation=args.ssm_activation, ssm_gate=args.ssm_gate, ssm_input_gate=args.ssm_input_gate, ssm_layer_scale_init=args.ssm_layer_scale_init, ssm_local_shift=args.ssm_local_shift, ssm_local_shift_init=args.ssm_local_shift_init, ssm_local_shift_per_channel=args.ssm_local_shift_per_channel, ) def add_component_hooks(model: torch.nn.Module, architecture: str, timer: ComponentTimer) -> None: timer.add(model.token_embedding, "embedding") timer.add(model.final_norm, "final_norm") timer.add(model.output_head, "output_head") for layer_index, block in enumerate(model.blocks): prefix = f"block{layer_index}" if architecture == "taonet_ssm": mixer = block.mixer timer.add(mixer.norm, f"{prefix}.mixer.norm") if mixer.input_gate is not None: timer.add(mixer.input_gate, f"{prefix}.mixer.input_gate") timer.add(mixer.input_proj, f"{prefix}.mixer.input_proj") timer.add(mixer.ssm, f"{prefix}.mixer.ssm_core") timer.add(mixer.activation, f"{prefix}.mixer.activation") timer.add(mixer.out_proj, f"{prefix}.mixer.out_proj") if mixer.output_gate is not None: timer.add(mixer.output_gate, f"{prefix}.mixer.output_gate") timer.add(mixer.proj_dropout, f"{prefix}.mixer.dropout") else: mla = block.mla timer.add(mla.norm, f"{prefix}.attention.norm") timer.add(mla.q_proj, f"{prefix}.attention.q_proj") timer.add(mla.k_proj, f"{prefix}.attention.k_proj") timer.add(mla.v_proj, f"{prefix}.attention.v_proj") timer.add(mla.out_proj, f"{prefix}.attention.out_proj") timer.add(mla.attn_dropout, f"{prefix}.attention.attn_dropout") timer.add(mla.proj_dropout, f"{prefix}.attention.proj_dropout") timer.add(block.ff_norm, f"{prefix}.ff.norm") timer.add(block.ff_gate, f"{prefix}.ff.gate") timer.add(block.ff_value, f"{prefix}.ff.value") timer.add(block.ff_out, f"{prefix}.ff.out") def time_repeats(fn, *, device: torch.device, warmup: int, repeats: int) -> dict[str, float]: for _ in range(warmup): fn() synchronize(device) latencies = [] for _ in range(repeats): reset_memory(device) synchronize(device) start = time.perf_counter() fn() synchronize(device) latencies.append(time.perf_counter() - start) mean_s = sum(latencies) / len(latencies) return { "mean_ms": mean_s * 1000.0, "min_ms": min(latencies) * 1000.0, "max_ms": max(latencies) * 1000.0, } def profile_architecture( args: argparse.Namespace, *, architecture: str, device: torch.device, dtype: torch.dtype, ) -> dict[str, Any]: torch.manual_seed(args.seed) if device.type == "cuda": torch.cuda.manual_seed_all(args.seed) config = build_config(args, architecture) with redirect_stdout(io.StringIO()): model = get_model(config, device=device) model.train() input_ids = torch.randint( low=0, high=args.vocab_size, size=(args.batch_size, args.seq_len), device=device, ) labels = torch.randint( low=0, high=args.vocab_size, size=(args.batch_size, args.seq_len), device=device, ) attention_mask = torch.ones_like(input_ids) autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16} def autocast_context(): if not autocast_enabled: return nullcontext() return torch.autocast(device_type=device.type, dtype=dtype, enabled=True) def forward_only() -> torch.Tensor: with torch.no_grad(): with autocast_context(): outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) return outputs["loss"] def forward_backward() -> torch.Tensor: model.zero_grad(set_to_none=True) with autocast_context(): outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) loss = outputs["loss"] loss.backward() return loss no_timer_forward = time_repeats( forward_only, device=device, warmup=args.warmup, repeats=args.repeats, ) no_timer_backward = time_repeats( forward_backward, device=device, warmup=args.warmup, repeats=args.repeats, ) timer = ComponentTimer(device) add_component_hooks(model, architecture, timer) try: for _ in range(args.component_warmup): forward_only() synchronize(device) for _ in range(args.component_repeats): forward_only() synchronize(device) finally: timer.close() tokens = args.batch_size * args.seq_len component_rows = timer.summary() return { "architecture": architecture, "total_params": sum(param.numel() for param in model.parameters()), "trainable_params": sum(param.numel() for param in model.parameters() if param.requires_grad), "forward": { **no_timer_forward, "tokens_per_s": tokens / max(no_timer_forward["mean_ms"] / 1000.0, 1e-12), }, "forward_backward": { **no_timer_backward, "tokens_per_s": tokens / max(no_timer_backward["mean_ms"] / 1000.0, 1e-12), **memory_stats(device), }, "components_forward": component_rows, } def main() -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--architectures", default="taonet,taonet_ssm") parser.add_argument("--vocab-size", type=int, default=8192) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--seq-len", type=int, default=512) parser.add_argument("--hidden-dim", type=int, default=256) parser.add_argument("--num-layers", type=int, default=4) parser.add_argument("--num-heads", type=int, default=4) parser.add_argument("--d-latent-kv", type=int, default=None) parser.add_argument("--d-rope", type=int, default=None) parser.add_argument("--hidden-dim-ff", type=int, default=None) parser.add_argument("--dropout", type=float, default=0.0) parser.add_argument("--gqa-groups", type=int, default=1) parser.add_argument("--rope-scale", type=float, default=40.0) parser.add_argument("--yarn-alpha", type=float, default=1.0) parser.add_argument("--init-std", type=float, default=0.02) parser.add_argument("--ssm-core", choices=["gamma_s4", "dplr"], default="dplr") parser.add_argument("--ssm-hidden-dim", type=int, default=16) parser.add_argument("--ssm-mixer-dim", type=int, default=128) parser.add_argument("--ssm-rank", type=int, default=1) parser.add_argument("--ssm-max-low-rank-scale", type=float, default=0.1) parser.add_argument("--ssm-kernel-mode", choices=["auto", "conv", "conv_transfer", "recurrent"], default="conv") parser.add_argument("--ssm-kernel-threshold", type=int, default=1) parser.add_argument("--ssm-dt-min", type=float, default=1e-3) parser.add_argument("--ssm-dt-max", type=float, default=1e-1) parser.add_argument("--ssm-dt-init", type=float, default=1e-2) parser.add_argument("--ssm-activation", choices=["gelu", "silu", "identity", "linear"], default="gelu") parser.add_argument("--ssm-gate", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--ssm-input-gate", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--ssm-layer-scale-init", type=float, default=0.1) parser.add_argument("--ssm-local-shift", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--ssm-local-shift-init", type=float, default=0.1) parser.add_argument("--ssm-local-shift-per-channel", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16") parser.add_argument("--device", default="auto") parser.add_argument("--warmup", type=int, default=2) parser.add_argument("--repeats", type=int, default=5) parser.add_argument("--component-warmup", type=int, default=1) parser.add_argument("--component-repeats", type=int, default=3) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--output", type=Path, default=None) args = parser.parse_args() if args.device == "auto": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device(args.device) dtype = DTYPES[args.dtype] if device.type == "cuda": torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True results = [ profile_architecture(args, architecture=architecture.strip(), device=device, dtype=dtype) for architecture in args.architectures.split(",") if architecture.strip() ] report = { "metadata": { "python": platform.python_version(), "platform": platform.platform(), "torch": torch.__version__, "cuda_available": torch.cuda.is_available(), "cuda_device": torch.cuda.get_device_name(device) if device.type == "cuda" else None, "device": str(device), "dtype": str(dtype).replace("torch.", ""), "args": vars(args) | {"output": str(args.output) if args.output else None}, }, "results": results, } text = json.dumps(report, indent=2, sort_keys=True, default=str) print(text) if args.output is not None: args.output.parent.mkdir(parents=True, exist_ok=True) args.output.write_text(text, encoding="utf-8") return 0 if __name__ == "__main__": raise SystemExit(main())