Text Generation
Transformers
PyTorch
English
taonet_mini_t2
taonet
taotern
ssm
state-space-model
dplr
custom_code
experimental
Instructions to use TaoTern/TaoNet-mini-T2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TaoTern/TaoNet-mini-T2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="TaoTern/TaoNet-mini-T2", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("TaoTern/TaoNet-mini-T2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use TaoTern/TaoNet-mini-T2 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "TaoTern/TaoNet-mini-T2" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/TaoTern/TaoNet-mini-T2
- SGLang
How to use TaoTern/TaoNet-mini-T2 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use TaoTern/TaoNet-mini-T2 with Docker Model Runner:
docker model run hf.co/TaoTern/TaoNet-mini-T2
| """Profile TaoNet and TaoNet-SSM component costs on synthetic token batches. | |
| The real-token benchmark tells us end-to-end quality and throughput. This | |
| script is the companion microscope: it times forward components such as the | |
| SSM core, gates, projections, FFN, embeddings, and output head so hardware work | |
| targets the largest measured costs. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| from collections import defaultdict | |
| from contextlib import nullcontext | |
| from contextlib import redirect_stdout | |
| import io | |
| import json | |
| import os | |
| from pathlib import Path | |
| import platform | |
| import sys | |
| import time | |
| from typing import Any | |
| import torch | |
| REPO_ROOT = Path(__file__).resolve().parents[1] | |
| SRC_ROOT = REPO_ROOT / "src" | |
| if str(SRC_ROOT) not in sys.path: | |
| sys.path.insert(0, str(SRC_ROOT)) | |
| from taoTrain.config import ModelConfig | |
| from taoTrain.models import get_model | |
| DTYPES = { | |
| "float32": torch.float32, | |
| "fp32": torch.float32, | |
| "float16": torch.float16, | |
| "fp16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "bf16": torch.bfloat16, | |
| } | |
| def synchronize(device: torch.device) -> None: | |
| if device.type == "cuda": | |
| torch.cuda.synchronize(device) | |
| def reset_memory(device: torch.device) -> None: | |
| if device.type == "cuda": | |
| torch.cuda.reset_peak_memory_stats(device) | |
| def memory_stats(device: torch.device) -> dict[str, float | None]: | |
| if device.type != "cuda": | |
| return {"peak_allocated_mb": None, "peak_reserved_mb": None} | |
| return { | |
| "peak_allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2), | |
| "peak_reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2), | |
| } | |
| class ComponentTimer: | |
| def __init__(self, device: torch.device) -> None: | |
| self.device = device | |
| self.records: dict[str, list[float]] = defaultdict(list) | |
| self._starts: dict[int, Any] = {} | |
| self._handles = [] | |
| def _record_ms(self, name: str, start: Any) -> None: | |
| if self.device.type == "cuda": | |
| end = torch.cuda.Event(enable_timing=True) | |
| end.record() | |
| end.synchronize() | |
| self.records[name].append(float(start.elapsed_time(end))) | |
| else: | |
| self.records[name].append((time.perf_counter() - start) * 1000.0) | |
| def add(self, module: torch.nn.Module, name: str) -> None: | |
| def pre_hook(mod, inputs): | |
| del inputs | |
| if self.device.type == "cuda": | |
| start = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| else: | |
| start = time.perf_counter() | |
| self._starts[id(mod)] = start | |
| def post_hook(mod, inputs, output): | |
| del inputs, output | |
| start = self._starts.pop(id(mod), None) | |
| if start is not None: | |
| self._record_ms(name, start) | |
| self._handles.append(module.register_forward_pre_hook(pre_hook)) | |
| self._handles.append(module.register_forward_hook(post_hook)) | |
| def close(self) -> None: | |
| for handle in self._handles: | |
| handle.remove() | |
| self._handles.clear() | |
| def summary(self) -> list[dict[str, float | str | int]]: | |
| rows = [] | |
| for name, values in sorted(self.records.items()): | |
| if not values: | |
| continue | |
| rows.append( | |
| { | |
| "component": name, | |
| "calls": len(values), | |
| "mean_ms": sum(values) / len(values), | |
| "total_ms": sum(values), | |
| "min_ms": min(values), | |
| "max_ms": max(values), | |
| } | |
| ) | |
| rows.sort(key=lambda row: float(row["total_ms"]), reverse=True) | |
| return rows | |
| def build_config(args: argparse.Namespace, architecture: str) -> ModelConfig: | |
| d_latent_kv = args.d_latent_kv if args.d_latent_kv is not None else int(args.hidden_dim * 0.75) | |
| d_rope = args.d_rope if args.d_rope is not None else args.hidden_dim // args.num_heads | |
| hidden_dim_ff = args.hidden_dim_ff if args.hidden_dim_ff is not None else args.hidden_dim * 4 | |
| return ModelConfig( | |
| architecture_type=architecture, | |
| vocab_size=args.vocab_size, | |
| hidden_dim=args.hidden_dim, | |
| num_layers=args.num_layers, | |
| num_heads=args.num_heads, | |
| max_seq_length=args.seq_len, | |
| d_latent_kv=d_latent_kv, | |
| d_rope=d_rope, | |
| hidden_dim_ff=hidden_dim_ff, | |
| dropout=args.dropout, | |
| gqa_groups=args.gqa_groups, | |
| rope_scale=args.rope_scale, | |
| yarn_alpha=args.yarn_alpha, | |
| init_std=args.init_std, | |
| ssm_core=args.ssm_core, | |
| ssm_hidden_dim=args.ssm_hidden_dim, | |
| ssm_mixer_dim=args.ssm_mixer_dim, | |
| ssm_rank=args.ssm_rank, | |
| ssm_max_low_rank_scale=args.ssm_max_low_rank_scale, | |
| ssm_kernel_mode=args.ssm_kernel_mode, | |
| ssm_kernel_threshold=args.ssm_kernel_threshold, | |
| ssm_dt_min=args.ssm_dt_min, | |
| ssm_dt_max=args.ssm_dt_max, | |
| ssm_dt_init=args.ssm_dt_init, | |
| ssm_use_padding_mask=False, | |
| ssm_activation=args.ssm_activation, | |
| ssm_gate=args.ssm_gate, | |
| ssm_input_gate=args.ssm_input_gate, | |
| ssm_layer_scale_init=args.ssm_layer_scale_init, | |
| ssm_local_shift=args.ssm_local_shift, | |
| ssm_local_shift_init=args.ssm_local_shift_init, | |
| ssm_local_shift_per_channel=args.ssm_local_shift_per_channel, | |
| ) | |
| def add_component_hooks(model: torch.nn.Module, architecture: str, timer: ComponentTimer) -> None: | |
| timer.add(model.token_embedding, "embedding") | |
| timer.add(model.final_norm, "final_norm") | |
| timer.add(model.output_head, "output_head") | |
| for layer_index, block in enumerate(model.blocks): | |
| prefix = f"block{layer_index}" | |
| if architecture == "taonet_ssm": | |
| mixer = block.mixer | |
| timer.add(mixer.norm, f"{prefix}.mixer.norm") | |
| if mixer.input_gate is not None: | |
| timer.add(mixer.input_gate, f"{prefix}.mixer.input_gate") | |
| timer.add(mixer.input_proj, f"{prefix}.mixer.input_proj") | |
| timer.add(mixer.ssm, f"{prefix}.mixer.ssm_core") | |
| timer.add(mixer.activation, f"{prefix}.mixer.activation") | |
| timer.add(mixer.out_proj, f"{prefix}.mixer.out_proj") | |
| if mixer.output_gate is not None: | |
| timer.add(mixer.output_gate, f"{prefix}.mixer.output_gate") | |
| timer.add(mixer.proj_dropout, f"{prefix}.mixer.dropout") | |
| else: | |
| mla = block.mla | |
| timer.add(mla.norm, f"{prefix}.attention.norm") | |
| timer.add(mla.q_proj, f"{prefix}.attention.q_proj") | |
| timer.add(mla.k_proj, f"{prefix}.attention.k_proj") | |
| timer.add(mla.v_proj, f"{prefix}.attention.v_proj") | |
| timer.add(mla.out_proj, f"{prefix}.attention.out_proj") | |
| timer.add(mla.attn_dropout, f"{prefix}.attention.attn_dropout") | |
| timer.add(mla.proj_dropout, f"{prefix}.attention.proj_dropout") | |
| timer.add(block.ff_norm, f"{prefix}.ff.norm") | |
| timer.add(block.ff_gate, f"{prefix}.ff.gate") | |
| timer.add(block.ff_value, f"{prefix}.ff.value") | |
| timer.add(block.ff_out, f"{prefix}.ff.out") | |
| def time_repeats(fn, *, device: torch.device, warmup: int, repeats: int) -> dict[str, float]: | |
| for _ in range(warmup): | |
| fn() | |
| synchronize(device) | |
| latencies = [] | |
| for _ in range(repeats): | |
| reset_memory(device) | |
| synchronize(device) | |
| start = time.perf_counter() | |
| fn() | |
| synchronize(device) | |
| latencies.append(time.perf_counter() - start) | |
| mean_s = sum(latencies) / len(latencies) | |
| return { | |
| "mean_ms": mean_s * 1000.0, | |
| "min_ms": min(latencies) * 1000.0, | |
| "max_ms": max(latencies) * 1000.0, | |
| } | |
| def profile_architecture( | |
| args: argparse.Namespace, | |
| *, | |
| architecture: str, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> dict[str, Any]: | |
| torch.manual_seed(args.seed) | |
| if device.type == "cuda": | |
| torch.cuda.manual_seed_all(args.seed) | |
| config = build_config(args, architecture) | |
| with redirect_stdout(io.StringIO()): | |
| model = get_model(config, device=device) | |
| model.train() | |
| input_ids = torch.randint( | |
| low=0, | |
| high=args.vocab_size, | |
| size=(args.batch_size, args.seq_len), | |
| device=device, | |
| ) | |
| labels = torch.randint( | |
| low=0, | |
| high=args.vocab_size, | |
| size=(args.batch_size, args.seq_len), | |
| device=device, | |
| ) | |
| attention_mask = torch.ones_like(input_ids) | |
| autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16} | |
| def autocast_context(): | |
| if not autocast_enabled: | |
| return nullcontext() | |
| return torch.autocast(device_type=device.type, dtype=dtype, enabled=True) | |
| def forward_only() -> torch.Tensor: | |
| with torch.no_grad(): | |
| with autocast_context(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) | |
| return outputs["loss"] | |
| def forward_backward() -> torch.Tensor: | |
| model.zero_grad(set_to_none=True) | |
| with autocast_context(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) | |
| loss = outputs["loss"] | |
| loss.backward() | |
| return loss | |
| no_timer_forward = time_repeats( | |
| forward_only, | |
| device=device, | |
| warmup=args.warmup, | |
| repeats=args.repeats, | |
| ) | |
| no_timer_backward = time_repeats( | |
| forward_backward, | |
| device=device, | |
| warmup=args.warmup, | |
| repeats=args.repeats, | |
| ) | |
| timer = ComponentTimer(device) | |
| add_component_hooks(model, architecture, timer) | |
| try: | |
| for _ in range(args.component_warmup): | |
| forward_only() | |
| synchronize(device) | |
| for _ in range(args.component_repeats): | |
| forward_only() | |
| synchronize(device) | |
| finally: | |
| timer.close() | |
| tokens = args.batch_size * args.seq_len | |
| component_rows = timer.summary() | |
| return { | |
| "architecture": architecture, | |
| "total_params": sum(param.numel() for param in model.parameters()), | |
| "trainable_params": sum(param.numel() for param in model.parameters() if param.requires_grad), | |
| "forward": { | |
| **no_timer_forward, | |
| "tokens_per_s": tokens / max(no_timer_forward["mean_ms"] / 1000.0, 1e-12), | |
| }, | |
| "forward_backward": { | |
| **no_timer_backward, | |
| "tokens_per_s": tokens / max(no_timer_backward["mean_ms"] / 1000.0, 1e-12), | |
| **memory_stats(device), | |
| }, | |
| "components_forward": component_rows, | |
| } | |
| def main() -> int: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument("--architectures", default="taonet,taonet_ssm") | |
| parser.add_argument("--vocab-size", type=int, default=8192) | |
| parser.add_argument("--batch-size", type=int, default=32) | |
| parser.add_argument("--seq-len", type=int, default=512) | |
| parser.add_argument("--hidden-dim", type=int, default=256) | |
| parser.add_argument("--num-layers", type=int, default=4) | |
| parser.add_argument("--num-heads", type=int, default=4) | |
| parser.add_argument("--d-latent-kv", type=int, default=None) | |
| parser.add_argument("--d-rope", type=int, default=None) | |
| parser.add_argument("--hidden-dim-ff", type=int, default=None) | |
| parser.add_argument("--dropout", type=float, default=0.0) | |
| parser.add_argument("--gqa-groups", type=int, default=1) | |
| parser.add_argument("--rope-scale", type=float, default=40.0) | |
| parser.add_argument("--yarn-alpha", type=float, default=1.0) | |
| parser.add_argument("--init-std", type=float, default=0.02) | |
| parser.add_argument("--ssm-core", choices=["gamma_s4", "dplr"], default="dplr") | |
| parser.add_argument("--ssm-hidden-dim", type=int, default=16) | |
| parser.add_argument("--ssm-mixer-dim", type=int, default=128) | |
| parser.add_argument("--ssm-rank", type=int, default=1) | |
| parser.add_argument("--ssm-max-low-rank-scale", type=float, default=0.1) | |
| parser.add_argument("--ssm-kernel-mode", choices=["auto", "conv", "conv_transfer", "recurrent"], default="conv") | |
| parser.add_argument("--ssm-kernel-threshold", type=int, default=1) | |
| parser.add_argument("--ssm-dt-min", type=float, default=1e-3) | |
| parser.add_argument("--ssm-dt-max", type=float, default=1e-1) | |
| parser.add_argument("--ssm-dt-init", type=float, default=1e-2) | |
| parser.add_argument("--ssm-activation", choices=["gelu", "silu", "identity", "linear"], default="gelu") | |
| parser.add_argument("--ssm-gate", action=argparse.BooleanOptionalAction, default=True) | |
| parser.add_argument("--ssm-input-gate", action=argparse.BooleanOptionalAction, default=True) | |
| parser.add_argument("--ssm-layer-scale-init", type=float, default=0.1) | |
| parser.add_argument("--ssm-local-shift", action=argparse.BooleanOptionalAction, default=True) | |
| parser.add_argument("--ssm-local-shift-init", type=float, default=0.1) | |
| parser.add_argument("--ssm-local-shift-per-channel", action=argparse.BooleanOptionalAction, default=True) | |
| parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16") | |
| parser.add_argument("--device", default="auto") | |
| parser.add_argument("--warmup", type=int, default=2) | |
| parser.add_argument("--repeats", type=int, default=5) | |
| parser.add_argument("--component-warmup", type=int, default=1) | |
| parser.add_argument("--component-repeats", type=int, default=3) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--output", type=Path, default=None) | |
| args = parser.parse_args() | |
| if args.device == "auto": | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| device = torch.device(args.device) | |
| dtype = DTYPES[args.dtype] | |
| if device.type == "cuda": | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| results = [ | |
| profile_architecture(args, architecture=architecture.strip(), device=device, dtype=dtype) | |
| for architecture in args.architectures.split(",") | |
| if architecture.strip() | |
| ] | |
| report = { | |
| "metadata": { | |
| "python": platform.python_version(), | |
| "platform": platform.platform(), | |
| "torch": torch.__version__, | |
| "cuda_available": torch.cuda.is_available(), | |
| "cuda_device": torch.cuda.get_device_name(device) if device.type == "cuda" else None, | |
| "device": str(device), | |
| "dtype": str(dtype).replace("torch.", ""), | |
| "args": vars(args) | {"output": str(args.output) if args.output else None}, | |
| }, | |
| "results": results, | |
| } | |
| text = json.dumps(report, indent=2, sort_keys=True, default=str) | |
| print(text) | |
| if args.output is not None: | |
| args.output.parent.mkdir(parents=True, exist_ok=True) | |
| args.output.write_text(text, encoding="utf-8") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |