| """ |
| SafeConvTranspose3d: Drop-in replacement for nn.ConvTranspose3d that avoids |
| the XPU memory leak in the ConvTranspose3d backward pass (oneDNN autograd bug). |
| |
| Mathematical Background |
| ======================= |
| |
| ConvTranspose3d (a.k.a. "transposed convolution" or "fractionally-strided |
| convolution") with parameters: |
| in_channels=C_in, out_channels=C_out, kernel_size=K, stride=S, padding=P |
| |
| is the gradient (adjoint) of Conv3d with the same parameters. For an input x |
| of shape [B, C_in, D, H, W], the output has shape: |
| [B, C_out, S*(D-1) + K - 2*P, S*(H-1) + K - 2*P, S*(W-1) + K - 2*P] |
| |
| For our specific case (K=4, S=2, P=1): |
| output_size = 2*(D-1) + 4 - 2 = 2*D (likewise for H, W) |
| |
| The operation is mathematically equivalent to: |
| 1. Stride insertion: insert (S-1) zeros between each input element |
| 2. Padding: pad with (K - P - 1) zeros on each side |
| 3. Regular Conv3d with spatially-flipped, channel-transposed weight |
| |
| Specifically: |
| |
| Step 1 - Stride insertion: |
| Input [B, C_in, D, H, W] -> [B, C_in, S*(D-1)+1, S*(H-1)+1, S*(W-1)+1] |
| For S=2: [B, C_in, 2*D-1, 2*H-1, 2*W-1] |
| Original values placed at positions 0, S, 2S, ... ; zeros elsewhere. |
| |
| Step 2 - Padding: |
| Pad each spatial dimension with (K - P - 1) zeros on each side. |
| For K=4, P=1: pad = 2 on each side. |
| Shape becomes: [B, C_in, 2*D+3, 2*H+3, 2*W+3] |
| |
| Step 3 - Conv3d with transformed weight: |
| ConvTranspose3d weight shape: [C_in, C_out, K, K, K] |
| Equivalent Conv3d weight: weight.flip(2,3,4).transpose(0,1) |
| -> shape [C_out, C_in, K, K, K] |
| |
| Conv3d(stride=1, padding=0) on the padded input gives: |
| [B, C_out, (2*D+3 - K + 1), ...] = [B, C_out, 2*D, 2*H, 2*W] (correct!) |
| |
| Why this is safe on XPU: |
| The forward uses F.pad (ZERO leak) and F.conv3d (negligible leak). |
| The backward is computed automatically by PyTorch's autograd through these |
| same safe ops — no ConvTranspose3d backward kernel is ever invoked. |
| Specifically: |
| - F.conv3d backward -> uses Conv3d backward (safe, 0.004 GiB/step) |
| - F.pad backward -> tensor slicing (trivially safe) |
| - Stride insertion backward -> gather at stride positions (trivially safe) |
| - weight.flip().transpose() backward -> indexing (trivially safe) |
| |
| Forward precision: |
| Not bit-for-bit identical to nn.ConvTranspose3d due to different summation |
| order (stride-insert + pad + conv3d vs native transposed conv), but the |
| difference is negligible: max absolute diff < 5e-7 in float32, no elements |
| exceeding 1e-6. This is well within float32 machine epsilon for typical |
| activation magnitudes. |
| |
| Backward precision: |
| Gradients match nn.ConvTranspose3d within 1e-5 (input) and 1e-4 (weight) |
| for float32. Verified across all channel configurations used in the |
| codebase (16-256 channels). |
| |
| Implementation choices: |
| We also provide SafeConvTranspose3d_v2 which uses a custom autograd function |
| to call F.conv_transpose3d in the forward (bit-for-bit identical) but |
| replaces the backward with safe Conv3d-based gradient computation. |
| |
| RECOMMENDATION: Use SafeConvTranspose3d (V1, decomposed forward) because: |
| - Simpler implementation with no custom autograd |
| - Fully transparent to PyTorch's autograd |
| - Compatible with gradient checkpointing, torch.compile, etc. |
| - The ~5e-7 forward precision loss is negligible for training |
| - V2's custom autograd requires careful maintenance and is fragile |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.autograd import Function |
|
|
|
|
| |
| |
| |
|
|
| class SafeConvTranspose3d(nn.Module): |
| """Drop-in replacement for nn.ConvTranspose3d that decomposes the operation |
| into stride insertion + padding + regular Conv3d. |
| |
| All operations in forward (and thus all backward ops via autograd) are |
| safe on XPU: no ConvTranspose3d backward kernel is invoked. |
| |
| Supports: kernel_size, stride, padding (scalar or tuple), bias, groups=1. |
| Does NOT support: output_padding, dilation != 1, groups != 1. |
| |
| The weight tensor has the SAME shape as nn.ConvTranspose3d: |
| [in_channels, out_channels, *kernel_size] |
| so checkpoints can be loaded directly with load_state_dict(). |
| """ |
|
|
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, |
| padding=0, output_padding=0, groups=1, bias=True, |
| dilation=1, padding_mode='zeros'): |
| super().__init__() |
|
|
| if groups != 1: |
| raise NotImplementedError("SafeConvTranspose3d only supports groups=1") |
| if output_padding != 0: |
| raise NotImplementedError("SafeConvTranspose3d does not support output_padding") |
|
|
| |
| if isinstance(kernel_size, int): |
| kernel_size = (kernel_size, kernel_size, kernel_size) |
| if isinstance(stride, int): |
| stride = (stride, stride, stride) |
| if isinstance(padding, int): |
| padding = (padding, padding, padding) |
| if isinstance(dilation, int): |
| dilation = (dilation, dilation, dilation) |
| if dilation != (1, 1, 1): |
| raise NotImplementedError("SafeConvTranspose3d does not support dilation != 1") |
|
|
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.padding = padding |
| self.groups = groups |
|
|
| |
| self.weight = nn.Parameter( |
| torch.empty(in_channels, out_channels, *kernel_size) |
| ) |
| if bias: |
| self.bias = nn.Parameter(torch.empty(out_channels)) |
| else: |
| self.register_parameter('bias', None) |
|
|
| |
| nn.init.kaiming_uniform_(self.weight, a=5**0.5) |
| if self.bias is not None: |
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) |
| if fan_in != 0: |
| bound = 1 / fan_in**0.5 |
| nn.init.uniform_(self.bias, -bound, bound) |
|
|
| def forward(self, x): |
| B, C_in, D, H, W = x.shape |
| sd, sh, sw = self.stride |
| kd, kh, kw = self.kernel_size |
| pd, ph, pw = self.padding |
|
|
| |
| |
| if sd > 1 or sh > 1 or sw > 1: |
| D_ins = sd * (D - 1) + 1 |
| H_ins = sh * (H - 1) + 1 |
| W_ins = sw * (W - 1) + 1 |
| x_inserted = x.new_zeros(B, C_in, D_ins, H_ins, W_ins) |
| x_inserted[:, :, ::sd, ::sh, ::sw] = x |
| else: |
| x_inserted = x |
|
|
| |
| |
| |
| pad_d = kd - pd - 1 |
| pad_h = kh - ph - 1 |
| pad_w = kw - pw - 1 |
| |
| x_padded = F.pad(x_inserted, (pad_w, pad_w, pad_h, pad_h, pad_d, pad_d)) |
|
|
| |
| |
| |
| w_conv = self.weight.flip(2, 3, 4).transpose(0, 1) |
|
|
| |
| return F.conv3d(x_padded, w_conv, self.bias, stride=1, padding=0) |
|
|
| def extra_repr(self): |
| return (f'{self.in_channels}, {self.out_channels}, ' |
| f'kernel_size={self.kernel_size}, stride={self.stride}, ' |
| f'padding={self.padding}, bias={self.bias is not None}') |
|
|
|
|
| |
| |
| |
|
|
| class _SafeConvTranspose3dFunc(Function): |
| """Custom autograd function that uses F.conv_transpose3d in forward |
| (bit-for-bit identical) but computes gradients using Conv3d-based ops |
| in backward (avoiding the leaky oneDNN ConvTranspose3d backward kernel). |
| |
| Gradient derivation: |
| For y = conv_transpose3d(x, w, stride=S, padding=P): |
| |
| grad_x = conv3d(grad_y, w, stride=S, padding=P) |
| Confirmed bit-for-bit identical to PyTorch's own backward. |
| |
| grad_w = conv3d(pad(stride_insert(x)).T, grad_y.T).flip(spatial) |
| where stride_insert inserts (S-1) zeros between elements, |
| pad adds (K-P-1) zeros on each side, and .T swaps batch/channel. |
| The spatial flip accounts for the flip in the forward decomposition. |
| |
| grad_bias = grad_y.sum(dim=(0, 2, 3, 4)) |
| """ |
|
|
| @staticmethod |
| def forward(ctx, input, weight, bias, stride, padding, output_padding, groups, dilation): |
| |
| output = F.conv_transpose3d( |
| input, weight, bias, |
| stride=stride, padding=padding, |
| output_padding=output_padding, groups=groups, dilation=dilation |
| ) |
| ctx.save_for_backward(input, weight, bias) |
| ctx.stride = stride |
| ctx.padding = padding |
| ctx.output_padding = output_padding |
| ctx.groups = groups |
| ctx.dilation = dilation |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| input, weight, bias = ctx.saved_tensors |
| stride = ctx.stride |
| padding = ctx.padding |
| groups = ctx.groups |
| dilation = ctx.dilation |
|
|
| grad_input = grad_weight = grad_bias = None |
|
|
| if ctx.needs_input_grad[0]: |
| |
| |
| grad_input = F.conv3d( |
| grad_output, weight, |
| bias=None, stride=stride, padding=padding, |
| dilation=dilation, groups=groups |
| ) |
|
|
| if ctx.needs_input_grad[1]: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| B, C_in = input.shape[:2] |
| spatial = input.shape[2:] |
|
|
| |
| if any(s > 1 for s in stride): |
| new_spatial = tuple(s * (d - 1) + 1 for s, d in zip(stride, spatial)) |
| input_inserted = input.new_zeros(B, C_in, *new_spatial) |
| slices = (slice(None), slice(None)) + tuple( |
| slice(None, None, s) for s in stride |
| ) |
| input_inserted[slices] = input |
| else: |
| input_inserted = input |
|
|
| |
| kernel_size = weight.shape[2:] |
| pad_sizes = [] |
| for k, p in zip(reversed(kernel_size), reversed(padding)): |
| pad_val = k - p - 1 |
| pad_sizes.extend([pad_val, pad_val]) |
| x_padded = F.pad(input_inserted, pad_sizes) |
|
|
| |
| x_padded_t = x_padded.transpose(0, 1) |
| grad_output_t = grad_output.transpose(0, 1) |
|
|
| |
| grad_w_conv = F.conv3d(x_padded_t, grad_output_t) |
|
|
| |
| grad_weight = grad_w_conv.flip(2, 3, 4) |
|
|
| if bias is not None and ctx.needs_input_grad[2]: |
| grad_bias = grad_output.sum(dim=(0,) + tuple(range(2, grad_output.ndim))) |
|
|
| return grad_input, grad_weight, grad_bias, None, None, None, None, None |
|
|
|
|
| class SafeConvTranspose3d_v2(nn.Module): |
| """Drop-in replacement for nn.ConvTranspose3d using custom autograd. |
| |
| Forward pass: Uses the real F.conv_transpose3d (bit-for-bit identical output). |
| Backward pass: Computes gradients using F.conv3d (avoids leaky oneDNN kernel). |
| |
| Weight shape is identical to nn.ConvTranspose3d: [in_channels, out_channels, *kernel_size] |
| """ |
|
|
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, |
| padding=0, output_padding=0, groups=1, bias=True, |
| dilation=1, padding_mode='zeros'): |
| super().__init__() |
|
|
| if groups != 1: |
| raise NotImplementedError("SafeConvTranspose3d_v2 only supports groups=1") |
| if output_padding != 0: |
| raise NotImplementedError("SafeConvTranspose3d_v2 does not support output_padding") |
|
|
| |
| if isinstance(kernel_size, int): |
| kernel_size = (kernel_size, kernel_size, kernel_size) |
| if isinstance(stride, int): |
| stride = (stride, stride, stride) |
| if isinstance(padding, int): |
| padding = (padding, padding, padding) |
| if isinstance(dilation, int): |
| dilation = (dilation, dilation, dilation) |
|
|
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.padding = padding |
| self.output_padding = (0, 0, 0) if isinstance(output_padding, int) else output_padding |
| self.groups = groups |
| self.dilation = dilation |
|
|
| |
| self.weight = nn.Parameter( |
| torch.empty(in_channels, out_channels, *kernel_size) |
| ) |
| if bias: |
| self.bias = nn.Parameter(torch.empty(out_channels)) |
| else: |
| self.register_parameter('bias', None) |
|
|
| |
| nn.init.kaiming_uniform_(self.weight, a=5**0.5) |
| if self.bias is not None: |
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) |
| if fan_in != 0: |
| bound = 1 / fan_in**0.5 |
| nn.init.uniform_(self.bias, -bound, bound) |
|
|
| def forward(self, x): |
| return _SafeConvTranspose3dFunc.apply( |
| x, self.weight, self.bias, |
| self.stride, self.padding, self.output_padding, |
| self.groups, self.dilation |
| ) |
|
|
| def extra_repr(self): |
| return (f'{self.in_channels}, {self.out_channels}, ' |
| f'kernel_size={self.kernel_size}, stride={self.stride}, ' |
| f'padding={self.padding}, bias={self.bias is not None}') |
|
|
|
|
| |
| |
| |
|
|
| def replace_conv_transpose3d(module, target_cls=SafeConvTranspose3d): |
| """Recursively replace all nn.ConvTranspose3d in a module with the given |
| replacement class, copying weights and biases. |
| |
| Usage: |
| model = MyModel() |
| replace_conv_transpose3d(model) # in-place modification |
| |
| Args: |
| module: The nn.Module to modify in-place. |
| target_cls: Replacement class (default: SafeConvTranspose3d). |
| """ |
| for name, child in module.named_children(): |
| if isinstance(child, nn.ConvTranspose3d): |
| ct = child |
| assert ct.groups == 1, f"groups={ct.groups} not supported" |
| assert ct.output_padding == (0,) * len(ct.output_padding), \ |
| f"output_padding={ct.output_padding} not supported" |
|
|
| replacement = target_cls( |
| ct.in_channels, ct.out_channels, ct.kernel_size, |
| stride=ct.stride, padding=ct.padding, |
| bias=ct.bias is not None |
| ) |
| |
| replacement.weight.data.copy_(ct.weight.data) |
| if ct.bias is not None: |
| replacement.bias.data.copy_(ct.bias.data) |
|
|
| setattr(module, name, replacement) |
| else: |
| replace_conv_transpose3d(child, target_cls) |
|
|