Text Generation
Transformers
PyTorch
English
taonet_mini_t2
taonet
taotern
ssm
state-space-model
dplr
custom_code
experimental
Instructions to use TaoTern/TaoNet-mini-T2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TaoTern/TaoNet-mini-T2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="TaoTern/TaoNet-mini-T2", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("TaoTern/TaoNet-mini-T2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use TaoTern/TaoNet-mini-T2 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "TaoTern/TaoNet-mini-T2" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/TaoTern/TaoNet-mini-T2
- SGLang
How to use TaoTern/TaoNet-mini-T2 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "TaoTern/TaoNet-mini-T2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "TaoTern/TaoNet-mini-T2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use TaoTern/TaoNet-mini-T2 with Docker Model Runner:
docker model run hf.co/TaoTern/TaoNet-mini-T2
| """ | |
| Selective Scan Operations Interface with TileLang Acceleration. | |
| Provides unified interface for TileLang-accelerated SSM operations with automatic | |
| fallback to pure PyTorch CPU implementation. | |
| Features: | |
| - TileLang-accelerated parallel scan for efficient GPU utilization | |
| - Automatic GPU/CPU dispatch | |
| - Autograd support via custom backward pass | |
| - Memory-efficient implementation | |
| - Support for fp32, fp16, and bf16 | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| from typing import Optional, Tuple | |
| import sys | |
| from pathlib import Path | |
| # Import TileLang accelerated operations | |
| try: | |
| from csrc.tilelang import ( | |
| HAS_TILELANG_ACCELERATION, | |
| TILELANG_BACKEND, | |
| ssm_gamma_forward_tilelang, | |
| ) | |
| HAS_TILELANG_OPS = HAS_TILELANG_ACCELERATION | |
| except ImportError: | |
| try: | |
| # Alternative import path | |
| import sys | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent / 'csrc')) | |
| from tilelang import ( | |
| HAS_TILELANG_ACCELERATION, | |
| TILELANG_BACKEND, | |
| ssm_gamma_forward_tilelang, | |
| ) | |
| HAS_TILELANG_OPS = HAS_TILELANG_ACCELERATION | |
| except ImportError: | |
| HAS_TILELANG_OPS = False | |
| TILELANG_BACKEND = "unavailable" | |
| ssm_gamma_forward_tilelang = None | |
| class SSMGammaFunction(torch.autograd.Function): | |
| """ | |
| Custom autograd function for SSM gamma forward/backward. | |
| Uses TileLang kernels if available, falls back to PyTorch operations. | |
| """ | |
| def forward(ctx, u, A, B, C, delta_t): | |
| """ | |
| Forward pass. | |
| Args: | |
| u: Input (batch, seq_len, state_dim) | |
| A: State matrix (hidden_dim, hidden_dim) | |
| B: Input matrix (hidden_dim, state_dim) | |
| C: Output matrix (state_dim, hidden_dim) | |
| delta_t: Discretization step | |
| Returns: | |
| output: (batch, seq_len, state_dim) | |
| h_final: (batch, hidden_dim) | |
| """ | |
| if HAS_TILELANG_OPS and u.is_cuda: | |
| # Use TileLang-accelerated kernels | |
| y, h_final = ssm_gamma_forward_tilelang( | |
| u.contiguous(), A.contiguous(), B.contiguous(), C.contiguous(), delta_t | |
| ) | |
| else: | |
| # Fallback to PyTorch implementation | |
| y, h_final = _ssm_gamma_forward_pytorch(u, A, B, C, delta_t) | |
| ctx.save_for_backward(u, A, B, C, y, h_final) | |
| ctx.delta_t = delta_t | |
| return y, h_final | |
| def backward(ctx, grad_y, grad_h_final): | |
| """Backward pass with gradient computation.""" | |
| u, A, B, C, y, h_final = ctx.saved_tensors | |
| delta_t = ctx.delta_t | |
| # Compute hidden state history from forward pass | |
| h_history = _compute_state_history(u, A, B, C, delta_t) | |
| if HAS_TILELANG_OPS and grad_y.is_cuda: | |
| # Use TileLang kernels for backward | |
| # For now, use the unified backward from tilelang | |
| grad_u, grad_A, grad_B, grad_C, grad_h_init = _ssm_gamma_backward_pytorch( | |
| grad_y, u, h_history, A, B, C, delta_t | |
| ) | |
| else: | |
| # Fallback to PyTorch backward | |
| grad_u, grad_A, grad_B, grad_C, grad_h_init = _ssm_gamma_backward_pytorch( | |
| grad_y, u, h_history, A, B, C, delta_t | |
| ) | |
| return grad_u, grad_A, grad_B, grad_C, None | |
| def ssm_gamma_forward( | |
| u: torch.Tensor, | |
| A: torch.Tensor, | |
| B: torch.Tensor, | |
| C: torch.Tensor, | |
| delta_t: float = 0.1, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| SSM Gamma forward pass using parallel scan optimization. | |
| Implements: h[t+1] = h[t] + dt * (A @ h[t] + B @ u[t]), y[t] = C @ h[t] | |
| Args: | |
| u: Input tensor (batch, seq_len, state_dim) | |
| A: State matrix (hidden_dim, hidden_dim) - tridiagonal HiPPO-Gamma | |
| B: Input matrix (hidden_dim, state_dim) | |
| C: Output matrix (state_dim, hidden_dim) | |
| delta_t: Time discretization step | |
| Returns: | |
| output: (batch, seq_len, state_dim) | |
| final_state: (batch, hidden_dim) | |
| Example: | |
| >>> u = torch.randn(2, 32, 64) # batch=2, seq=32, state_dim=64 | |
| >>> A = torch.eye(128) # 128-dim hidden state | |
| >>> B = torch.randn(128, 64) | |
| >>> C = torch.randn(64, 128) | |
| >>> y, h = ssm_gamma_forward(u, A, B, C, delta_t=0.1) | |
| >>> y.shape | |
| torch.Size([2, 32, 64]) | |
| """ | |
| # Always use custom function for autograd support | |
| return SSMGammaFunction.apply(u, A, B, C, delta_t) | |
| def _ssm_gamma_forward_pytorch( | |
| u: torch.Tensor, | |
| A: torch.Tensor, | |
| B: torch.Tensor, | |
| C: torch.Tensor, | |
| delta_t: float, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Pure PyTorch implementation of SSM forward pass.""" | |
| batch_size, seq_len, state_dim = u.shape | |
| hidden_dim = A.shape[0] | |
| device = u.device | |
| dtype = u.dtype | |
| # Initialize hidden state | |
| h = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype) | |
| outputs = [] | |
| # Process sequence | |
| for t in range(seq_len): | |
| u_t = u[:, t, :] # (batch, state_dim) | |
| # State update: h_new = h + delta_t * (A @ h + B @ u) | |
| A_h = torch.matmul(h, A.T) # (batch, hidden_dim) | |
| B_u = torch.matmul(u_t, B.T) # (batch, hidden_dim) | |
| h_new = h + delta_t * (A_h + B_u) | |
| # Output: y = C @ h_new | |
| y_t = torch.matmul(h_new, C.T) # (batch, state_dim) | |
| outputs.append(y_t) | |
| h = h_new | |
| y = torch.stack(outputs, dim=1) # (batch, seq_len, state_dim) | |
| return y, h | |
| def _compute_state_history( | |
| u: torch.Tensor, | |
| A: torch.Tensor, | |
| B: torch.Tensor, | |
| C: torch.Tensor, | |
| delta_t: float, | |
| ) -> torch.Tensor: | |
| """Compute full state history for backward pass.""" | |
| batch_size, seq_len, state_dim = u.shape | |
| hidden_dim = A.shape[0] | |
| device = u.device | |
| compute_dtype = torch.promote_types(u.dtype, A.dtype) | |
| u_compute = u.to(dtype=compute_dtype) | |
| A_compute = A.to(device=device, dtype=compute_dtype) | |
| B_compute = B.to(device=device, dtype=compute_dtype) | |
| h_history = torch.zeros(batch_size, seq_len, hidden_dim, device=device, dtype=compute_dtype) | |
| h = torch.zeros(batch_size, hidden_dim, device=device, dtype=compute_dtype) | |
| for t in range(seq_len): | |
| u_t = u_compute[:, t, :] | |
| A_h = torch.matmul(h, A_compute.T) | |
| B_u = torch.matmul(u_t, B_compute.T) | |
| h = h + delta_t * (A_h + B_u) | |
| h_history[:, t, :] = h | |
| return h_history | |
| def _ssm_gamma_backward_pytorch( | |
| grad_y: torch.Tensor, | |
| u: torch.Tensor, | |
| h_history: torch.Tensor, | |
| A: torch.Tensor, | |
| B: torch.Tensor, | |
| C: torch.Tensor, | |
| delta_t: float, | |
| ): | |
| """Pure PyTorch implementation of backward pass.""" | |
| batch_size, seq_len, state_dim = grad_y.shape | |
| hidden_dim = A.shape[0] | |
| device = grad_y.device | |
| compute_dtype = torch.promote_types( | |
| torch.promote_types(grad_y.dtype, A.dtype), | |
| torch.promote_types(B.dtype, C.dtype), | |
| ) | |
| grad_y_compute = grad_y.to(dtype=compute_dtype) | |
| u_compute = u.to(device=device, dtype=compute_dtype) | |
| h_history_compute = h_history.to(device=device, dtype=compute_dtype) | |
| A_compute = A.to(device=device, dtype=compute_dtype) | |
| B_compute = B.to(device=device, dtype=compute_dtype) | |
| C_compute = C.to(device=device, dtype=compute_dtype) | |
| grad_u = torch.zeros_like(u_compute) | |
| grad_A = torch.zeros_like(A_compute) | |
| grad_B = torch.zeros_like(B_compute) | |
| grad_C = torch.zeros_like(C_compute) | |
| grad_h = torch.zeros(batch_size, hidden_dim, device=device, dtype=compute_dtype) | |
| identity = torch.eye(hidden_dim, device=device, dtype=compute_dtype) | |
| # Backward pass through time | |
| for t in range(seq_len - 1, -1, -1): | |
| # Gradient from output | |
| grad_h_from_y = torch.matmul(grad_y_compute[:, t, :], C_compute) # (batch, hidden_dim) | |
| grad_h = grad_h + grad_h_from_y | |
| # Gradient w.r.t. C | |
| h_t = h_history_compute[:, t, :] | |
| grad_C.add_(torch.matmul(grad_y_compute[:, t, :].T, h_t) / batch_size) | |
| # Gradient w.r.t. B | |
| u_t = u_compute[:, t, :] | |
| grad_B.add_(delta_t * torch.matmul(grad_h.T, u_t) / batch_size) | |
| # Gradient w.r.t. input | |
| grad_u[:, t, :] = torch.matmul(grad_h, B_compute) | |
| # Gradient w.r.t. A | |
| prev_h = h_history_compute[:, t, :] if t > 0 else torch.zeros_like(h_history_compute[:, 0, :]) | |
| grad_A.add_(delta_t * torch.matmul(grad_h.T, prev_h) / batch_size) | |
| # Propagate gradient backward through state transition | |
| if t > 0: | |
| grad_h = torch.matmul(grad_h, A_compute.T + identity) | |
| grad_h_init = grad_h.to(dtype=u.dtype) | |
| return ( | |
| grad_u.to(dtype=u.dtype), | |
| grad_A.to(dtype=A.dtype), | |
| grad_B.to(dtype=B.dtype), | |
| grad_C.to(dtype=C.dtype), | |
| grad_h_init, | |
| ) | |
| def selective_scan_fwd( | |
| u: torch.Tensor, | |
| delta: torch.Tensor, | |
| A: torch.Tensor, | |
| B: torch.Tensor, | |
| C: torch.Tensor, | |
| D: Optional[torch.Tensor] = None, | |
| return_last_state: bool = False, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """ | |
| Selective scan forward pass (Mamba-style). | |
| Extended interface for more complex SSM variants. | |
| Currently falls back to SSM gamma for compatibility. | |
| Args: | |
| u: Input (batch, seq_len, d_model) | |
| delta: Time step (scalar or per-timestep) | |
| A: State transition matrix | |
| B: Input matrix | |
| C: Output matrix | |
| D: Direct term (optional) | |
| return_last_state: Whether to return final state | |
| Returns: | |
| output: (batch, seq_len, d_model) | |
| last_state: Final state if requested | |
| """ | |
| # Simplified version - redirect to SSM gamma | |
| output, last_state = ssm_gamma_forward(u, A, B, C, delta_t=delta if isinstance(delta, (int, float)) else delta.item()) | |
| if return_last_state: | |
| return output, last_state | |
| else: | |
| return output, None | |
| def selective_scan_bwd(): | |
| """Backward pass placeholder for extensibility.""" | |
| pass | |