TaoNet-mini-T2 / code /Taotern_SSM /gamma_space_model /ops /selective_scan_interface.py
StarMist0012's picture
Add files using upload-large-folder tool
e2bfccc verified
"""
Selective Scan Operations Interface with TileLang Acceleration.
Provides unified interface for TileLang-accelerated SSM operations with automatic
fallback to pure PyTorch CPU implementation.
Features:
- TileLang-accelerated parallel scan for efficient GPU utilization
- Automatic GPU/CPU dispatch
- Autograd support via custom backward pass
- Memory-efficient implementation
- Support for fp32, fp16, and bf16
"""
import torch
import torch.nn.functional as F
from typing import Optional, Tuple
import sys
from pathlib import Path
# Import TileLang accelerated operations
try:
from csrc.tilelang import (
HAS_TILELANG_ACCELERATION,
TILELANG_BACKEND,
ssm_gamma_forward_tilelang,
)
HAS_TILELANG_OPS = HAS_TILELANG_ACCELERATION
except ImportError:
try:
# Alternative import path
import sys
sys.path.insert(0, str(Path(__file__).parent.parent.parent / 'csrc'))
from tilelang import (
HAS_TILELANG_ACCELERATION,
TILELANG_BACKEND,
ssm_gamma_forward_tilelang,
)
HAS_TILELANG_OPS = HAS_TILELANG_ACCELERATION
except ImportError:
HAS_TILELANG_OPS = False
TILELANG_BACKEND = "unavailable"
ssm_gamma_forward_tilelang = None
class SSMGammaFunction(torch.autograd.Function):
"""
Custom autograd function for SSM gamma forward/backward.
Uses TileLang kernels if available, falls back to PyTorch operations.
"""
@staticmethod
def forward(ctx, u, A, B, C, delta_t):
"""
Forward pass.
Args:
u: Input (batch, seq_len, state_dim)
A: State matrix (hidden_dim, hidden_dim)
B: Input matrix (hidden_dim, state_dim)
C: Output matrix (state_dim, hidden_dim)
delta_t: Discretization step
Returns:
output: (batch, seq_len, state_dim)
h_final: (batch, hidden_dim)
"""
if HAS_TILELANG_OPS and u.is_cuda:
# Use TileLang-accelerated kernels
y, h_final = ssm_gamma_forward_tilelang(
u.contiguous(), A.contiguous(), B.contiguous(), C.contiguous(), delta_t
)
else:
# Fallback to PyTorch implementation
y, h_final = _ssm_gamma_forward_pytorch(u, A, B, C, delta_t)
ctx.save_for_backward(u, A, B, C, y, h_final)
ctx.delta_t = delta_t
return y, h_final
@staticmethod
def backward(ctx, grad_y, grad_h_final):
"""Backward pass with gradient computation."""
u, A, B, C, y, h_final = ctx.saved_tensors
delta_t = ctx.delta_t
# Compute hidden state history from forward pass
h_history = _compute_state_history(u, A, B, C, delta_t)
if HAS_TILELANG_OPS and grad_y.is_cuda:
# Use TileLang kernels for backward
# For now, use the unified backward from tilelang
grad_u, grad_A, grad_B, grad_C, grad_h_init = _ssm_gamma_backward_pytorch(
grad_y, u, h_history, A, B, C, delta_t
)
else:
# Fallback to PyTorch backward
grad_u, grad_A, grad_B, grad_C, grad_h_init = _ssm_gamma_backward_pytorch(
grad_y, u, h_history, A, B, C, delta_t
)
return grad_u, grad_A, grad_B, grad_C, None
def ssm_gamma_forward(
u: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
delta_t: float = 0.1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
SSM Gamma forward pass using parallel scan optimization.
Implements: h[t+1] = h[t] + dt * (A @ h[t] + B @ u[t]), y[t] = C @ h[t]
Args:
u: Input tensor (batch, seq_len, state_dim)
A: State matrix (hidden_dim, hidden_dim) - tridiagonal HiPPO-Gamma
B: Input matrix (hidden_dim, state_dim)
C: Output matrix (state_dim, hidden_dim)
delta_t: Time discretization step
Returns:
output: (batch, seq_len, state_dim)
final_state: (batch, hidden_dim)
Example:
>>> u = torch.randn(2, 32, 64) # batch=2, seq=32, state_dim=64
>>> A = torch.eye(128) # 128-dim hidden state
>>> B = torch.randn(128, 64)
>>> C = torch.randn(64, 128)
>>> y, h = ssm_gamma_forward(u, A, B, C, delta_t=0.1)
>>> y.shape
torch.Size([2, 32, 64])
"""
# Always use custom function for autograd support
return SSMGammaFunction.apply(u, A, B, C, delta_t)
def _ssm_gamma_forward_pytorch(
u: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
delta_t: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Pure PyTorch implementation of SSM forward pass."""
batch_size, seq_len, state_dim = u.shape
hidden_dim = A.shape[0]
device = u.device
dtype = u.dtype
# Initialize hidden state
h = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype)
outputs = []
# Process sequence
for t in range(seq_len):
u_t = u[:, t, :] # (batch, state_dim)
# State update: h_new = h + delta_t * (A @ h + B @ u)
A_h = torch.matmul(h, A.T) # (batch, hidden_dim)
B_u = torch.matmul(u_t, B.T) # (batch, hidden_dim)
h_new = h + delta_t * (A_h + B_u)
# Output: y = C @ h_new
y_t = torch.matmul(h_new, C.T) # (batch, state_dim)
outputs.append(y_t)
h = h_new
y = torch.stack(outputs, dim=1) # (batch, seq_len, state_dim)
return y, h
def _compute_state_history(
u: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
delta_t: float,
) -> torch.Tensor:
"""Compute full state history for backward pass."""
batch_size, seq_len, state_dim = u.shape
hidden_dim = A.shape[0]
device = u.device
compute_dtype = torch.promote_types(u.dtype, A.dtype)
u_compute = u.to(dtype=compute_dtype)
A_compute = A.to(device=device, dtype=compute_dtype)
B_compute = B.to(device=device, dtype=compute_dtype)
h_history = torch.zeros(batch_size, seq_len, hidden_dim, device=device, dtype=compute_dtype)
h = torch.zeros(batch_size, hidden_dim, device=device, dtype=compute_dtype)
for t in range(seq_len):
u_t = u_compute[:, t, :]
A_h = torch.matmul(h, A_compute.T)
B_u = torch.matmul(u_t, B_compute.T)
h = h + delta_t * (A_h + B_u)
h_history[:, t, :] = h
return h_history
def _ssm_gamma_backward_pytorch(
grad_y: torch.Tensor,
u: torch.Tensor,
h_history: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
delta_t: float,
):
"""Pure PyTorch implementation of backward pass."""
batch_size, seq_len, state_dim = grad_y.shape
hidden_dim = A.shape[0]
device = grad_y.device
compute_dtype = torch.promote_types(
torch.promote_types(grad_y.dtype, A.dtype),
torch.promote_types(B.dtype, C.dtype),
)
grad_y_compute = grad_y.to(dtype=compute_dtype)
u_compute = u.to(device=device, dtype=compute_dtype)
h_history_compute = h_history.to(device=device, dtype=compute_dtype)
A_compute = A.to(device=device, dtype=compute_dtype)
B_compute = B.to(device=device, dtype=compute_dtype)
C_compute = C.to(device=device, dtype=compute_dtype)
grad_u = torch.zeros_like(u_compute)
grad_A = torch.zeros_like(A_compute)
grad_B = torch.zeros_like(B_compute)
grad_C = torch.zeros_like(C_compute)
grad_h = torch.zeros(batch_size, hidden_dim, device=device, dtype=compute_dtype)
identity = torch.eye(hidden_dim, device=device, dtype=compute_dtype)
# Backward pass through time
for t in range(seq_len - 1, -1, -1):
# Gradient from output
grad_h_from_y = torch.matmul(grad_y_compute[:, t, :], C_compute) # (batch, hidden_dim)
grad_h = grad_h + grad_h_from_y
# Gradient w.r.t. C
h_t = h_history_compute[:, t, :]
grad_C.add_(torch.matmul(grad_y_compute[:, t, :].T, h_t) / batch_size)
# Gradient w.r.t. B
u_t = u_compute[:, t, :]
grad_B.add_(delta_t * torch.matmul(grad_h.T, u_t) / batch_size)
# Gradient w.r.t. input
grad_u[:, t, :] = torch.matmul(grad_h, B_compute)
# Gradient w.r.t. A
prev_h = h_history_compute[:, t, :] if t > 0 else torch.zeros_like(h_history_compute[:, 0, :])
grad_A.add_(delta_t * torch.matmul(grad_h.T, prev_h) / batch_size)
# Propagate gradient backward through state transition
if t > 0:
grad_h = torch.matmul(grad_h, A_compute.T + identity)
grad_h_init = grad_h.to(dtype=u.dtype)
return (
grad_u.to(dtype=u.dtype),
grad_A.to(dtype=A.dtype),
grad_B.to(dtype=B.dtype),
grad_C.to(dtype=C.dtype),
grad_h_init,
)
def selective_scan_fwd(
u: torch.Tensor,
delta: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
D: Optional[torch.Tensor] = None,
return_last_state: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Selective scan forward pass (Mamba-style).
Extended interface for more complex SSM variants.
Currently falls back to SSM gamma for compatibility.
Args:
u: Input (batch, seq_len, d_model)
delta: Time step (scalar or per-timestep)
A: State transition matrix
B: Input matrix
C: Output matrix
D: Direct term (optional)
return_last_state: Whether to return final state
Returns:
output: (batch, seq_len, d_model)
last_state: Final state if requested
"""
# Simplified version - redirect to SSM gamma
output, last_state = ssm_gamma_forward(u, A, B, C, delta_t=delta if isinstance(delta, (int, float)) else delta.item())
if return_last_state:
return output, last_state
else:
return output, None
def selective_scan_bwd():
"""Backward pass placeholder for extensibility."""
pass