File size: 14,532 Bytes
f62ec09 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import torch
import triton
import triton.language as tl
from einops import rearrange
import torch.nn.functional as F
from torch.autograd import Function
# Helper function to ensure tensors are contiguous for Triton
def ensure_contiguous(t: torch.Tensor) -> torch.Tensor:
return t if t.is_contiguous() else t.contiguous()
# --- Forward Kernel (modified for optional cache) ---
@triton.jit
def _dynamic_conv_fwd_kernel(
X_ptr, K_ptr, Out_ptr,
Cache_ptr, # New: Pointer to cache tensor
B, T, D, T_CACHE: tl.constexpr, # New: T is shape of x, T_CACHE is shape of cache
X_stride_b, X_stride_t, X_stride_d,
K_stride_b, K_stride_t, K_stride_d, K_stride_w,
Out_stride_b, Out_stride_t, Out_stride_d,
Cache_stride_b, Cache_stride_t, Cache_stride_d, # New: Strides for cache tensor
W: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
pid_batch_time = tl.program_id(0) # Covers B * T_out
pid_d_block = tl.program_id(1)
# T here is the time dimension of x and Out
batch_idx = tl.cast(pid_batch_time // T, tl.int64)
time_idx = pid_batch_time % T # Current output time step for x (0 to T-1)
offs_d = pid_d_block * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D)
d_mask = offs_d < D
accumulator = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32)
offs_w = tl.arange(0, W) # Kernel window offsets [0, 1, ..., W-1]
# Load Kernels (kernels are aligned with x's T dimension)
# K_ptr is indexed by time_idx which is the output time relative to x's start
k_ptrs = K_ptr + (batch_idx * K_stride_b + time_idx * K_stride_t +
offs_d[:, None] * K_stride_d + offs_w[None, :] * K_stride_w)
k_vals = tl.load(k_ptrs, mask=d_mask[:, None], other=0.0) # Shape: [BLOCK_SIZE_D, W]
# --- Load Input from conceptual [Cache, X] tensor ---
# `time_idx` is the current output time step (0 to T-1, where T is x.shape[1])
# `offs_w` is [0, ..., W-1]
# Convolution input time indices relative to the *start of x*:
# e.g., for W=3, offs_w - W + 1 gives [-2, -1, 0]
# so input_time_indices_rel_to_x_start are [time_idx-2, time_idx-1, time_idx]
input_time_indices_rel_to_x_start = time_idx + offs_w - W + 1 # Shape: [W]
# Effective input time indices in the conceptual [Cache, X] sequence:
# These indices range from 0 (start of cache) to T_CACHE + T - 1 (end of x)
eff_t_indices = input_time_indices_rel_to_x_start + T_CACHE # Shape: [W]
# Overall mask for valid time indices within the conceptual [Cache, X] tensor
# Total effective length is T_CACHE (for cache) + T (for x)
eff_t_valid_mask = (eff_t_indices >= 0) & (eff_t_indices < (T_CACHE + T)) # Shape: [W]
# --- Load from Cache ---
# Condition for loading from cache: index is valid AND index < T_CACHE
# (eff_t_indices are 0-indexed from the start of the cache)
cache_load_time_mask = eff_t_valid_mask & (eff_t_indices < T_CACHE) # Shape: [W]
cache_ptr_indices = eff_t_indices # Use directly if in cache range
cache_ptrs = Cache_ptr + (batch_idx * Cache_stride_b +
cache_ptr_indices[None, :] * Cache_stride_t +
offs_d[:, None] * Cache_stride_d)
cache_final_load_mask = d_mask[:, None] & cache_load_time_mask[None, :] # Shape: [BLOCK_SIZE_D, W]
vals_from_cache = tl.load(cache_ptrs, mask=cache_final_load_mask, other=0.0) # Shape: [BLOCK_SIZE_D, W]
# --- Load from X ---
# Condition for loading from X: index is valid AND index >= T_CACHE
x_load_time_mask = eff_t_valid_mask & (eff_t_indices >= T_CACHE) # Shape: [W]
# Adjust indices for X_ptr: X_ptr expects indices from 0 to T-1 (relative to start of x)
x_ptr_indices = eff_t_indices - T_CACHE # Shape: [W]
x_ptrs = X_ptr + (batch_idx * X_stride_b +
x_ptr_indices[None, :] * X_stride_t +
offs_d[:, None] * X_stride_d)
x_final_load_mask = d_mask[:, None] & x_load_time_mask[None, :] # Shape: [BLOCK_SIZE_D, W]
vals_from_x = tl.load(x_ptrs, mask=x_final_load_mask, other=0.0) # Shape: [BLOCK_SIZE_D, W]
# Combine values. Masks ensure only one source contributes non-zero per element.
# If T_CACHE == 0, cache_load_time_mask is all False, so vals_from_cache is 0.0.
x_input_vals = vals_from_cache + vals_from_x # Shape: [BLOCK_SIZE_D, W]
# Compute and Accumulate
product = k_vals * x_input_vals # Element-wise product
accumulator += tl.sum(product, axis=1) # Sum over W dimension
# Store Result
out_ptrs = Out_ptr + (batch_idx * Out_stride_b + time_idx * Out_stride_t +
offs_d * Out_stride_d)
tl.store(out_ptrs, accumulator, mask=d_mask)
# --- Backward Kernel for Input Gradient (dX) ---
@triton.jit
def _dynamic_conv_bwd_dx_kernel(
GradOut_ptr, K_ptr, GradX_ptr, # Note: GradX is accumulated into
B, T, D,
GradOut_stride_b, GradOut_stride_t, GradOut_stride_d,
K_stride_b, K_stride_t, K_stride_d, K_stride_w,
GradX_stride_b, GradX_stride_t, GradX_stride_d,
W: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
"""
Computes gradient w.r.t. input X.
Grid: (B * T, cdiv(D, BLOCK_SIZE_D)) - covering GradX output
GradX[b, t_x, d] = sum_{w=0}^{W-1} GradOut[b, t, d] * K[b, t, d, w]
where t = t_x + W - 1 - w
"""
pid_batch_time_x = tl.program_id(0) # Covers B * T for output GradX
pid_d_block = tl.program_id(1)
batch_idx = pid_batch_time_x // T
time_idx_x = pid_batch_time_x % T # This is t_x
offs_d = pid_d_block * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D)
d_mask = offs_d < D
# Accumulator for GradX elements
accumulator = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32)
offs_w = tl.arange(0, W) # [W]
# Loop over W to accumulate contributions
# Calculate the 't' index needed for GradOut and K based on t_x and w
# t = t_x + W - 1 - w
t_k_gradout_offs = time_idx_x + W - 1 - offs_w # Shape [W]
# Mask for valid 't' indices [0, T)
t_k_gradout_mask = (t_k_gradout_offs >= 0) & (t_k_gradout_offs < T) # Shape [W]
# --- Load GradOut ---
# Pointers shape: [BLOCK_SIZE_D, W]
gradout_ptrs = GradOut_ptr + (batch_idx * GradOut_stride_b +
t_k_gradout_offs[None, :] * GradOut_stride_t +
offs_d[:, None] * GradOut_stride_d)
# Combined mask for loading GradOut (valid D and valid t)
gradout_load_mask = d_mask[:, None] & t_k_gradout_mask[None, :]
# Shape: [BLOCK_SIZE_D, W]
gradout_vals = tl.load(gradout_ptrs, mask=gradout_load_mask, other=0.0)
# --- Load Kernels ---
# Pointers shape: [BLOCK_SIZE_D, W]
k_ptrs = K_ptr + (batch_idx * K_stride_b +
t_k_gradout_offs[None, :] * K_stride_t +
offs_d[:, None] * K_stride_d +
offs_w[None, :] * K_stride_w) # Index K with 't' and 'w'
# Combined mask for loading K (valid D and valid t)
k_load_mask = d_mask[:, None] & t_k_gradout_mask[None, :]
# Shape: [BLOCK_SIZE_D, W]
k_vals = tl.load(k_ptrs, mask=k_load_mask, other=0.0)
# --- Compute product and accumulate ---
# Shape: [BLOCK_SIZE_D, W]
product = gradout_vals * k_vals
# Sum contributions over the W dimension
accumulator += tl.sum(product, axis=1) # Shape: [BLOCK_SIZE_D]
# --- Store accumulated gradients ---
gradx_ptrs = GradX_ptr + (batch_idx * GradX_stride_b +
time_idx_x * GradX_stride_t +
offs_d * GradX_stride_d)
tl.store(gradx_ptrs, accumulator, mask=d_mask)
# --- Backward Kernel for Kernel Gradient (dK) ---
@triton.jit
def _dynamic_conv_bwd_dk_kernel(
GradOut_ptr, X_ptr, GradK_ptr, # Note: GradK is written directly
B, T, D,
GradOut_stride_b, GradOut_stride_t, GradOut_stride_d,
X_stride_b, X_stride_t, X_stride_d,
GradK_stride_b, GradK_stride_t, GradK_stride_d, GradK_stride_w,
W: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
"""
Computes gradient w.r.t. kernels K.
Grid: (B * T, cdiv(D, BLOCK_SIZE_D)) - covering GradK output dims B, T, D
GradK[b, t, d, w] = GradOut[b, t, d] * X[b, t + w - W + 1, d]
"""
pid_batch_time = tl.program_id(0) # Covers B * T for output GradK
pid_d_block = tl.program_id(1)
batch_idx = pid_batch_time // T
time_idx = pid_batch_time % T # This is 't' for GradK and GradOut
offs_d = pid_d_block * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D)
d_mask = offs_d < D
offs_w = tl.arange(0, W) # [W]
# --- Load GradOut ---
# Pointers shape: [BLOCK_SIZE_D] (only depends on b, t, d)
gradout_ptrs = GradOut_ptr + (batch_idx * GradOut_stride_b +
time_idx * GradOut_stride_t +
offs_d * GradOut_stride_d)
# Shape: [BLOCK_SIZE_D]
gradout_vals = tl.load(gradout_ptrs, mask=d_mask, other=0.0)
# --- Load Input X with implicit padding ---
# Calculate X's time index: t_x = t + w - W + 1
t_in_offs = time_idx + offs_w - W + 1 # Shape [W]
# Mask for valid t_x index [0, T)
t_in_mask = (t_in_offs >= 0) & (t_in_offs < T) # Shape [W]
# Pointers shape: [BLOCK_SIZE_D, W]
x_ptrs = X_ptr + (batch_idx * X_stride_b +
t_in_offs[None, :] * X_stride_t +
offs_d[:, None] * X_stride_d)
# Combined mask for loading X (valid D and valid t_x)
x_load_mask = d_mask[:, None] & t_in_mask[None, :] # Shape [BLOCK_SIZE_D, W]
# Shape: [BLOCK_SIZE_D, W]
x_vals = tl.load(x_ptrs, mask=x_load_mask, other=0.0)
# --- Compute GradK = GradOut * X ---
# Broadcast gradout_vals: [BLOCK_SIZE_D, 1] * [BLOCK_SIZE_D, W] -> [BLOCK_SIZE_D, W]
gradk_vals = gradout_vals[:, None] * x_vals # Shape [BLOCK_SIZE_D, W]
# --- Store gradients for Kernels ---
# Pointers shape: [BLOCK_SIZE_D, W]
gradk_ptrs = GradK_ptr + (batch_idx * GradK_stride_b +
time_idx * GradK_stride_t +
offs_d[:, None] * GradK_stride_d +
offs_w[None, :] * GradK_stride_w)
# Mask only needed for D dimension (W is fully computed)
# Store computed gradient values.
tl.store(gradk_ptrs, gradk_vals, mask=d_mask[:, None])
# --- Autograd Function ---
class DynamicConvTritonFunc(Function):
@staticmethod
def forward(ctx, x, kernels, cache=None): # Added cache argument
"""
Args:
x: Input tensor [B, T, D]
kernels: Kernels tensor [B, T, D, W]
cache: Optional past context tensor [B, T_cache, D]
"""
x = ensure_contiguous(x)
kernels = ensure_contiguous(kernels)
B, T, D = x.shape # T is the time dimension of the current input x
_B_k, _T_k, _D_k, W = kernels.shape # Kernels are [B, T_x, D, W]
assert B == _B_k and T == _T_k and D == _D_k, \
f"Shape mismatch between x ({x.shape}) and kernels ({kernels.shape}) on B, T, or D dims"
assert W <= 4, "Kernel W > 4 not expected for this version"
out = torch.empty_like(x) # Output shape [B, T, D], corresponds to x
T_cache_val = 0
# Use x's data pointer and zero strides as placeholders if cache is None.
# These won't be used by the kernel if T_CACHE_VAL is 0 due to masking.
cache_ptr_val = x
cache_s_b, cache_s_t, cache_s_d = 0, 0, 0
if cache is not None:
cache = ensure_contiguous(cache)
B_c, T_c, D_c = cache.shape
assert B_c == B, f"Batch size mismatch: x ({B}) vs cache ({B_c})"
assert D_c == D, f"Dimension mismatch: x ({D}) vs cache ({D_c})"
T_cache_val = T_c
cache_ptr_val = cache
cache_s_b, cache_s_t, cache_s_d = cache.stride(0), cache.stride(1), cache.stride(2)
grid = lambda meta: (B * T, triton.cdiv(D, meta['BLOCK_SIZE_D']))
BLOCK_SIZE_D = 128 # Consider tuning
_dynamic_conv_fwd_kernel[grid](
x, kernels, out, # X, K, Out pointers
cache_ptr_val, # Cache pointer
B, T, D, T_cache_val, # Shapes: B, T_x, D, T_cache
x.stride(0), x.stride(1), x.stride(2), # X strides
kernels.stride(0), kernels.stride(1), kernels.stride(2), kernels.stride(3), # K strides
out.stride(0), out.stride(1), out.stride(2), # Out strides
cache_s_b, cache_s_t, cache_s_d, # Cache strides
W=W,
BLOCK_SIZE_D=BLOCK_SIZE_D,
)
# Save tensors needed for backward (cache is not needed for current backward)
ctx.save_for_backward(x, kernels)
ctx.W = W
ctx.BLOCK_SIZE_D = BLOCK_SIZE_D
# ctx.T_cache = T_cache_val # Not needed for current backward
return out
@staticmethod
def backward(ctx, grad_out):
raise NotImplementedError("Backward of cached fwdbwd is not implemented")
# --- User-facing function ---
def dynamic_conv_triton_cache(x: torch.Tensor, kernels: torch.Tensor, cache: torch.Tensor = None) -> torch.Tensor:
"""
Fused dynamic convolution with autograd support using Triton kernels.
Assumes W <= 4.
Args:
x: Input tensor of shape [B, T, D].
kernels: Dynamic kernels of shape [B, T, D, W].
cache: Optional past context tensor of shape [B, T_cache, D].
If provided, treated as concatenated before x for convolution input.
Returns:
Output tensor of shape [B, T, D].
"""
return DynamicConvTritonFunc.apply(x, kernels, cache) |