Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import Tensor | |
| import triton | |
| import triton.language as tl | |
| def _causal_conv1d_varlen_states( | |
| X, | |
| CU_SEQLENS, | |
| STATES, | |
| state_len, | |
| dim, | |
| stride_x_seqlen, stride_x_dim, | |
| stride_states_batch, stride_states_seqlen, stride_states_dim, | |
| BLOCK_M: tl.constexpr, | |
| BLOCK_N: tl.constexpr | |
| ): | |
| batch_idx = tl.program_id(2) | |
| STATES += batch_idx * stride_states_batch | |
| end_idx = tl.load(CU_SEQLENS + batch_idx + 1) | |
| start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) | |
| rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) | |
| cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) | |
| x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, | |
| mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), | |
| other=0) | |
| rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) | |
| tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, | |
| x, | |
| mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) | |
| def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: | |
| """ | |
| Forward pass only, does not support backward pass. | |
| Parameters: | |
| x: (total_tokens, dim) | |
| cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. | |
| state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. | |
| If some of those elements belong to a different sequence, the value of the states will be zero. | |
| Return: | |
| states: (batch, dim, state_len) | |
| """ | |
| _, dim = x.shape | |
| batch = cu_seqlens.shape[0] - 1 | |
| cu_seqlens = cu_seqlens.contiguous() | |
| states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) | |
| BLOCK_M = min(triton.next_power_of_2(state_len), 16) | |
| BLOCK_N = min(triton.next_power_of_2(dim), 256) | |
| grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) | |
| with torch.cuda.device(x.device.index): | |
| _causal_conv1d_varlen_states[grid]( | |
| x, | |
| cu_seqlens, | |
| states, | |
| state_len, | |
| dim, | |
| x.stride(0), x.stride(1), | |
| states.stride(0), states.stride(2), states.stride(1), | |
| BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N | |
| ) | |
| return states | |
| def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: | |
| """ | |
| Forward pass only, does not support backward pass. | |
| Parameters: | |
| x: (total_tokens, dim) | |
| cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. | |
| state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. | |
| If some of those elements belong to a different sequence, the value of the states will be zero. | |
| Return: | |
| states: (batch, dim, state_len) | |
| """ | |
| _, dim = x.shape | |
| batch = cu_seqlens.shape[0] - 1 | |
| cu_seqlens = cu_seqlens.contiguous() | |
| states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) | |
| for i in range(batch): | |
| end_idx = cu_seqlens[i + 1] | |
| start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) | |
| states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T | |
| return states | |