PyTorch Native - Deformable DETR

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.25s | Raw GitHub
import subprocess
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Fri Dec 19 23:02:11 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.105.08             Driver Version: 580.105.08     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   42C    P0             83W /  350W |       0MiB /  46068MiB |     12%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Deformable DETR Multi-Scale Deformable Attention Benchmark (PyTorch Native)

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 9.26s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_deformable_detr(
    value, spatial_shapes, level_start_index, sampling_locations, attention_weights, im2col_step=64
):
    """
    PyTorch native reference implementation of multi-scale deformable attention.
    Uses vectorized bilinear interpolation for reasonable performance.
    """
    bs, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
    _, _, _, channels = value.shape

    output = torch.zeros(bs, num_queries, num_heads, channels, device=value.device, dtype=value.dtype)

    # Split value tensor by levels
    value_list = value.split([int(h * w) for h, w in spatial_shapes.tolist()], dim=1)

    # Iterate through each level (can't avoid this loop easily)
    for level_idx in range(num_levels):
        h, w = spatial_shapes[level_idx].tolist()
        value_level = value_list[level_idx]  # (bs, h*w, num_heads, channels)

        # Reshape to spatial grid: (bs, num_heads, channels, h, w)
        value_spatial = value_level.reshape(bs, h, w, num_heads, channels).permute(0, 3, 4, 1, 2)

        # Get sampling locations and weights for this level
        # loc: (bs, num_queries, num_heads, num_points, 2)
        loc = sampling_locations[:, :, :, level_idx, :, :]
        # weight: (bs, num_queries, num_heads, num_points)
        weight = attention_weights[:, :, :, level_idx, :]

        # Convert normalized coordinates to pixel coordinates
        # loc[..., 0] is x (width), loc[..., 1] is y (height)
        x = loc[..., 0] * w - 0.5  # (bs, num_queries, num_heads, num_points)
        y = loc[..., 1] * h - 0.5

        # Get integer coordinates for bilinear interpolation
        x0 = torch.floor(x).long()
        y0 = torch.floor(y).long()
        x1 = x0 + 1
        y1 = y0 + 1

        # Compute interpolation weights BEFORE clamping (important!)
        lw = x - x0.float()  # weight for x direction
        lh = y - y0.float()  # weight for y direction
        hw = 1 - lw
        hh = 1 - lh

        # Create mask for valid sample locations
        valid = (y > -1) & (x > -1) & (y < h) & (x < w)

        # Create masks for each corner being in bounds
        mask_tl = ((y0 >= 0) & (x0 >= 0)).unsqueeze(-1).float()
        mask_tr = ((y0 >= 0) & (x1 <= w - 1)).unsqueeze(-1).float()
        mask_bl = ((y1 <= h - 1) & (x0 >= 0)).unsqueeze(-1).float()
        mask_br = ((y1 <= h - 1) & (x1 <= w - 1)).unsqueeze(-1).float()

        # Clamp coordinates for safe indexing
        x0_clamped = torch.clamp(x0, 0, w - 1)
        x1_clamped = torch.clamp(x1, 0, w - 1)
        y0_clamped = torch.clamp(y0, 0, h - 1)
        y1_clamped = torch.clamp(y1, 0, h - 1)

        # Bilinear interpolation weights for all 4 corners
        w_tl = (hh * hw).unsqueeze(-1)  # top-left: (bs, num_queries, num_heads, num_points, 1)
        w_tr = (hh * lw).unsqueeze(-1)  # top-right
        w_bl = (lh * hw).unsqueeze(-1)  # bottom-left
        w_br = (lh * lw).unsqueeze(-1)  # bottom-right

        # Gather values from the 4 corners using advanced indexing
        batch_idx = torch.arange(bs, device=value.device).view(bs, 1, 1, 1).expand(bs, num_queries, num_heads, num_points)
        head_idx = torch.arange(num_heads, device=value.device).view(1, 1, num_heads, 1).expand(bs, num_queries, num_heads, num_points)

        # Gather corner values with clamped indices, then apply corner masks
        v_tl = value_spatial[batch_idx, head_idx, :, y0_clamped, x0_clamped] * mask_tl
        v_tr = value_spatial[batch_idx, head_idx, :, y0_clamped, x1_clamped] * mask_tr
        v_bl = value_spatial[batch_idx, head_idx, :, y1_clamped, x0_clamped] * mask_bl
        v_br = value_spatial[batch_idx, head_idx, :, y1_clamped, x1_clamped] * mask_br

        # Bilinear interpolation
        sampled = w_tl * v_tl + w_tr * v_tr + w_bl * v_bl + w_br * v_br

        # Apply valid mask (only accumulate if entire sample location is valid)
        sampled = sampled * valid.unsqueeze(-1).float()

        # Apply attention weights and sum over points
        # weight: (bs, num_queries, num_heads, num_points)
        # Expand weight: (bs, num_queries, num_heads, num_points, 1)
        weighted_sampled = sampled * weight.unsqueeze(-1)

        # Sum over points: (bs, num_queries, num_heads, channels)
        output += weighted_sampled.sum(dim=3)

    # Flatten last two dimensions to match kernel output
    return output.reshape(bs, num_queries, num_heads * channels)


run_benchmark(
    kernel_type=KernelTypeEnum.DEFORMABLE_DETR,
    impl_name="torch_eager",
    impl_tags={"family": "pytorch", "backend": "eager"},
    impl_func=torch_deformable_detr,
    dtype="float32",
)
Running deformable_detr benchmark on cuda with 4 workloads.

======================================================================
PROFILE TRACE: torch_eager | cuda_B1_Q100_H8_E256_L4_P4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            torch_eager         0.00%       0.000us         0.00%       0.000us       0.000us      19.976ms      1348.57%      19.976ms      19.976ms             1  
                                            torch_eager        20.04%       4.395ms        99.96%      21.929ms      21.929ms       0.000us         0.00%       1.482ms       1.482ms             1  
                                            aten::index         4.53%     992.766us        16.58%       3.638ms      75.786us     236.544us        15.97%     370.336us       7.715us            48  
                                            aten::copy_         4.69%       1.028ms        11.56%       2.535ms      11.576us     366.053us        24.71%     366.053us       1.671us           219  
                                              aten::mul         5.90%       1.295ms        10.04%       2.203ms      11.474us     293.531us        19.82%     293.531us       1.529us           192  
void at::native::index_elementwise_kernel<128, 4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     236.544us        15.97%     236.544us       4.928us            48  
                                               aten::to         0.58%     126.843us        11.27%       2.473ms      14.461us       0.000us         0.00%     232.261us       1.358us           171  
                                         aten::_to_copy         1.95%     426.950us        10.69%       2.346ms      19.073us       0.000us         0.00%     232.261us       1.888us           123  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     201.821us        13.62%     201.821us       1.682us           120  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us     167.778us        11.33%     167.778us       1.997us            84  
                                       aten::contiguous         0.36%      78.966us         8.52%       1.869ms      19.471us       0.000us         0.00%     133.792us       1.394us            96  
                                            aten::clone         0.74%     161.750us         8.16%       1.790ms      18.648us       0.000us         0.00%     133.792us       1.394us            96  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     133.792us         9.03%     133.792us       1.394us            96  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     115.553us         7.80%     115.553us       1.204us            96  
                                          aten::__and__         0.42%      91.609us         4.49%     984.808us      11.724us       0.000us         0.00%      99.041us       1.179us            84  
                                      aten::bitwise_and         2.54%     557.575us         4.07%     893.199us      10.633us      99.041us         6.69%      99.041us       1.179us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      99.041us         6.69%      99.041us       1.179us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      86.140us         5.82%      86.140us       1.196us            72  
                                              aten::sub         2.17%     475.165us         3.61%     791.992us      11.000us      79.197us         5.35%      79.197us       1.100us            72  
                                              aten::add         1.62%     354.490us         2.70%     592.103us       9.868us      74.334us         5.02%      74.334us       1.239us            60  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 21.937ms
Self CUDA time total: 1.481ms



