Text Generation
Transformers
PyTorch
English
taonet_mini_t2
taonet
taotern
ssm
state-space-model
dplr
custom_code
experimental
Instructions to use TaoTern/TaoNet-mini-T2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TaoTern/TaoNet-mini-T2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="TaoTern/TaoNet-mini-T2", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("TaoTern/TaoNet-mini-T2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use TaoTern/TaoNet-mini-T2 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "TaoTern/TaoNet-mini-T2" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/TaoTern/TaoNet-mini-T2
- SGLang
How to use TaoTern/TaoNet-mini-T2 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use TaoTern/TaoNet-mini-T2 with Docker Model Runner:
docker model run hf.co/TaoTern/TaoNet-mini-T2
| """Profile the DPLR direct path with and without the finite-tail correction. | |
| This diagnostic does not change model behavior. It answers whether the exact | |
| finite convolution term | |
| C @ response - z^L (C @ A^L) @ response | |
| is a promising speed target or a mathematically important part we should keep. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| REPO_ROOT = Path(__file__).resolve().parents[1] | |
| if str(REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(REPO_ROOT)) | |
| from gamma_space_model import S4TernaryDPLRSSM | |
| DTYPES = { | |
| "fp32": torch.float32, | |
| "float32": torch.float32, | |
| "bf16": torch.bfloat16, | |
| "bfloat16": torch.bfloat16, | |
| "fp16": torch.float16, | |
| "float16": torch.float16, | |
| } | |
| def synchronize(device: torch.device) -> None: | |
| if device.type == "cuda": | |
| torch.cuda.synchronize(device) | |
| def summarize(latencies: list[float], tokens: int) -> dict[str, float]: | |
| mean_s = sum(latencies) / len(latencies) | |
| return { | |
| "mean_ms": mean_s * 1000.0, | |
| "min_ms": min(latencies) * 1000.0, | |
| "tokens_per_s": tokens / max(mean_s, 1e-12), | |
| } | |
| def dplr_direct( | |
| model: S4TernaryDPLRSSM, | |
| x: torch.Tensor, | |
| *, | |
| finite_tail: bool, | |
| ) -> torch.Tensor: | |
| batch, seq_len, _ = x.shape | |
| del batch | |
| original_dtype = x.dtype | |
| target_dtype = torch.float32 if x.dtype in {torch.float16, torch.bfloat16} else x.dtype | |
| fft_len = 1 << max(1, (2 * seq_len - 1).bit_length()) | |
| device = x.device | |
| with torch.autocast(device_type=device.type, enabled=False): | |
| u_channels = x.transpose(1, 2).to(dtype=target_dtype) | |
| u_f = torch.fft.rfft(u_channels, n=fft_len) | |
| diag, U, V, B_disc = model._discrete_params(dtype=target_dtype, device=device) | |
| A_dense = model._dense_discrete_A_from_params(diag, U, V) | |
| C = model.C.to(device=device, dtype=target_dtype) | |
| D = model.D.to(device=device, dtype=target_dtype) | |
| A_power = torch.linalg.matrix_power(A_dense, seq_len) if finite_tail else None | |
| complex_dtype = torch.complex64 if target_dtype != torch.float64 else torch.complex128 | |
| freq_count = fft_len // 2 + 1 | |
| roots, roots_power = model._frequency_roots(seq_len, fft_len, target_dtype, device) | |
| diag_complex = diag.to(dtype=complex_dtype) | |
| U_complex = U.to(dtype=complex_dtype) | |
| V_complex = V.to(dtype=complex_dtype) | |
| B_complex = B_disc.to(dtype=complex_dtype) | |
| C_complex = C.to(dtype=complex_dtype) | |
| u_freq = u_f.permute(2, 0, 1).to(dtype=complex_dtype) | |
| denom = 1.0 - roots[:, None] * diag_complex[None, :] | |
| inv_diag = denom.reciprocal() | |
| input_term = torch.einsum("nd,fbd->fbn", B_complex, u_freq) | |
| inv_input = inv_diag[:, None, :] * input_term | |
| omega_u = roots[:, None, None] * U_complex[None, :, :] | |
| inv_u = inv_diag[:, :, None] * omega_u | |
| vt_inv_u = torch.einsum("nr,fns->frs", V_complex, inv_u) | |
| vt_inv_input = torch.einsum("nr,fbn->fbr", V_complex, inv_input) | |
| if model.rank == 1: | |
| middle = (1.0 + vt_inv_u[:, 0, 0]).reciprocal() | |
| correction = ( | |
| inv_u[:, None, :, 0] | |
| * middle.view(freq_count, 1, 1) | |
| * vt_inv_input[:, :, 0].unsqueeze(-1) | |
| ) | |
| else: | |
| rank_eye = torch.eye(model.rank, device=device, dtype=complex_dtype).expand(freq_count, -1, -1) | |
| middle = torch.linalg.inv(rank_eye + vt_inv_u) | |
| correction = torch.einsum("fns,frs,fbr->fbn", inv_u, middle, vt_inv_input) | |
| response = inv_input - correction | |
| y_freq = torch.einsum("on,fbn->fbo", C_complex, response) | |
| if finite_tail: | |
| assert A_power is not None | |
| A_power_complex = A_power.to(dtype=complex_dtype) | |
| powered_readout = torch.matmul(C_complex, A_power_complex) | |
| y_freq = y_freq - ( | |
| roots_power.view(freq_count, 1, 1) | |
| * torch.einsum("on,fbn->fbo", powered_readout, response) | |
| ) | |
| y_freq = y_freq + u_freq * D.to(dtype=complex_dtype).view(1, 1, -1) | |
| y = torch.fft.irfft(y_freq.permute(1, 2, 0), n=fft_len)[..., :seq_len] | |
| return y.transpose(1, 2).to(dtype=original_dtype) | |
| def time_variant( | |
| fn, | |
| *, | |
| device: torch.device, | |
| warmup: int, | |
| repeats: int, | |
| tokens: int, | |
| ) -> dict[str, float]: | |
| for _ in range(warmup): | |
| fn() | |
| synchronize(device) | |
| latencies = [] | |
| for _ in range(repeats): | |
| synchronize(device) | |
| start = time.perf_counter() | |
| fn() | |
| synchronize(device) | |
| latencies.append(time.perf_counter() - start) | |
| return summarize(latencies, tokens) | |
| def main() -> int: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") | |
| parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16") | |
| parser.add_argument("--batch-size", type=int, default=32) | |
| parser.add_argument("--seq-len", type=int, default=512) | |
| parser.add_argument("--d-model", type=int, default=64) | |
| parser.add_argument("--hidden-dim", type=int, default=256) | |
| parser.add_argument("--rank", type=int, default=1) | |
| parser.add_argument("--warmup", type=int, default=3) | |
| parser.add_argument("--repeats", type=int, default=10) | |
| parser.add_argument("--output", type=Path, default=None) | |
| args = parser.parse_args() | |
| device = torch.device(args.device) | |
| dtype = DTYPES[args.dtype] | |
| model = S4TernaryDPLRSSM( | |
| state_dim=args.d_model, | |
| hidden_dim=args.hidden_dim, | |
| rank=args.rank, | |
| kernel_mode="conv", | |
| kernel_threshold=1, | |
| ).to(device=device) | |
| model.train() | |
| x = torch.randn(args.batch_size, args.seq_len, args.d_model, device=device, dtype=dtype) | |
| tokens = args.batch_size * args.seq_len | |
| def exact_forward() -> torch.Tensor: | |
| return dplr_direct(model, x, finite_tail=True) | |
| def ablated_forward() -> torch.Tensor: | |
| return dplr_direct(model, x, finite_tail=False) | |
| def exact_backward() -> None: | |
| model.zero_grad(set_to_none=True) | |
| y = exact_forward() | |
| y.square().mean().backward() | |
| def ablated_backward() -> None: | |
| model.zero_grad(set_to_none=True) | |
| y = ablated_forward() | |
| y.square().mean().backward() | |
| with torch.no_grad(): | |
| y_exact = exact_forward() | |
| y_ablated = ablated_forward() | |
| y_reference, _ = model._forward_convolutional(x, return_state=False) | |
| diff = (y_exact.float() - y_ablated.float()).abs() | |
| reference_diff = (y_exact.float() - y_reference.float()).abs() | |
| exact_norm = y_exact.float().norm().item() | |
| diff_norm = diff.norm().item() | |
| report: dict[str, Any] = { | |
| "config": vars(args) | {"device": str(device), "dtype": str(dtype).replace("torch.", "")}, | |
| "forward": { | |
| "exact": time_variant( | |
| exact_forward, | |
| device=device, | |
| warmup=args.warmup, | |
| repeats=args.repeats, | |
| tokens=tokens, | |
| ), | |
| "finite_tail_ablated": time_variant( | |
| ablated_forward, | |
| device=device, | |
| warmup=args.warmup, | |
| repeats=args.repeats, | |
| tokens=tokens, | |
| ), | |
| }, | |
| "forward_backward": { | |
| "exact": time_variant( | |
| exact_backward, | |
| device=device, | |
| warmup=args.warmup, | |
| repeats=args.repeats, | |
| tokens=tokens, | |
| ), | |
| "finite_tail_ablated": time_variant( | |
| ablated_backward, | |
| device=device, | |
| warmup=args.warmup, | |
| repeats=args.repeats, | |
| tokens=tokens, | |
| ), | |
| }, | |
| "difference": { | |
| "max_abs": diff.max().item(), | |
| "mean_abs": diff.mean().item(), | |
| "exact_norm": exact_norm, | |
| "diff_norm": diff_norm, | |
| "relative_l2": diff_norm / max(exact_norm, 1e-12), | |
| "exact_vs_production_max_abs": reference_diff.max().item(), | |
| }, | |
| "frequency_grid_cache_entries": len(model._frequency_grid_cache), | |
| } | |
| text = json.dumps(report, indent=2, sort_keys=True, default=str) | |
| print(text) | |
| if args.output is not None: | |
| args.output.parent.mkdir(parents=True, exist_ok=True) | |
| args.output.write_text(text, encoding="utf-8") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |