Buckets:
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| from typing import Tuple | |
| import torch | |
| from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states | |
| from apps.mamba.component.causal_conv1d_compilable import ( | |
| causal_conv1d_fn, | |
| causal_conv1d_update, | |
| ) | |
| from apps.fastRNN.component.compilable_scan import scan as accelerated_scan | |
| # from accelerated_scan.triton import scan as triton_scan | |
| from accelerated_scan.ref import scan as ref_scan | |
| def conv1d( | |
| x: torch.Tensor, | |
| conv_weight: torch.Tensor, | |
| tok_idx: torch.Tensor, | |
| cu_seqlens: torch.Tensor, | |
| impl: str = "parallel", | |
| cache=None, | |
| ) -> torch.Tensor: | |
| if impl == "parallel": | |
| if cache is not None: | |
| conv_varlen_states = causal_conv1d_varlen_states( | |
| x.squeeze(0).transpose(0, 1), cu_seqlens, state_len=cache.shape[-1] | |
| ) | |
| cache.copy_(conv_varlen_states) | |
| x = causal_conv1d_fn( | |
| x=x, | |
| weight=conv_weight, | |
| bias=None, | |
| seq_idx=tok_idx, | |
| activation="silu", | |
| ) | |
| elif impl == "sequential": | |
| x = ( | |
| causal_conv1d_update( | |
| x=x.squeeze(0).transpose(0, 1), | |
| conv_state=cache, | |
| weight=conv_weight, | |
| bias=None, | |
| activation="silu", | |
| ) | |
| .transpose(0, 1) | |
| .unsqueeze(0) | |
| ) | |
| else: | |
| raise NotImplementedError( | |
| f"causal_conv1d implementation {impl} not supported" | |
| ) | |
| return x | |
| def _prepare_for_cache( | |
| a: torch.Tensor, b: torch.Tensor, cu_seqlen: torch.Tensor, seq_len: int | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """This function reset the hidden state at the beginning of each sequence in the batch so that the hidden state is not carried over between sequences.""" | |
| num_seq = cu_seqlen.size(0) - 1 | |
| pow_2_seqlen = max(2 ** (seq_len + num_seq - 2).bit_length(), 32) | |
| _a = torch.zeros(*a.shape[:2], pow_2_seqlen, device=a.device, dtype=a.dtype) | |
| _b = torch.zeros(*b.shape[:2], pow_2_seqlen, device=b.device, dtype=b.dtype) | |
| mask = torch.zeros(pow_2_seqlen, dtype=torch.bool, device=a.device) | |
| offsets = torch.arange(0, num_seq, device=a.device) | |
| mask[cu_seqlen[1:-1] + offsets[:-1]] = True | |
| mask[(cu_seqlen[-1] + offsets[-1]) :] = True | |
| mask = (~mask).nonzero().flatten() | |
| for tensor_with_reset, tensor in zip((_a, _b), (a, b)): | |
| tensor_with_reset[..., mask] = tensor | |
| return _a, _b, cu_seqlen[1:] + offsets - 1, mask | |
| def sequential_step( | |
| states: torch.Tensor, a: torch.Tensor, b: torch.Tensor | |
| ) -> torch.Tensor: | |
| return a * states + b | |
| def scan( | |
| a: torch.Tensor, | |
| b: torch.Tensor, | |
| cu_seqlens: torch.Tensor, | |
| impl: str = "parallel", | |
| cache=None, | |
| ) -> torch.Tensor: | |
| if impl == "parallel": | |
| if cache is not None: | |
| # For accelerated_scan give me illegal memory access error when seqlen > ~2048 | |
| a, b, last_state_idx, mask = _prepare_for_cache(a, b, cu_seqlens, a.size(2)) | |
| h = ref_scan( | |
| a.contiguous(), | |
| b.contiguous(), | |
| ) | |
| cache.copy_(h[:, :, last_state_idx]) | |
| h = h[:, :, mask] | |
| else: | |
| h = accelerated_scan( | |
| a.contiguous(), | |
| b.contiguous(), | |
| ) | |
| elif impl == "sequential": | |
| h = sequential_step(cache, a, b) | |
| cache.copy_(h) | |
| return h | |
Xet Storage Details
- Size:
- 3.52 kB
- Xet hash:
- 9db467aefa29b8f6289efa21600719e2e4914d48a21efd1211e6fab8b0450d68
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.