======================================================================
PROFILE TRACE: torch_eager | cuda_B1_Q300_H8_E256_L4_P4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            torch_eager         0.00%       0.000us         0.00%       0.000us       0.000us      19.069ms      1196.67%      19.069ms      19.069ms             1  
                                            torch_eager        19.87%       4.152ms        99.97%      20.886ms      20.886ms       0.000us         0.00%       1.594ms       1.594ms             1  
                                            aten::index         4.48%     935.232us        16.67%       3.483ms      72.569us     249.668us        15.67%     382.147us       7.961us            48  
                                            aten::copy_         4.80%       1.003ms        11.85%       2.477ms      11.308us     366.556us        23.00%     366.556us       1.674us           219  
                                              aten::mul         6.04%       1.262ms        10.39%       2.170ms      11.304us     358.714us        22.51%     358.714us       1.868us           192  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     267.167us        16.77%     267.167us       2.226us           120  
void at::native::index_elementwise_kernel<128, 4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     249.668us        15.67%     249.668us       5.201us            48  
                                               aten::to         0.60%     125.408us        11.23%       2.347ms      13.724us       0.000us         0.00%     234.077us       1.369us           171  
                                         aten::_to_copy         1.87%     389.897us        10.63%       2.221ms      18.060us       0.000us         0.00%     234.077us       1.903us           123  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us     169.728us        10.65%     169.728us       2.021us            84  
                                       aten::contiguous         0.35%      74.120us         8.81%       1.840ms      19.167us       0.000us         0.00%     132.479us       1.380us            96  
                                            aten::clone         0.79%     164.425us         8.45%       1.766ms      18.395us       0.000us         0.00%     132.479us       1.380us            96  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     132.479us         8.31%     132.479us       1.380us            96  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     117.475us         7.37%     117.475us       1.224us            96  
                                          aten::__and__         0.44%      90.959us         4.50%     941.006us      11.202us       0.000us         0.00%     105.476us       1.256us            84  
                                      aten::bitwise_and         2.49%     520.216us         4.07%     850.047us      10.120us     105.476us         6.62%     105.476us       1.256us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     105.476us         6.62%     105.476us       1.256us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     104.197us         6.54%     104.197us       1.447us            72  
                                              aten::add         1.62%     338.151us         2.73%     570.998us       9.517us      91.678us         5.75%      91.678us       1.528us            60  
                                              aten::sub         2.14%     447.777us         3.61%     754.447us      10.478us      80.286us         5.04%      80.286us       1.115us            72  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 20.891ms
Self CUDA time total: 1.593ms



======================================================================
PROFILE TRACE: torch_eager | cuda_B2_Q100_H8_E256_L4_P4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            torch_eager         0.00%       0.000us         0.00%       0.000us       0.000us      19.677ms      1279.16%      19.677ms      19.677ms             1  
                                            torch_eager        19.82%       4.280ms        99.97%      21.590ms      21.590ms       0.000us         0.00%       1.539ms       1.539ms             1  
                                            aten::index         4.49%     970.701us        16.56%       3.576ms      74.506us     243.261us        15.81%     377.688us       7.868us            48  
                                            aten::copy_         4.67%       1.008ms        11.52%       2.487ms      11.356us     367.898us        23.92%     367.898us       1.680us           219  
                                              aten::mul         5.96%       1.287ms        10.22%       2.207ms      11.495us     324.384us        21.09%     324.384us       1.690us           192  
