Text Generation
Transformers
PyTorch
English
taonet_mini_t2
taonet
taotern
ssm
state-space-model
dplr
custom_code
experimental
Instructions to use TaoTern/TaoNet-mini-T2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TaoTern/TaoNet-mini-T2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="TaoTern/TaoNet-mini-T2", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("TaoTern/TaoNet-mini-T2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use TaoTern/TaoNet-mini-T2 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "TaoTern/TaoNet-mini-T2" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/TaoTern/TaoNet-mini-T2
- SGLang
How to use TaoTern/TaoNet-mini-T2 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use TaoTern/TaoNet-mini-T2 with Docker Model Runner:
docker model run hf.co/TaoTern/TaoNet-mini-T2
| """Lightweight script benchmarks for Gamma SSM variants. | |
| This replaces the notebook-only timing loop for quick local/remote feedback. | |
| It focuses on the model kernels themselves: full-sequence forward, optional | |
| forward+backward, and optional recurrent decode. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| from contextlib import nullcontext | |
| import csv | |
| import json | |
| import os | |
| import platform | |
| import subprocess | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Iterable | |
| import torch | |
| import torch.nn as nn | |
| REPO_ROOT = Path(__file__).resolve().parents[1] | |
| if str(REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(REPO_ROOT)) | |
| from gamma_space_model import S4TernaryDPLRSSM, SSMGamma, SSMGammaS4 | |
| DTYPES = { | |
| "float32": torch.float32, | |
| "fp32": torch.float32, | |
| "float16": torch.float16, | |
| "fp16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "bf16": torch.bfloat16, | |
| } | |
| def parse_int_list(value: str) -> list[int]: | |
| return [int(item.strip()) for item in value.split(",") if item.strip()] | |
| def synchronize(device: torch.device) -> None: | |
| if device.type == "cuda": | |
| torch.cuda.synchronize(device) | |
| def cuda_memory(device: torch.device) -> dict[str, float | None]: | |
| if device.type != "cuda": | |
| return { | |
| "peak_allocated_mb": None, | |
| "peak_reserved_mb": None, | |
| } | |
| return { | |
| "peak_allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2), | |
| "peak_reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2), | |
| } | |
| def reset_cuda_memory(device: torch.device) -> None: | |
| if device.type == "cuda": | |
| torch.cuda.reset_peak_memory_stats(device) | |
| def nvidia_smi_snapshot() -> str | None: | |
| try: | |
| completed = subprocess.run( | |
| [ | |
| "nvidia-smi", | |
| "--query-gpu=name,memory.used,memory.total,utilization.gpu,utilization.memory,power.draw,temperature.gpu", | |
| "--format=csv,noheader,nounits", | |
| ], | |
| check=False, | |
| capture_output=True, | |
| text=True, | |
| timeout=5, | |
| ) | |
| except (OSError, subprocess.TimeoutExpired): | |
| return None | |
| if completed.returncode != 0: | |
| return None | |
| return completed.stdout.strip() | |
| def make_model(name: str, d_model: int, hidden_dim: int, rank: int, kernel_mode: str) -> nn.Module: | |
| if name == "baseline": | |
| return SSMGamma(state_dim=d_model, hidden_dim=hidden_dim) | |
| if name == "gamma_s4": | |
| return SSMGammaS4( | |
| state_dim=d_model, | |
| hidden_dim=hidden_dim, | |
| discretization="bilinear", | |
| kernel_mode=kernel_mode, | |
| kernel_threshold=1, | |
| ) | |
| if name == "dplr": | |
| return S4TernaryDPLRSSM( | |
| state_dim=d_model, | |
| hidden_dim=hidden_dim, | |
| rank=rank, | |
| kernel_mode=kernel_mode, | |
| kernel_threshold=1, | |
| ) | |
| raise ValueError(f"Unknown model '{name}'.") | |
| def model_forward(model: nn.Module, x: torch.Tensor, return_state: bool = False) -> torch.Tensor: | |
| if isinstance(model, SSMGamma): | |
| y, _ = model(x) | |
| else: | |
| y, _ = model(x, return_state=return_state) | |
| return y | |
| def init_state(model: nn.Module, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: | |
| return model.init_state(batch_size=batch_size, device=device, dtype=dtype) | |
| def allocate_cache( | |
| model: nn.Module, | |
| batch_size: int, | |
| seq_len: int, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> dict[str, torch.Tensor] | None: | |
| if hasattr(model, "allocate_inference_cache"): | |
| return model.allocate_inference_cache( | |
| batch_size=batch_size, | |
| seq_len=seq_len, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| return None | |
| def recurrent_forward(model: nn.Module, x: torch.Tensor) -> torch.Tensor: | |
| batch_size, seq_len, _ = x.shape | |
| state = init_state(model, batch_size=batch_size, device=x.device, dtype=x.dtype) | |
| cache = allocate_cache(model, batch_size=batch_size, seq_len=seq_len, device=x.device, dtype=x.dtype) | |
| outputs = [] | |
| for step in range(seq_len): | |
| token = x[:, step, :] | |
| if cache is None: | |
| y, state = model.step(token, state) | |
| else: | |
| y, state = model.step(token, state, cache=cache) | |
| outputs.append(y) | |
| return torch.stack(outputs, dim=1) | |
| def timed_repeats( | |
| fn, | |
| *, | |
| device: torch.device, | |
| warmup: int, | |
| repeats: int, | |
| ) -> tuple[float, float]: | |
| for _ in range(warmup): | |
| fn() | |
| synchronize(device) | |
| latencies = [] | |
| for _ in range(repeats): | |
| reset_cuda_memory(device) | |
| synchronize(device) | |
| start = time.perf_counter() | |
| fn() | |
| synchronize(device) | |
| latencies.append(time.perf_counter() - start) | |
| mean_s = sum(latencies) / len(latencies) | |
| min_s = min(latencies) | |
| return mean_s, min_s | |
| def benchmark_case( | |
| *, | |
| model_name: str, | |
| batch_size: int, | |
| seq_len: int, | |
| d_model: int, | |
| hidden_dim: int, | |
| rank: int, | |
| kernel_mode: str, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| warmup: int, | |
| repeats: int, | |
| run_backward: bool, | |
| run_recurrent: bool, | |
| ) -> list[dict[str, Any]]: | |
| model = make_model(model_name, d_model=d_model, hidden_dim=hidden_dim, rank=rank, kernel_mode=kernel_mode) | |
| model = model.to(device=device) | |
| model.train() | |
| x = torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype) | |
| tokens = batch_size * seq_len | |
| autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16} | |
| autocast_dtype = dtype if dtype in {torch.float16, torch.bfloat16} else torch.float32 | |
| rows: list[dict[str, Any]] = [] | |
| def autocast_context(): | |
| if not autocast_enabled: | |
| return nullcontext() | |
| return torch.autocast(device_type=device.type, dtype=autocast_dtype, enabled=True) | |
| def add_row(mode: str, mean_s: float, min_s: float) -> None: | |
| mem = cuda_memory(device) | |
| rows.append( | |
| { | |
| "model": model_name, | |
| "mode": mode, | |
| "batch_size": batch_size, | |
| "seq_len": seq_len, | |
| "d_model": d_model, | |
| "hidden_dim": hidden_dim, | |
| "rank": rank if model_name == "dplr" else None, | |
| "kernel_mode": kernel_mode, | |
| "dtype": str(dtype).replace("torch.", ""), | |
| "device": str(device), | |
| "mean_ms": mean_s * 1000.0, | |
| "min_ms": min_s * 1000.0, | |
| "tokens_per_s_mean": tokens / max(mean_s, 1e-12), | |
| "tokens_per_s_best": tokens / max(min_s, 1e-12), | |
| **mem, | |
| } | |
| ) | |
| def forward_only() -> None: | |
| with torch.no_grad(): | |
| with autocast_context(): | |
| y = model_forward(model, x, return_state=False) | |
| _ = y.sum() | |
| mean_s, min_s = timed_repeats(forward_only, device=device, warmup=warmup, repeats=repeats) | |
| add_row("forward", mean_s, min_s) | |
| if run_backward: | |
| def forward_backward() -> None: | |
| model.zero_grad(set_to_none=True) | |
| with autocast_context(): | |
| y = model_forward(model, x, return_state=False) | |
| loss = y.square().mean() | |
| loss.backward() | |
| mean_s, min_s = timed_repeats(forward_backward, device=device, warmup=warmup, repeats=repeats) | |
| add_row("forward_backward", mean_s, min_s) | |
| if run_recurrent: | |
| model.eval() | |
| def recurrent() -> None: | |
| with torch.no_grad(): | |
| y = recurrent_forward(model, x) | |
| _ = y.sum() | |
| mean_s, min_s = timed_repeats(recurrent, device=device, warmup=max(1, warmup // 2), repeats=repeats) | |
| add_row("recurrent", mean_s, min_s) | |
| return rows | |
| def write_outputs(rows: list[dict[str, Any]], output_dir: Path, metadata: dict[str, Any]) -> None: | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| json_path = output_dir / "ssm_variant_benchmark.json" | |
| csv_path = output_dir / "ssm_variant_benchmark.csv" | |
| json_path.write_text( | |
| json.dumps( | |
| { | |
| "metadata": metadata, | |
| "results": rows, | |
| }, | |
| indent=2, | |
| ), | |
| encoding="utf-8", | |
| ) | |
| fieldnames = list(rows[0].keys()) if rows else [] | |
| with csv_path.open("w", newline="", encoding="utf-8") as handle: | |
| writer = csv.DictWriter(handle, fieldnames=fieldnames) | |
| writer.writeheader() | |
| writer.writerows(rows) | |
| print(f"Wrote {json_path}") | |
| print(f"Wrote {csv_path}") | |
| def print_table(rows: Iterable[dict[str, Any]]) -> None: | |
| columns = [ | |
| "model", | |
| "mode", | |
| "batch_size", | |
| "seq_len", | |
| "mean_ms", | |
| "tokens_per_s_mean", | |
| "peak_allocated_mb", | |
| ] | |
| print("\t".join(columns)) | |
| for row in rows: | |
| values = [] | |
| for column in columns: | |
| value = row[column] | |
| if isinstance(value, float): | |
| values.append(f"{value:.3f}") | |
| else: | |
| values.append(str(value)) | |
| print("\t".join(values)) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Benchmark Gamma SSM variants.") | |
| parser.add_argument("--models", default="dplr,gamma_s4,baseline") | |
| parser.add_argument("--batch-sizes", default="1,4") | |
| parser.add_argument("--seq-lens", default="128,512") | |
| parser.add_argument("--d-model", type=int, default=128) | |
| parser.add_argument("--hidden-dim", type=int, default=256) | |
| parser.add_argument("--rank", type=int, default=1) | |
| parser.add_argument("--kernel-mode", choices=["auto", "conv", "recurrent"], default="conv") | |
| parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16") | |
| parser.add_argument("--device", default="auto") | |
| parser.add_argument("--warmup", type=int, default=2) | |
| parser.add_argument("--repeats", type=int, default=5) | |
| parser.add_argument("--backward", action="store_true") | |
| parser.add_argument("--recurrent", action="store_true") | |
| parser.add_argument("--output-dir", default=os.environ.get("REPOBRIDGE_OUTPUT_DIR", "output/benchmarks")) | |
| args = parser.parse_args() | |
| if args.device == "auto": | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| device = torch.device(args.device) | |
| dtype = DTYPES[args.dtype] | |
| if device.type != "cuda" and dtype == torch.float16: | |
| raise ValueError("float16 benchmark requires CUDA.") | |
| if device.type == "cuda": | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| models = [item.strip() for item in args.models.split(",") if item.strip()] | |
| batch_sizes = parse_int_list(args.batch_sizes) | |
| seq_lens = parse_int_list(args.seq_lens) | |
| metadata = { | |
| "python": platform.python_version(), | |
| "platform": platform.platform(), | |
| "torch": torch.__version__, | |
| "cuda_available": torch.cuda.is_available(), | |
| "cuda_device": torch.cuda.get_device_name(device) if device.type == "cuda" else None, | |
| "nvidia_smi_before": nvidia_smi_snapshot(), | |
| "args": vars(args), | |
| } | |
| rows: list[dict[str, Any]] = [] | |
| for model_name in models: | |
| for batch_size in batch_sizes: | |
| for seq_len in seq_lens: | |
| print(f"Benchmarking model={model_name} batch={batch_size} seq={seq_len}") | |
| rows.extend( | |
| benchmark_case( | |
| model_name=model_name, | |
| batch_size=batch_size, | |
| seq_len=seq_len, | |
| d_model=args.d_model, | |
| hidden_dim=args.hidden_dim, | |
| rank=args.rank, | |
| kernel_mode=args.kernel_mode, | |
| dtype=dtype, | |
| device=device, | |
| warmup=args.warmup, | |
| repeats=args.repeats, | |
| run_backward=args.backward, | |
| run_recurrent=args.recurrent, | |
| ) | |
| ) | |
| metadata["nvidia_smi_after"] = nvidia_smi_snapshot() | |
| print_table(rows) | |
| write_outputs(rows, Path(args.output_dir), metadata) | |
| if __name__ == "__main__": | |
| main() | |