| |
|
|
| from typing import Tuple |
|
|
| import torch |
|
|
| from accelerated_scan.warp import warpscan_forward, warpscan_backward |
|
|
| @torch.library.custom_op( |
| "scan::scan_fwd", |
| mutates_args=(), |
| device_types="cuda", |
| ) |
| def scan_fwd( |
| gates: torch.Tensor, |
| tokens: torch.Tensor, |
| reverse: bool = False, |
| ) -> torch.Tensor: |
| B, dim, seq_len = gates.shape |
| assert tokens.shape == (B, dim, seq_len) |
| assert gates.is_contiguous() |
| assert tokens.is_contiguous() |
|
|
| output = torch.zeros_like(tokens) |
| warpscan_forward(gates, tokens, output, reverse) |
| return output |
|
|
| @scan_fwd.register_fake |
| def _scan_fwd_fake(gates, tokens, reverse=False): |
| return torch.empty_like(tokens) |
|
|
| @torch.library.custom_op( |
| "scan::scan_bwd", |
| mutates_args=(), |
| device_types="cuda", |
| ) |
| def scan_bwd( |
| dout: torch.Tensor, |
| states: torch.Tensor, |
| gates: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| |
| dout = dout.contiguous() |
| assert states.is_contiguous() |
| assert gates.is_contiguous() |
|
|
| d_gates = torch.empty_like(gates) |
| d_tokens = torch.empty_like(gates) |
| warpscan_backward(gates, states, dout, d_gates, d_tokens) |
|
|
| return d_gates, d_tokens |
|
|
| @scan_bwd.register_fake |
| def _scan_bwd_fake(dout, states, gates): |
| return torch.empty_like(gates), torch.empty_like(gates) |
|
|
| def scan_setup_context(ctx, inputs, output): |
| gates, tokens, reverse = inputs |
| ctx.save_for_backward(gates, output) |
|
|
| def scan_bwd_bridge(ctx, dout): |
| gates, states = ctx.saved_tensors |
| d_gates, d_tokens = scan_bwd(dout, states, gates) |
| |
| return d_gates, d_tokens, None |
|
|
| torch.library.register_autograd( |
| "scan::scan_fwd", |
| scan_bwd_bridge, |
| setup_context=scan_setup_context, |
| ) |
|
|
| def scan(gates: torch.Tensor, tokens: torch.Tensor, reverse: bool = False) -> torch.Tensor: |
| return scan_fwd(gates, tokens, reverse) |
|
|