| import torch |
| import torch.nn.functional as F |
|
|
| from kernels.benchmark import Benchmark |
|
|
|
|
| def ms_deform_attn_reference( |
| value: torch.Tensor, |
| spatial_shapes: torch.Tensor, |
| level_start_index: torch.Tensor, |
| sampling_locations: torch.Tensor, |
| attention_weights: torch.Tensor, |
| ) -> torch.Tensor: |
| batch, _, num_heads, channels = value.shape |
| _, num_query, _, num_levels, num_points, _ = sampling_locations.shape |
|
|
| |
| value_list = [] |
| for level_id in range(num_levels): |
| H, W = spatial_shapes[level_id] |
| start_idx = level_start_index[level_id] |
| end_idx = ( |
| level_start_index[level_id + 1] |
| if level_id < num_levels - 1 |
| else value.shape[1] |
| ) |
| |
| value_level = value[:, start_idx:end_idx, :, :].view( |
| batch, H, W, num_heads, channels |
| ) |
| value_level = value_level.permute(0, 3, 4, 1, 2).contiguous() |
| value_list.append(value_level) |
|
|
| |
| output = torch.zeros( |
| batch, num_query, num_heads, channels, device=value.device, dtype=value.dtype |
| ) |
|
|
| for level_id in range(num_levels): |
| H, W = spatial_shapes[level_id] |
| value_level = value_list[level_id] |
|
|
| |
| sampling_loc_level = sampling_locations[:, :, :, level_id, :, :] |
|
|
| |
| grid = ( |
| 2.0 * sampling_loc_level - 1.0 |
| ) |
|
|
| |
| value_level = value_level.view(batch * num_heads, channels, H.item(), W.item()) |
| grid = grid.permute( |
| 0, 2, 1, 3, 4 |
| ).contiguous() |
| grid = grid.view(batch * num_heads, num_query, num_points, 2) |
|
|
| |
| sampled = F.grid_sample( |
| value_level, |
| grid, |
| mode="bilinear", |
| padding_mode="zeros", |
| align_corners=False, |
| ) |
|
|
| |
| sampled = sampled.view(batch, num_heads, channels, num_query, num_points) |
| |
| sampled = sampled.permute(0, 3, 1, 4, 2).contiguous() |
|
|
| |
| attn_level = attention_weights[:, :, :, level_id, :] |
|
|
| |
| output += (sampled * attn_level.unsqueeze(-1)).sum(dim=3) |
|
|
| |
| output = output.view(batch, num_query, num_heads * channels) |
| return output |
|
|
|
|
| class MSDeformAttnBenchmark(Benchmark): |
| seed: int = 42 |
|
|
| def setup(self): |
| batch = 2 |
| num_heads = 8 |
| channels = 32 |
| num_levels = 4 |
| num_query = 300 |
| num_points = 4 |
| im2col_step = 64 |
|
|
| |
| spatial_shapes = torch.tensor( |
| [[64, 64], [32, 32], [16, 16], [8, 8]], |
| dtype=torch.int64, |
| device=self.device, |
| ) |
| |
| spatial_size = (64 * 64) + (32 * 32) + (16 * 16) + (8 * 8) |
|
|
| |
| level_start_index = torch.tensor( |
| [0, 64 * 64, 64 * 64 + 32 * 32, 64 * 64 + 32 * 32 + 16 * 16], |
| dtype=torch.int64, |
| device=self.device, |
| ) |
|
|
| self.value = torch.randn( |
| batch, |
| spatial_size, |
| num_heads, |
| channels, |
| device=self.device, |
| dtype=torch.float32, |
| ) |
| self.spatial_shapes = spatial_shapes |
| self.level_start_index = level_start_index |
| self.sampling_loc = torch.rand( |
| batch, |
| num_query, |
| num_heads, |
| num_levels, |
| num_points, |
| 2, |
| device=self.device, |
| dtype=torch.float32, |
| ) |
| self.attn_weight = torch.rand( |
| batch, |
| num_query, |
| num_heads, |
| num_levels, |
| num_points, |
| device=self.device, |
| dtype=torch.float32, |
| ) |
| |
| self.attn_weight = self.attn_weight / self.attn_weight.sum(-1, keepdim=True) |
| self.im2col_step = im2col_step |
|
|
| self.out = torch.empty( |
| batch, |
| num_query, |
| num_heads * channels, |
| device=self.device, |
| dtype=torch.float32, |
| ) |
|
|
| def benchmark_forward(self): |
| self.out = self.kernel.ms_deform_attn_forward( |
| self.value, |
| self.spatial_shapes, |
| self.level_start_index, |
| self.sampling_loc, |
| self.attn_weight, |
| self.im2col_step, |
| ) |
|
|
| def verify_forward(self) -> torch.Tensor: |
| return ms_deform_attn_reference( |
| self.value, |
| self.spatial_shapes, |
| self.level_start_index, |
| self.sampling_loc, |
| self.attn_weight, |
| ) |
|
|
| def setup_large(self): |
| batch = 8 |
| num_heads = 8 |
| channels = 32 |
| num_levels = 4 |
| num_query = 900 |
| num_points = 4 |
| im2col_step = 64 |
|
|
| spatial_shapes = torch.tensor( |
| [[64, 64], [32, 32], [16, 16], [8, 8]], |
| dtype=torch.int64, |
| device=self.device, |
| ) |
| spatial_size = (64 * 64) + (32 * 32) + (16 * 16) + (8 * 8) |
|
|
| level_start_index = torch.tensor( |
| [0, 64 * 64, 64 * 64 + 32 * 32, 64 * 64 + 32 * 32 + 16 * 16], |
| dtype=torch.int64, |
| device=self.device, |
| ) |
|
|
| self.value = torch.randn( |
| batch, |
| spatial_size, |
| num_heads, |
| channels, |
| device=self.device, |
| dtype=torch.float32, |
| ) |
| self.spatial_shapes = spatial_shapes |
| self.level_start_index = level_start_index |
| self.sampling_loc = torch.rand( |
| batch, |
| num_query, |
| num_heads, |
| num_levels, |
| num_points, |
| 2, |
| device=self.device, |
| dtype=torch.float32, |
| ) |
| self.attn_weight = torch.rand( |
| batch, |
| num_query, |
| num_heads, |
| num_levels, |
| num_points, |
| device=self.device, |
| dtype=torch.float32, |
| ) |
| self.attn_weight = self.attn_weight / self.attn_weight.sum(-1, keepdim=True) |
| self.im2col_step = im2col_step |
|
|
| self.out = torch.empty( |
| batch, |
| num_query, |
| num_heads * channels, |
| device=self.device, |
| dtype=torch.float32, |
| ) |
|
|
| def benchmark_large(self): |
| self.out = self.kernel.ms_deform_attn_forward( |
| self.value, |
| self.spatial_shapes, |
| self.level_start_index, |
| self.sampling_loc, |
| self.attn_weight, |
| self.im2col_step, |
| ) |
|
|
| def verify_large(self) -> torch.Tensor: |
| return ms_deform_attn_reference( |
| self.value, |
| self.spatial_shapes, |
| self.level_start_index, |
| self.sampling_loc, |
| self.attn_weight, |
| ) |
|
|