Text Generation
Transformers
PyTorch
English
taonet_mini_t2
taonet
taotern
ssm
state-space-model
dplr
custom_code
experimental
Instructions to use TaoTern/TaoNet-mini-T2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TaoTern/TaoNet-mini-T2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="TaoTern/TaoNet-mini-T2", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("TaoTern/TaoNet-mini-T2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use TaoTern/TaoNet-mini-T2 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "TaoTern/TaoNet-mini-T2" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/TaoTern/TaoNet-mini-T2
- SGLang
How to use TaoTern/TaoNet-mini-T2 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use TaoTern/TaoNet-mini-T2 with Docker Model Runner:
docker model run hf.co/TaoTern/TaoNet-mini-T2
| """Token-level benchmark for TaoNet attention vs TaoNet-SSM. | |
| The goal is to compare the two LLM wrappers with the same outer dimensions: | |
| original MLA attention TaoNet versus TaoNet with an SSM mixer. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| from contextlib import nullcontext | |
| from contextlib import redirect_stdout | |
| import csv | |
| import io | |
| import json | |
| import os | |
| from pathlib import Path | |
| import platform | |
| import subprocess | |
| import sys | |
| import time | |
| from typing import Any | |
| import torch | |
| REPO_ROOT = Path(__file__).resolve().parents[1] | |
| SRC_ROOT = REPO_ROOT / "src" | |
| if str(SRC_ROOT) not in sys.path: | |
| sys.path.insert(0, str(SRC_ROOT)) | |
| from taoTrain.config import ModelConfig | |
| from taoTrain.models import get_model | |
| DTYPES = { | |
| "float32": torch.float32, | |
| "fp32": torch.float32, | |
| "float16": torch.float16, | |
| "fp16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "bf16": torch.bfloat16, | |
| } | |
| def parse_int_list(value: str) -> list[int]: | |
| return [int(item.strip()) for item in value.split(",") if item.strip()] | |
| def synchronize(device: torch.device) -> None: | |
| if device.type == "cuda": | |
| torch.cuda.synchronize(device) | |
| def reset_memory(device: torch.device) -> None: | |
| if device.type == "cuda": | |
| torch.cuda.reset_peak_memory_stats(device) | |
| def memory_stats(device: torch.device) -> dict[str, float | None]: | |
| if device.type != "cuda": | |
| return { | |
| "peak_allocated_mb": None, | |
| "peak_reserved_mb": None, | |
| } | |
| return { | |
| "peak_allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2), | |
| "peak_reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2), | |
| } | |
| def nvidia_smi_snapshot() -> str | None: | |
| try: | |
| completed = subprocess.run( | |
| [ | |
| "nvidia-smi", | |
| "--query-gpu=name,memory.used,memory.total,utilization.gpu,utilization.memory,power.draw,temperature.gpu", | |
| "--format=csv,noheader,nounits", | |
| ], | |
| check=False, | |
| capture_output=True, | |
| text=True, | |
| timeout=5, | |
| ) | |
| except (OSError, subprocess.TimeoutExpired): | |
| return None | |
| if completed.returncode != 0: | |
| return None | |
| return completed.stdout.strip() | |
| def make_token_batch( | |
| *, | |
| batch_size: int, | |
| seq_len: int, | |
| vocab_size: int, | |
| device: torch.device, | |
| task: str = "random", | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| if task == "random": | |
| input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) | |
| labels = torch.empty_like(input_ids) | |
| labels[:, :-1] = input_ids[:, 1:] | |
| labels[:, -1] = torch.randint(0, vocab_size, (batch_size,), device=device) | |
| elif task == "increment": | |
| starts = torch.randint(0, vocab_size, (batch_size, 1), device=device) | |
| offsets = torch.arange(seq_len, device=device).view(1, seq_len) | |
| input_ids = (starts + offsets) % vocab_size | |
| labels = (input_ids + 1) % vocab_size | |
| elif task == "previous": | |
| input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) | |
| labels = torch.empty_like(input_ids) | |
| labels[:, 0] = -100 | |
| labels[:, 1:] = input_ids[:, :-1] | |
| else: | |
| raise ValueError(f"Unsupported token task '{task}'.") | |
| attention_mask = torch.ones_like(input_ids) | |
| return input_ids, labels, attention_mask | |
| def token_accuracy(logits: torch.Tensor, labels: torch.Tensor) -> float: | |
| valid = labels != -100 | |
| if not torch.any(valid): | |
| return float("nan") | |
| predictions = torch.argmax(logits, dim=-1) | |
| correct = (predictions == labels) & valid | |
| return float(correct.sum().detach().cpu() / valid.sum().detach().cpu()) | |
| def build_config(args: argparse.Namespace, architecture: str) -> ModelConfig: | |
| d_latent_kv = args.d_latent_kv if args.d_latent_kv is not None else int(args.hidden_dim * 0.75) | |
| d_rope = args.d_rope if args.d_rope is not None else args.hidden_dim // args.num_heads | |
| hidden_dim_ff = args.hidden_dim_ff if args.hidden_dim_ff is not None else args.hidden_dim * 4 | |
| return ModelConfig( | |
| architecture_type=architecture, | |
| vocab_size=args.vocab_size, | |
| hidden_dim=args.hidden_dim, | |
| num_layers=args.num_layers, | |
| num_heads=args.num_heads, | |
| max_seq_length=max(parse_int_list(args.seq_lens)), | |
| d_latent_kv=d_latent_kv, | |
| d_rope=d_rope, | |
| hidden_dim_ff=hidden_dim_ff, | |
| dropout=args.dropout, | |
| gqa_groups=args.gqa_groups, | |
| rope_scale=args.rope_scale, | |
| yarn_alpha=args.yarn_alpha, | |
| init_std=args.init_std, | |
| ssm_core=args.ssm_core, | |
| ssm_hidden_dim=args.ssm_hidden_dim or d_latent_kv, | |
| ssm_mixer_dim=args.ssm_mixer_dim, | |
| ssm_rank=args.ssm_rank, | |
| ssm_max_low_rank_scale=args.ssm_max_low_rank_scale, | |
| ssm_kernel_mode=args.ssm_kernel_mode, | |
| ssm_kernel_threshold=args.ssm_kernel_threshold, | |
| ssm_dt_min=args.ssm_dt_min, | |
| ssm_dt_max=args.ssm_dt_max, | |
| ssm_dt_init=args.ssm_dt_init, | |
| ssm_use_padding_mask=args.ssm_use_padding_mask, | |
| ssm_activation=args.ssm_activation, | |
| ssm_gate=args.ssm_gate, | |
| ssm_input_gate=args.ssm_input_gate, | |
| ssm_layer_scale_init=args.ssm_layer_scale_init, | |
| ssm_local_shift=args.ssm_local_shift, | |
| ssm_local_shift_init=args.ssm_local_shift_init, | |
| ssm_local_shift_per_channel=args.ssm_local_shift_per_channel, | |
| ) | |
| def count_params(model: torch.nn.Module) -> tuple[int, int]: | |
| total = sum(param.numel() for param in model.parameters()) | |
| trainable = sum(param.numel() for param in model.parameters() if param.requires_grad) | |
| return total, trainable | |
| def time_repeats(fn, *, device: torch.device, warmup: int, repeats: int) -> tuple[float, float, float]: | |
| last_loss = float("nan") | |
| for _ in range(warmup): | |
| last_loss = fn() | |
| synchronize(device) | |
| latencies = [] | |
| for _ in range(repeats): | |
| reset_memory(device) | |
| synchronize(device) | |
| start = time.perf_counter() | |
| last_loss = fn() | |
| synchronize(device) | |
| latencies.append(time.perf_counter() - start) | |
| return sum(latencies) / len(latencies), min(latencies), last_loss | |
| def evaluate_model( | |
| model: torch.nn.Module, | |
| *, | |
| args: argparse.Namespace, | |
| batch_size: int, | |
| seq_len: int, | |
| device: torch.device, | |
| autocast_context, | |
| ) -> tuple[float, float]: | |
| model.eval() | |
| losses = [] | |
| accuracies = [] | |
| with torch.no_grad(): | |
| for _ in range(args.eval_batches): | |
| input_ids, labels, attention_mask = make_token_batch( | |
| batch_size=batch_size, | |
| seq_len=seq_len, | |
| vocab_size=args.vocab_size, | |
| device=device, | |
| task=args.token_task, | |
| ) | |
| with autocast_context(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) | |
| losses.append(float(outputs["loss"].detach().cpu())) | |
| accuracies.append(token_accuracy(outputs["logits"], labels)) | |
| model.train() | |
| return sum(losses) / len(losses), sum(accuracies) / len(accuracies) | |
| def train_model( | |
| model: torch.nn.Module, | |
| *, | |
| args: argparse.Namespace, | |
| batch_size: int, | |
| seq_len: int, | |
| device: torch.device, | |
| autocast_context, | |
| ) -> tuple[float | None, float | None]: | |
| if args.train_steps <= 0: | |
| return None, None | |
| model.train() | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=args.learning_rate, | |
| weight_decay=args.weight_decay, | |
| ) | |
| last_loss = float("nan") | |
| start = time.perf_counter() | |
| for _ in range(args.train_steps): | |
| input_ids, labels, attention_mask = make_token_batch( | |
| batch_size=batch_size, | |
| seq_len=seq_len, | |
| vocab_size=args.vocab_size, | |
| device=device, | |
| task=args.token_task, | |
| ) | |
| optimizer.zero_grad(set_to_none=True) | |
| with autocast_context(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) | |
| loss = outputs["loss"] | |
| loss.backward() | |
| optimizer.step() | |
| last_loss = float(loss.detach().cpu()) | |
| synchronize(device) | |
| return last_loss, time.perf_counter() - start | |
| def benchmark_case( | |
| *, | |
| args: argparse.Namespace, | |
| architecture: str, | |
| batch_size: int, | |
| seq_len: int, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| ) -> list[dict[str, Any]]: | |
| config = build_config(args, architecture) | |
| with redirect_stdout(io.StringIO()): | |
| model = get_model(config, device=device) | |
| model.train() | |
| total_params, trainable_params = count_params(model) | |
| tokens = batch_size * seq_len | |
| input_ids, labels, attention_mask = make_token_batch( | |
| batch_size=batch_size, | |
| seq_len=seq_len, | |
| vocab_size=args.vocab_size, | |
| device=device, | |
| task=args.token_task, | |
| ) | |
| autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16} | |
| def autocast_context(): | |
| if not autocast_enabled: | |
| return nullcontext() | |
| return torch.autocast(device_type=device.type, dtype=dtype, enabled=True) | |
| train_final_loss, train_seconds = train_model( | |
| model, | |
| args=args, | |
| batch_size=batch_size, | |
| seq_len=seq_len, | |
| device=device, | |
| autocast_context=autocast_context, | |
| ) | |
| eval_loss, eval_accuracy = evaluate_model( | |
| model, | |
| args=args, | |
| batch_size=batch_size, | |
| seq_len=seq_len, | |
| device=device, | |
| autocast_context=autocast_context, | |
| ) | |
| rows: list[dict[str, Any]] = [] | |
| def add_row(mode: str, mean_s: float, min_s: float, loss: float) -> None: | |
| rows.append( | |
| { | |
| "architecture": architecture, | |
| "ssm_core": args.ssm_core if architecture == "taonet_ssm" else None, | |
| "token_task": args.token_task, | |
| "train_steps": args.train_steps, | |
| "mode": mode, | |
| "batch_size": batch_size, | |
| "seq_len": seq_len, | |
| "tokens": tokens, | |
| "vocab_size": args.vocab_size, | |
| "hidden_dim": args.hidden_dim, | |
| "num_layers": args.num_layers, | |
| "num_heads": args.num_heads, | |
| "d_latent_kv": config.d_latent_kv, | |
| "ssm_hidden_dim": config.ssm_hidden_dim if architecture == "taonet_ssm" else None, | |
| "ssm_mixer_dim": config.ssm_mixer_dim if architecture == "taonet_ssm" else None, | |
| "ssm_rank": config.ssm_rank if architecture == "taonet_ssm" else None, | |
| "ssm_local_shift": config.ssm_local_shift if architecture == "taonet_ssm" else None, | |
| "ssm_local_shift_init": config.ssm_local_shift_init if architecture == "taonet_ssm" else None, | |
| "ssm_local_shift_per_channel": config.ssm_local_shift_per_channel if architecture == "taonet_ssm" else None, | |
| "dtype": str(dtype).replace("torch.", ""), | |
| "device": str(device), | |
| "total_params": total_params, | |
| "trainable_params": trainable_params, | |
| "mean_ms": mean_s * 1000.0, | |
| "min_ms": min_s * 1000.0, | |
| "tokens_per_s_mean": tokens / max(mean_s, 1e-12), | |
| "tokens_per_s_best": tokens / max(min_s, 1e-12), | |
| "loss": loss, | |
| "eval_loss": eval_loss, | |
| "eval_accuracy": eval_accuracy, | |
| "train_final_loss": train_final_loss, | |
| "train_seconds": train_seconds, | |
| **memory_stats(device), | |
| } | |
| ) | |
| def forward_only() -> float: | |
| with torch.no_grad(): | |
| with autocast_context(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) | |
| loss = outputs["loss"] | |
| return float(loss.detach().cpu()) | |
| mean_s, min_s, loss = time_repeats( | |
| forward_only, | |
| device=device, | |
| warmup=args.warmup, | |
| repeats=args.repeats, | |
| ) | |
| add_row("forward", mean_s, min_s, loss) | |
| if args.backward: | |
| def forward_backward() -> float: | |
| model.zero_grad(set_to_none=True) | |
| with autocast_context(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) | |
| loss = outputs["loss"] | |
| loss.backward() | |
| return float(loss.detach().cpu()) | |
| mean_s, min_s, loss = time_repeats( | |
| forward_backward, | |
| device=device, | |
| warmup=args.warmup, | |
| repeats=args.repeats, | |
| ) | |
| add_row("forward_backward", mean_s, min_s, loss) | |
| return rows | |
| def print_table(rows: list[dict[str, Any]]) -> None: | |
| columns = [ | |
| "architecture", | |
| "ssm_core", | |
| "token_task", | |
| "mode", | |
| "batch_size", | |
| "seq_len", | |
| "mean_ms", | |
| "tokens_per_s_mean", | |
| "peak_allocated_mb", | |
| "loss", | |
| "eval_loss", | |
| "eval_accuracy", | |
| ] | |
| print("\t".join(columns)) | |
| for row in rows: | |
| values = [] | |
| for column in columns: | |
| value = row[column] | |
| if isinstance(value, float): | |
| values.append(f"{value:.3f}") | |
| else: | |
| values.append(str(value)) | |
| print("\t".join(values)) | |
| def write_outputs(rows: list[dict[str, Any]], output_dir: Path, metadata: dict[str, Any]) -> None: | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| json_path = output_dir / "taonet_token_benchmark.json" | |
| csv_path = output_dir / "taonet_token_benchmark.csv" | |
| json_path.write_text(json.dumps({"metadata": metadata, "results": rows}, indent=2), encoding="utf-8") | |
| fieldnames = list(rows[0].keys()) if rows else [] | |
| with csv_path.open("w", newline="", encoding="utf-8") as handle: | |
| writer = csv.DictWriter(handle, fieldnames=fieldnames) | |
| writer.writeheader() | |
| writer.writerows(rows) | |
| print(f"Wrote {json_path}") | |
| print(f"Wrote {csv_path}") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Benchmark TaoNet attention vs TaoNet-SSM on token batches.") | |
| parser.add_argument("--architectures", default="taonet,taonet_ssm") | |
| parser.add_argument("--batch-sizes", default="1,4") | |
| parser.add_argument("--seq-lens", default="128,512") | |
| parser.add_argument("--vocab-size", type=int, default=8192) | |
| parser.add_argument("--hidden-dim", type=int, default=256) | |
| parser.add_argument("--num-layers", type=int, default=4) | |
| parser.add_argument("--num-heads", type=int, default=4) | |
| parser.add_argument("--d-latent-kv", type=int, default=None) | |
| parser.add_argument("--d-rope", type=int, default=None) | |
| parser.add_argument("--hidden-dim-ff", type=int, default=None) | |
| parser.add_argument("--dropout", type=float, default=0.0) | |
| parser.add_argument("--gqa-groups", type=int, default=1) | |
| parser.add_argument("--rope-scale", type=float, default=40.0) | |
| parser.add_argument("--yarn-alpha", type=float, default=1.0) | |
| parser.add_argument("--init-std", type=float, default=0.02) | |
| parser.add_argument("--ssm-core", choices=["gamma_s4", "dplr"], default="dplr") | |
| parser.add_argument("--ssm-hidden-dim", type=int, default=None) | |
| parser.add_argument("--ssm-mixer-dim", type=int, default=None) | |
| parser.add_argument("--ssm-rank", type=int, default=1) | |
| parser.add_argument("--ssm-max-low-rank-scale", type=float, default=0.1) | |
| parser.add_argument("--ssm-kernel-mode", choices=["auto", "conv", "conv_transfer", "recurrent"], default="conv") | |
| parser.add_argument("--ssm-kernel-threshold", type=int, default=1) | |
| parser.add_argument("--ssm-dt-min", type=float, default=1e-3) | |
| parser.add_argument("--ssm-dt-max", type=float, default=1e-1) | |
| parser.add_argument("--ssm-dt-init", type=float, default=1e-2) | |
| parser.add_argument("--ssm-use-padding-mask", action="store_true") | |
| parser.add_argument("--ssm-activation", choices=["gelu", "silu", "identity", "linear"], default="gelu") | |
| parser.add_argument("--ssm-gate", action=argparse.BooleanOptionalAction, default=True) | |
| parser.add_argument("--ssm-input-gate", action=argparse.BooleanOptionalAction, default=True) | |
| parser.add_argument("--ssm-layer-scale-init", type=float, default=0.1) | |
| parser.add_argument("--ssm-local-shift", action=argparse.BooleanOptionalAction, default=False) | |
| parser.add_argument("--ssm-local-shift-init", type=float, default=0.1) | |
| parser.add_argument("--ssm-local-shift-per-channel", action=argparse.BooleanOptionalAction, default=False) | |
| parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16") | |
| parser.add_argument("--device", default="auto") | |
| parser.add_argument("--warmup", type=int, default=2) | |
| parser.add_argument("--repeats", type=int, default=5) | |
| parser.add_argument("--backward", action="store_true") | |
| parser.add_argument("--token-task", choices=["random", "increment", "previous"], default="random") | |
| parser.add_argument("--train-steps", type=int, default=0) | |
| parser.add_argument("--learning-rate", type=float, default=3e-4) | |
| parser.add_argument("--weight-decay", type=float, default=0.01) | |
| parser.add_argument("--eval-batches", type=int, default=1) | |
| parser.add_argument("--output-dir", default=os.environ.get("REPOBRIDGE_OUTPUT_DIR", "results/token-bench")) | |
| args = parser.parse_args() | |
| if args.device == "auto": | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| device = torch.device(args.device) | |
| dtype = DTYPES[args.dtype] | |
| if device.type != "cuda" and dtype == torch.float16: | |
| raise ValueError("float16 benchmark requires CUDA.") | |
| if device.type == "cuda": | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| architectures = [item.strip() for item in args.architectures.split(",") if item.strip()] | |
| rows: list[dict[str, Any]] = [] | |
| metadata = { | |
| "python": platform.python_version(), | |
| "platform": platform.platform(), | |
| "torch": torch.__version__, | |
| "cuda_available": torch.cuda.is_available(), | |
| "cuda_device": torch.cuda.get_device_name(device) if device.type == "cuda" else None, | |
| "nvidia_smi_before": nvidia_smi_snapshot(), | |
| "args": vars(args), | |
| } | |
| for architecture in architectures: | |
| for batch_size in parse_int_list(args.batch_sizes): | |
| for seq_len in parse_int_list(args.seq_lens): | |
| print(f"Benchmarking architecture={architecture} batch={batch_size} seq={seq_len}") | |
| rows.extend( | |
| benchmark_case( | |
| args=args, | |
| architecture=architecture, | |
| batch_size=batch_size, | |
| seq_len=seq_len, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| ) | |
| metadata["nvidia_smi_after"] = nvidia_smi_snapshot() | |
| print_table(rows) | |
| write_outputs(rows, Path(args.output_dir), metadata) | |
| if __name__ == "__main__": | |
| main() | |