Text Generation
Transformers
PyTorch
English
taonet_mini_t2
taonet
taotern
ssm
state-space-model
dplr
custom_code
experimental
Instructions to use TaoTern/TaoNet-mini-T2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TaoTern/TaoNet-mini-T2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="TaoTern/TaoNet-mini-T2", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("TaoTern/TaoNet-mini-T2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use TaoTern/TaoNet-mini-T2 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "TaoTern/TaoNet-mini-T2" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/TaoTern/TaoNet-mini-T2
- SGLang
How to use TaoTern/TaoNet-mini-T2 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use TaoTern/TaoNet-mini-T2 with Docker Model Runner:
docker model run hf.co/TaoTern/TaoNet-mini-T2
| """Break down the DPLR direct frequency path into timed forward stages. | |
| The whole-path profiler tells us whether the direct convolution path is fast, | |
| but not which internal tensor operation should become the next TileLang/Triton | |
| target. This script mirrors ``S4TernaryDPLRSSM._apply_frequency_response`` and | |
| records per-stage timings without changing model behavior. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import math | |
| import statistics | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Callable | |
| import torch | |
| REPO_ROOT = Path(__file__).resolve().parents[1] | |
| if str(REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(REPO_ROOT)) | |
| from gamma_space_model import S4TernaryDPLRSSM | |
| DTYPES = { | |
| "fp32": torch.float32, | |
| "float32": torch.float32, | |
| "bf16": torch.bfloat16, | |
| "bfloat16": torch.bfloat16, | |
| "fp16": torch.float16, | |
| "float16": torch.float16, | |
| } | |
| def synchronize(device: torch.device) -> None: | |
| if device.type == "cuda": | |
| torch.cuda.synchronize(device) | |
| def summarize(values: list[float]) -> dict[str, float]: | |
| return { | |
| "mean_ms": statistics.fmean(values), | |
| "min_ms": min(values), | |
| "max_ms": max(values), | |
| "stdev_ms": statistics.pstdev(values) if len(values) > 1 else 0.0, | |
| } | |
| class StageRecorder: | |
| def __init__(self, device: torch.device) -> None: | |
| self.device = device | |
| self.cuda = device.type == "cuda" | |
| self.events: list[tuple[str, torch.cuda.Event, torch.cuda.Event]] = [] | |
| self.cpu_times: list[tuple[str, float]] = [] | |
| def measure(self, name: str, fn: Callable[[], Any]) -> Any: | |
| if self.cuda: | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| value = fn() | |
| end.record() | |
| self.events.append((name, start, end)) | |
| return value | |
| start_time = time.perf_counter() | |
| value = fn() | |
| self.cpu_times.append((name, (time.perf_counter() - start_time) * 1000.0)) | |
| return value | |
| def results(self) -> dict[str, float]: | |
| if self.cuda: | |
| torch.cuda.synchronize(self.device) | |
| return {name: start.elapsed_time(end) for name, start, end in self.events} | |
| return dict(self.cpu_times) | |
| def run_profiled_direct( | |
| model: S4TernaryDPLRSSM, | |
| x: torch.Tensor, | |
| *, | |
| seq_len: int, | |
| fft_len: int, | |
| target_dtype: torch.dtype, | |
| device: torch.device, | |
| ) -> tuple[torch.Tensor, dict[str, float]]: | |
| recorder = StageRecorder(device) | |
| def input_fft() -> tuple[torch.Tensor, torch.Tensor]: | |
| u_channels = x.transpose(1, 2).to(dtype=target_dtype) | |
| return u_channels, torch.fft.rfft(u_channels, n=fft_len) | |
| u_channels, u_f = recorder.measure("input_fft", input_fft) | |
| diag, U, V, B_disc = recorder.measure( | |
| "discrete_params", | |
| lambda: model._discrete_params(dtype=target_dtype, device=device), | |
| ) | |
| def matrix_power() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| A_dense = model._dense_discrete_A_from_params(diag, U, V) | |
| A_power = torch.linalg.matrix_power(A_dense, seq_len) | |
| C = model.C.to(device=device, dtype=target_dtype) | |
| D = model.D.to(device=device, dtype=target_dtype) | |
| return A_power, C, D | |
| A_power, C, D = recorder.measure("dense_A_power_C_D", matrix_power) | |
| complex_dtype = torch.complex64 if target_dtype != torch.float64 else torch.complex128 | |
| freq_count = fft_len // 2 + 1 | |
| def roots_and_casts() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| roots, roots_power = model._frequency_roots(seq_len, fft_len, target_dtype, device) | |
| return ( | |
| roots, | |
| roots_power, | |
| diag.to(dtype=complex_dtype), | |
| U.to(dtype=complex_dtype), | |
| V.to(dtype=complex_dtype), | |
| B_disc.to(dtype=complex_dtype), | |
| C.to(dtype=complex_dtype), | |
| ) | |
| ( | |
| roots, | |
| roots_power, | |
| diag_complex, | |
| U_complex, | |
| V_complex, | |
| B_complex, | |
| C_complex, | |
| ) = recorder.measure("roots_and_complex_casts", roots_and_casts) | |
| def diagonal_input_solve() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| u_freq = u_f.permute(2, 0, 1).to(dtype=complex_dtype) | |
| denom = 1.0 - roots[:, None] * diag_complex[None, :] | |
| inv_diag = denom.reciprocal() | |
| input_term = torch.einsum("nd,fbd->fbn", B_complex, u_freq) | |
| inv_input = inv_diag[:, None, :] * input_term | |
| return u_freq, inv_diag, inv_input | |
| u_freq, inv_diag, inv_input = recorder.measure("diagonal_input_solve", diagonal_input_solve) | |
| def low_rank_solve() -> torch.Tensor: | |
| omega_u = roots[:, None, None] * U_complex[None, :, :] | |
| inv_u = inv_diag[:, :, None] * omega_u | |
| vt_inv_u = torch.einsum("nr,fns->frs", V_complex, inv_u) | |
| vt_inv_input = torch.einsum("nr,fbn->fbr", V_complex, inv_input) | |
| if model.rank == 1: | |
| middle = (1.0 + vt_inv_u[:, 0, 0]).reciprocal() | |
| correction = ( | |
| inv_u[:, None, :, 0] | |
| * middle.view(freq_count, 1, 1) | |
| * vt_inv_input[:, :, 0].unsqueeze(-1) | |
| ) | |
| else: | |
| rank_eye = torch.eye(model.rank, device=device, dtype=complex_dtype).expand(freq_count, -1, -1) | |
| middle = torch.linalg.inv(rank_eye + vt_inv_u) | |
| correction = torch.einsum("fns,frs,fbr->fbn", inv_u, middle, vt_inv_input) | |
| return inv_input - correction | |
| response = recorder.measure("low_rank_solve", low_rank_solve) | |
| def powered_readout() -> torch.Tensor: | |
| A_power_complex = A_power.to(dtype=complex_dtype) | |
| return torch.matmul(C_complex, A_power_complex) | |
| C_power = recorder.measure("powered_readout", powered_readout) | |
| def output_projection() -> torch.Tensor: | |
| y_freq = torch.einsum("on,fbn->fbo", C_complex, response) | |
| y_freq = y_freq - ( | |
| roots_power.view(freq_count, 1, 1) | |
| * torch.einsum("on,fbn->fbo", C_power, response) | |
| ) | |
| return y_freq + u_freq * D.to(dtype=complex_dtype).view(1, 1, -1) | |
| y_freq = recorder.measure("output_projection_and_skip", output_projection) | |
| def inverse_fft() -> torch.Tensor: | |
| y = torch.fft.irfft(y_freq.permute(1, 2, 0), n=fft_len)[..., :seq_len] | |
| return y.transpose(1, 2).to(dtype=x.dtype) | |
| y = recorder.measure("inverse_fft", inverse_fft) | |
| del u_channels | |
| return y, recorder.results() | |
| def main() -> int: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") | |
| parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16") | |
| parser.add_argument("--batch-size", type=int, default=4) | |
| parser.add_argument("--seq-len", type=int, default=512) | |
| parser.add_argument("--d-model", type=int, default=64) | |
| parser.add_argument("--hidden-dim", type=int, default=256) | |
| parser.add_argument("--rank", type=int, default=1) | |
| parser.add_argument("--warmup", type=int, default=3) | |
| parser.add_argument("--repeats", type=int, default=10) | |
| parser.add_argument("--output", type=Path, default=None) | |
| args = parser.parse_args() | |
| device = torch.device(args.device) | |
| dtype = DTYPES[args.dtype] | |
| model = S4TernaryDPLRSSM( | |
| state_dim=args.d_model, | |
| hidden_dim=args.hidden_dim, | |
| rank=args.rank, | |
| kernel_mode="conv", | |
| kernel_threshold=1, | |
| ).to(device=device) | |
| model.train() | |
| x = torch.randn(args.batch_size, args.seq_len, args.d_model, device=device, dtype=dtype) | |
| target_dtype = torch.float32 if x.dtype in {torch.float16, torch.bfloat16} else x.dtype | |
| fft_len = 1 << max(1, (2 * args.seq_len - 1).bit_length()) | |
| with torch.no_grad(), torch.autocast(device_type=device.type, enabled=False): | |
| for _ in range(args.warmup): | |
| run_profiled_direct( | |
| model, | |
| x, | |
| seq_len=args.seq_len, | |
| fft_len=fft_len, | |
| target_dtype=target_dtype, | |
| device=device, | |
| ) | |
| synchronize(device) | |
| stage_runs: dict[str, list[float]] = {} | |
| total_ms: list[float] = [] | |
| profiled_y: torch.Tensor | None = None | |
| for _ in range(args.repeats): | |
| synchronize(device) | |
| start = time.perf_counter() | |
| profiled_y, stages = run_profiled_direct( | |
| model, | |
| x, | |
| seq_len=args.seq_len, | |
| fft_len=fft_len, | |
| target_dtype=target_dtype, | |
| device=device, | |
| ) | |
| synchronize(device) | |
| total_ms.append((time.perf_counter() - start) * 1000.0) | |
| for name, value in stages.items(): | |
| stage_runs.setdefault(name, []).append(value) | |
| reference_y, _ = model._forward_convolutional(x, return_state=False) | |
| max_abs_diff = (profiled_y - reference_y).abs().max().item() if profiled_y is not None else math.nan | |
| stage_summary = {name: summarize(values) for name, values in stage_runs.items()} | |
| stage_total_mean = sum(item["mean_ms"] for item in stage_summary.values()) | |
| report: dict[str, Any] = { | |
| "config": vars(args) | {"device": str(device), "dtype": str(dtype).replace("torch.", "")}, | |
| "fft_len": fft_len, | |
| "target_dtype": str(target_dtype).replace("torch.", ""), | |
| "total_wall": summarize(total_ms), | |
| "stage_total_mean_ms": stage_total_mean, | |
| "stages": stage_summary, | |
| "validation": {"max_abs_diff_vs_forward_convolutional": max_abs_diff}, | |
| "frequency_grid_cache_entries": len(model._frequency_grid_cache), | |
| } | |
| text = json.dumps(report, indent=2, sort_keys=True, default=str) | |
| print(text) | |
| if args.output is not None: | |
| args.output.parent.mkdir(parents=True, exist_ok=True) | |
| args.output.write_text(text, encoding="utf-8") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |