TaoNet-mini-T2 / code /TaoTrain /scripts /profile_taonet_components.py
StarMist0012's picture
Add files using upload-large-folder tool
e2bfccc verified
"""Profile TaoNet and TaoNet-SSM component costs on synthetic token batches.
The real-token benchmark tells us end-to-end quality and throughput. This
script is the companion microscope: it times forward components such as the
SSM core, gates, projections, FFN, embeddings, and output head so hardware work
targets the largest measured costs.
"""
from __future__ import annotations
import argparse
from collections import defaultdict
from contextlib import nullcontext
from contextlib import redirect_stdout
import io
import json
import os
from pathlib import Path
import platform
import sys
import time
from typing import Any
import torch
REPO_ROOT = Path(__file__).resolve().parents[1]
SRC_ROOT = REPO_ROOT / "src"
if str(SRC_ROOT) not in sys.path:
sys.path.insert(0, str(SRC_ROOT))
from taoTrain.config import ModelConfig
from taoTrain.models import get_model
DTYPES = {
"float32": torch.float32,
"fp32": torch.float32,
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
}
def synchronize(device: torch.device) -> None:
if device.type == "cuda":
torch.cuda.synchronize(device)
def reset_memory(device: torch.device) -> None:
if device.type == "cuda":
torch.cuda.reset_peak_memory_stats(device)
def memory_stats(device: torch.device) -> dict[str, float | None]:
if device.type != "cuda":
return {"peak_allocated_mb": None, "peak_reserved_mb": None}
return {
"peak_allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2),
"peak_reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2),
}
class ComponentTimer:
def __init__(self, device: torch.device) -> None:
self.device = device
self.records: dict[str, list[float]] = defaultdict(list)
self._starts: dict[int, Any] = {}
self._handles = []
def _record_ms(self, name: str, start: Any) -> None:
if self.device.type == "cuda":
end = torch.cuda.Event(enable_timing=True)
end.record()
end.synchronize()
self.records[name].append(float(start.elapsed_time(end)))
else:
self.records[name].append((time.perf_counter() - start) * 1000.0)
def add(self, module: torch.nn.Module, name: str) -> None:
def pre_hook(mod, inputs):
del inputs
if self.device.type == "cuda":
start = torch.cuda.Event(enable_timing=True)
start.record()
else:
start = time.perf_counter()
self._starts[id(mod)] = start
def post_hook(mod, inputs, output):
del inputs, output
start = self._starts.pop(id(mod), None)
if start is not None:
self._record_ms(name, start)
self._handles.append(module.register_forward_pre_hook(pre_hook))
self._handles.append(module.register_forward_hook(post_hook))
def close(self) -> None:
for handle in self._handles:
handle.remove()
self._handles.clear()
def summary(self) -> list[dict[str, float | str | int]]:
rows = []
for name, values in sorted(self.records.items()):
if not values:
continue
rows.append(
{
"component": name,
"calls": len(values),
"mean_ms": sum(values) / len(values),
"total_ms": sum(values),
"min_ms": min(values),
"max_ms": max(values),
}
)
rows.sort(key=lambda row: float(row["total_ms"]), reverse=True)
return rows
def build_config(args: argparse.Namespace, architecture: str) -> ModelConfig:
d_latent_kv = args.d_latent_kv if args.d_latent_kv is not None else int(args.hidden_dim * 0.75)
d_rope = args.d_rope if args.d_rope is not None else args.hidden_dim // args.num_heads
hidden_dim_ff = args.hidden_dim_ff if args.hidden_dim_ff is not None else args.hidden_dim * 4
return ModelConfig(
architecture_type=architecture,
vocab_size=args.vocab_size,
hidden_dim=args.hidden_dim,
num_layers=args.num_layers,
num_heads=args.num_heads,
max_seq_length=args.seq_len,
d_latent_kv=d_latent_kv,
d_rope=d_rope,
hidden_dim_ff=hidden_dim_ff,
dropout=args.dropout,
gqa_groups=args.gqa_groups,
rope_scale=args.rope_scale,
yarn_alpha=args.yarn_alpha,
init_std=args.init_std,
ssm_core=args.ssm_core,
ssm_hidden_dim=args.ssm_hidden_dim,
ssm_mixer_dim=args.ssm_mixer_dim,
ssm_rank=args.ssm_rank,
ssm_max_low_rank_scale=args.ssm_max_low_rank_scale,
ssm_kernel_mode=args.ssm_kernel_mode,
ssm_kernel_threshold=args.ssm_kernel_threshold,
ssm_dt_min=args.ssm_dt_min,
ssm_dt_max=args.ssm_dt_max,
ssm_dt_init=args.ssm_dt_init,
ssm_use_padding_mask=False,
ssm_activation=args.ssm_activation,
ssm_gate=args.ssm_gate,
ssm_input_gate=args.ssm_input_gate,
ssm_layer_scale_init=args.ssm_layer_scale_init,
ssm_local_shift=args.ssm_local_shift,
ssm_local_shift_init=args.ssm_local_shift_init,
ssm_local_shift_per_channel=args.ssm_local_shift_per_channel,
)
def add_component_hooks(model: torch.nn.Module, architecture: str, timer: ComponentTimer) -> None:
timer.add(model.token_embedding, "embedding")
timer.add(model.final_norm, "final_norm")
timer.add(model.output_head, "output_head")
for layer_index, block in enumerate(model.blocks):
prefix = f"block{layer_index}"
if architecture == "taonet_ssm":
mixer = block.mixer
timer.add(mixer.norm, f"{prefix}.mixer.norm")
if mixer.input_gate is not None:
timer.add(mixer.input_gate, f"{prefix}.mixer.input_gate")
timer.add(mixer.input_proj, f"{prefix}.mixer.input_proj")
timer.add(mixer.ssm, f"{prefix}.mixer.ssm_core")
timer.add(mixer.activation, f"{prefix}.mixer.activation")
timer.add(mixer.out_proj, f"{prefix}.mixer.out_proj")
if mixer.output_gate is not None:
timer.add(mixer.output_gate, f"{prefix}.mixer.output_gate")
timer.add(mixer.proj_dropout, f"{prefix}.mixer.dropout")
else:
mla = block.mla
timer.add(mla.norm, f"{prefix}.attention.norm")
timer.add(mla.q_proj, f"{prefix}.attention.q_proj")
timer.add(mla.k_proj, f"{prefix}.attention.k_proj")
timer.add(mla.v_proj, f"{prefix}.attention.v_proj")
timer.add(mla.out_proj, f"{prefix}.attention.out_proj")
timer.add(mla.attn_dropout, f"{prefix}.attention.attn_dropout")
timer.add(mla.proj_dropout, f"{prefix}.attention.proj_dropout")
timer.add(block.ff_norm, f"{prefix}.ff.norm")
timer.add(block.ff_gate, f"{prefix}.ff.gate")
timer.add(block.ff_value, f"{prefix}.ff.value")
timer.add(block.ff_out, f"{prefix}.ff.out")
def time_repeats(fn, *, device: torch.device, warmup: int, repeats: int) -> dict[str, float]:
for _ in range(warmup):
fn()
synchronize(device)
latencies = []
for _ in range(repeats):
reset_memory(device)
synchronize(device)
start = time.perf_counter()
fn()
synchronize(device)
latencies.append(time.perf_counter() - start)
mean_s = sum(latencies) / len(latencies)
return {
"mean_ms": mean_s * 1000.0,
"min_ms": min(latencies) * 1000.0,
"max_ms": max(latencies) * 1000.0,
}
def profile_architecture(
args: argparse.Namespace,
*,
architecture: str,
device: torch.device,
dtype: torch.dtype,
) -> dict[str, Any]:
torch.manual_seed(args.seed)
if device.type == "cuda":
torch.cuda.manual_seed_all(args.seed)
config = build_config(args, architecture)
with redirect_stdout(io.StringIO()):
model = get_model(config, device=device)
model.train()
input_ids = torch.randint(
low=0,
high=args.vocab_size,
size=(args.batch_size, args.seq_len),
device=device,
)
labels = torch.randint(
low=0,
high=args.vocab_size,
size=(args.batch_size, args.seq_len),
device=device,
)
attention_mask = torch.ones_like(input_ids)
autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
def autocast_context():
if not autocast_enabled:
return nullcontext()
return torch.autocast(device_type=device.type, dtype=dtype, enabled=True)
def forward_only() -> torch.Tensor:
with torch.no_grad():
with autocast_context():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
return outputs["loss"]
def forward_backward() -> torch.Tensor:
model.zero_grad(set_to_none=True)
with autocast_context():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs["loss"]
loss.backward()
return loss
no_timer_forward = time_repeats(
forward_only,
device=device,
warmup=args.warmup,
repeats=args.repeats,
)
no_timer_backward = time_repeats(
forward_backward,
device=device,
warmup=args.warmup,
repeats=args.repeats,
)
timer = ComponentTimer(device)
add_component_hooks(model, architecture, timer)
try:
for _ in range(args.component_warmup):
forward_only()
synchronize(device)
for _ in range(args.component_repeats):
forward_only()
synchronize(device)
finally:
timer.close()
tokens = args.batch_size * args.seq_len
component_rows = timer.summary()
return {
"architecture": architecture,
"total_params": sum(param.numel() for param in model.parameters()),
"trainable_params": sum(param.numel() for param in model.parameters() if param.requires_grad),
"forward": {
**no_timer_forward,
"tokens_per_s": tokens / max(no_timer_forward["mean_ms"] / 1000.0, 1e-12),
},
"forward_backward": {
**no_timer_backward,
"tokens_per_s": tokens / max(no_timer_backward["mean_ms"] / 1000.0, 1e-12),
**memory_stats(device),
},
"components_forward": component_rows,
}
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--architectures", default="taonet,taonet_ssm")
parser.add_argument("--vocab-size", type=int, default=8192)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--seq-len", type=int, default=512)
parser.add_argument("--hidden-dim", type=int, default=256)
parser.add_argument("--num-layers", type=int, default=4)
parser.add_argument("--num-heads", type=int, default=4)
parser.add_argument("--d-latent-kv", type=int, default=None)
parser.add_argument("--d-rope", type=int, default=None)
parser.add_argument("--hidden-dim-ff", type=int, default=None)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--gqa-groups", type=int, default=1)
parser.add_argument("--rope-scale", type=float, default=40.0)
parser.add_argument("--yarn-alpha", type=float, default=1.0)
parser.add_argument("--init-std", type=float, default=0.02)
parser.add_argument("--ssm-core", choices=["gamma_s4", "dplr"], default="dplr")
parser.add_argument("--ssm-hidden-dim", type=int, default=16)
parser.add_argument("--ssm-mixer-dim", type=int, default=128)
parser.add_argument("--ssm-rank", type=int, default=1)
parser.add_argument("--ssm-max-low-rank-scale", type=float, default=0.1)
parser.add_argument("--ssm-kernel-mode", choices=["auto", "conv", "conv_transfer", "recurrent"], default="conv")
parser.add_argument("--ssm-kernel-threshold", type=int, default=1)
parser.add_argument("--ssm-dt-min", type=float, default=1e-3)
parser.add_argument("--ssm-dt-max", type=float, default=1e-1)
parser.add_argument("--ssm-dt-init", type=float, default=1e-2)
parser.add_argument("--ssm-activation", choices=["gelu", "silu", "identity", "linear"], default="gelu")
parser.add_argument("--ssm-gate", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--ssm-input-gate", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--ssm-layer-scale-init", type=float, default=0.1)
parser.add_argument("--ssm-local-shift", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--ssm-local-shift-init", type=float, default=0.1)
parser.add_argument("--ssm-local-shift-per-channel", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16")
parser.add_argument("--device", default="auto")
parser.add_argument("--warmup", type=int, default=2)
parser.add_argument("--repeats", type=int, default=5)
parser.add_argument("--component-warmup", type=int, default=1)
parser.add_argument("--component-repeats", type=int, default=3)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--output", type=Path, default=None)
args = parser.parse_args()
if args.device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
dtype = DTYPES[args.dtype]
if device.type == "cuda":
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
results = [
profile_architecture(args, architecture=architecture.strip(), device=device, dtype=dtype)
for architecture in args.architectures.split(",")
if architecture.strip()
]
report = {
"metadata": {
"python": platform.python_version(),
"platform": platform.platform(),
"torch": torch.__version__,
"cuda_available": torch.cuda.is_available(),
"cuda_device": torch.cuda.get_device_name(device) if device.type == "cuda" else None,
"device": str(device),
"dtype": str(dtype).replace("torch.", ""),
"args": vars(args) | {"output": str(args.output) if args.output else None},
},
"results": results,
}
text = json.dumps(report, indent=2, sort_keys=True, default=str)
print(text)
if args.output is not None:
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(text, encoding="utf-8")
return 0
if __name__ == "__main__":
raise SystemExit(main())