Buckets:
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| from typing import Tuple | |
| import torch | |
| from accelerated_scan.warp import warpscan_forward, warpscan_backward | |
| def scan_fwd( | |
| gates: torch.Tensor, | |
| tokens: torch.Tensor, | |
| reverse: bool = False, | |
| ) -> torch.Tensor: | |
| B, dim, seq_len = gates.shape | |
| assert tokens.shape == (B, dim, seq_len) | |
| assert gates.is_contiguous() | |
| assert tokens.is_contiguous() | |
| output = torch.zeros_like(tokens) | |
| warpscan_forward(gates, tokens, output, reverse) | |
| return output | |
| def _scan_fwd_fake(gates, tokens, reverse=False): | |
| return torch.empty_like(tokens) | |
| def scan_bwd( | |
| dout: torch.Tensor, | |
| states: torch.Tensor, | |
| gates: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| dout = dout.contiguous() | |
| assert states.is_contiguous() | |
| assert gates.is_contiguous() | |
| d_gates = torch.empty_like(gates) | |
| d_tokens = torch.empty_like(gates) | |
| warpscan_backward(gates, states, dout, d_gates, d_tokens) | |
| return d_gates, d_tokens | |
| def _scan_bwd_fake(dout, states, gates): | |
| return torch.empty_like(gates), torch.empty_like(gates) | |
| def scan_setup_context(ctx, inputs, output): | |
| gates, tokens, reverse = inputs | |
| ctx.save_for_backward(gates, output) | |
| def scan_bwd_bridge(ctx, dout): | |
| gates, states = ctx.saved_tensors | |
| d_gates, d_tokens = scan_bwd(dout, states, gates) | |
| return d_gates, d_tokens, None | |
| torch.library.register_autograd( | |
| "scan::scan_fwd", | |
| scan_bwd_bridge, | |
| setup_context=scan_setup_context, | |
| ) | |
| def scan(gates: torch.Tensor, tokens: torch.Tensor, reverse: bool = False) -> torch.Tensor: | |
| return scan_fwd(gates, tokens, reverse) | |
Xet Storage Details
- Size:
- 1.91 kB
- Xet hash:
- 7061e2ded0e81d574fb90a7d4e52f4721dd7192e7d982f169239f39e0e92d4f5
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.