void at::native::index_elementwise_kernel<128, 4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     243.261us        15.81%     243.261us       5.068us            48  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     233.533us        15.18%     233.533us       1.946us           120  
                                               aten::to         0.57%     122.968us        11.17%       2.413ms      14.109us       0.000us         0.00%     233.471us       1.365us           171  
                                         aten::_to_copy         1.93%     415.801us        10.60%       2.290ms      18.615us       0.000us         0.00%     233.471us       1.898us           123  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us     169.053us        10.99%     169.053us       2.013us            84  
                                       aten::contiguous         0.37%      80.833us         8.61%       1.859ms      19.360us       0.000us         0.00%     134.427us       1.400us            96  
                                            aten::clone         0.74%     159.128us         8.23%       1.778ms      18.518us       0.000us         0.00%     134.427us       1.400us            96  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     134.427us         8.74%     134.427us       1.400us            96  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     115.871us         7.53%     115.871us       1.207us            96  
                                          aten::__and__         0.43%      92.507us         4.50%     971.781us      11.569us       0.000us         0.00%     104.160us       1.240us            84  
                                      aten::bitwise_and         2.49%     538.828us         4.07%     879.274us      10.468us     104.160us         6.77%     104.160us       1.240us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     104.160us         6.77%     104.160us       1.240us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      95.908us         6.23%      95.908us       1.332us            72  
                                              aten::add         1.64%     354.089us         2.75%     594.321us       9.905us      83.684us         5.44%      83.684us       1.395us            60  
                                              aten::sub         2.17%     468.302us         3.66%     789.975us      10.972us      79.297us         5.15%      79.297us       1.101us            72  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 21.596ms
Self CUDA time total: 1.538ms



======================================================================
PROFILE TRACE: torch_eager | cuda_B2_Q300_H8_E256_L4_P4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            torch_eager         0.00%       0.000us         0.00%       0.000us       0.000us      19.412ms      1097.11%      19.412ms      19.412ms             1  
                                            torch_eager        19.43%       4.188ms        99.97%      21.544ms      21.544ms       0.000us         0.00%       1.770ms       1.770ms             1  
                                              aten::mul         5.88%       1.267ms        10.26%       2.212ms      11.521us     450.496us        25.46%     450.496us       2.346us           192  
                                            aten::index         4.35%     938.379us        16.41%       3.536ms      73.661us     281.281us        15.90%     418.917us       8.727us            48  
                                            aten::copy_         4.72%       1.017ms        12.00%       2.587ms      11.811us     371.333us        20.99%     371.333us       1.696us           219  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     355.809us        20.11%     355.809us       2.965us           120  
void at::native::index_elementwise_kernel<128, 4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     281.281us        15.90%     281.281us       5.860us            48  
                                               aten::to         0.57%     122.376us        11.15%       2.403ms      14.050us       0.000us         0.00%     233.697us       1.367us           171  
                                         aten::_to_copy         1.79%     386.738us        10.58%       2.280ms      18.538us       0.000us         0.00%     233.697us       1.900us           123  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us     167.937us         9.49%     167.937us       1.999us            84  
                                       aten::contiguous         0.36%      77.297us         8.74%       1.884ms      19.624us       0.000us         0.00%     137.636us       1.434us            96  
                                            aten::clone         0.72%     155.217us         8.38%       1.807ms      18.819us       0.000us         0.00%     137.636us       1.434us            96  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     137.636us         7.78%     137.636us       1.434us            96  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     130.211us         7.36%     130.211us       1.808us            72  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     119.940us         6.78%     119.940us       1.249us            96  
                                              aten::add         1.56%     336.953us         2.72%     585.265us       9.754us     114.431us         6.47%     114.431us       1.907us            60  
                                          aten::__and__         0.41%      88.309us         4.45%     959.250us      11.420us       0.000us         0.00%     108.994us       1.298us            84  
                                      aten::bitwise_and         2.40%     517.417us         4.04%     870.941us      10.368us     108.994us         6.16%     108.994us       1.298us            84  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     108.994us         6.16%     108.994us       1.298us            84  
                                              aten::sub         2.15%     464.219us         3.68%     792.358us      11.005us      84.546us         4.78%      84.546us       1.174us            72  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 21.550ms
Self CUDA time total: 1.769ms


impl                     wl                  p50(ms)  ok
torch_eager              cuda_B1_Q100_H8_E256_L4_P4     3.38  True
torch_eager              cuda_B1_Q300_H8_E256_L4_P4     4.08  True
torch_eager              cuda_B2_Q100_H8_E256_L4_P4     4.16  True
torch_eager              cuda_B2_Q300_H8_E256_L4_P4     4.17  True
▶ UV Install Logs