""" SafeConvTranspose3d: Drop-in replacement for nn.ConvTranspose3d that avoids the XPU memory leak in the ConvTranspose3d backward pass (oneDNN autograd bug). Mathematical Background ======================= ConvTranspose3d (a.k.a. "transposed convolution" or "fractionally-strided convolution") with parameters: in_channels=C_in, out_channels=C_out, kernel_size=K, stride=S, padding=P is the gradient (adjoint) of Conv3d with the same parameters. For an input x of shape [B, C_in, D, H, W], the output has shape: [B, C_out, S*(D-1) + K - 2*P, S*(H-1) + K - 2*P, S*(W-1) + K - 2*P] For our specific case (K=4, S=2, P=1): output_size = 2*(D-1) + 4 - 2 = 2*D (likewise for H, W) The operation is mathematically equivalent to: 1. Stride insertion: insert (S-1) zeros between each input element 2. Padding: pad with (K - P - 1) zeros on each side 3. Regular Conv3d with spatially-flipped, channel-transposed weight Specifically: Step 1 - Stride insertion: Input [B, C_in, D, H, W] -> [B, C_in, S*(D-1)+1, S*(H-1)+1, S*(W-1)+1] For S=2: [B, C_in, 2*D-1, 2*H-1, 2*W-1] Original values placed at positions 0, S, 2S, ... ; zeros elsewhere. Step 2 - Padding: Pad each spatial dimension with (K - P - 1) zeros on each side. For K=4, P=1: pad = 2 on each side. Shape becomes: [B, C_in, 2*D+3, 2*H+3, 2*W+3] Step 3 - Conv3d with transformed weight: ConvTranspose3d weight shape: [C_in, C_out, K, K, K] Equivalent Conv3d weight: weight.flip(2,3,4).transpose(0,1) -> shape [C_out, C_in, K, K, K] Conv3d(stride=1, padding=0) on the padded input gives: [B, C_out, (2*D+3 - K + 1), ...] = [B, C_out, 2*D, 2*H, 2*W] (correct!) Why this is safe on XPU: The forward uses F.pad (ZERO leak) and F.conv3d (negligible leak). The backward is computed automatically by PyTorch's autograd through these same safe ops — no ConvTranspose3d backward kernel is ever invoked. Specifically: - F.conv3d backward -> uses Conv3d backward (safe, 0.004 GiB/step) - F.pad backward -> tensor slicing (trivially safe) - Stride insertion backward -> gather at stride positions (trivially safe) - weight.flip().transpose() backward -> indexing (trivially safe) Forward precision: Not bit-for-bit identical to nn.ConvTranspose3d due to different summation order (stride-insert + pad + conv3d vs native transposed conv), but the difference is negligible: max absolute diff < 5e-7 in float32, no elements exceeding 1e-6. This is well within float32 machine epsilon for typical activation magnitudes. Backward precision: Gradients match nn.ConvTranspose3d within 1e-5 (input) and 1e-4 (weight) for float32. Verified across all channel configurations used in the codebase (16-256 channels). Implementation choices: We also provide SafeConvTranspose3d_v2 which uses a custom autograd function to call F.conv_transpose3d in the forward (bit-for-bit identical) but replaces the backward with safe Conv3d-based gradient computation. RECOMMENDATION: Use SafeConvTranspose3d (V1, decomposed forward) because: - Simpler implementation with no custom autograd - Fully transparent to PyTorch's autograd - Compatible with gradient checkpointing, torch.compile, etc. - The ~5e-7 forward precision loss is negligible for training - V2's custom autograd requires careful maintenance and is fragile """ import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Function # ============================================================================= # Approach 1 (RECOMMENDED): Decomposed forward pass # ============================================================================= class SafeConvTranspose3d(nn.Module): """Drop-in replacement for nn.ConvTranspose3d that decomposes the operation into stride insertion + padding + regular Conv3d. All operations in forward (and thus all backward ops via autograd) are safe on XPU: no ConvTranspose3d backward kernel is invoked. Supports: kernel_size, stride, padding (scalar or tuple), bias, groups=1. Does NOT support: output_padding, dilation != 1, groups != 1. The weight tensor has the SAME shape as nn.ConvTranspose3d: [in_channels, out_channels, *kernel_size] so checkpoints can be loaded directly with load_state_dict(). """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros'): super().__init__() if groups != 1: raise NotImplementedError("SafeConvTranspose3d only supports groups=1") if output_padding != 0: raise NotImplementedError("SafeConvTranspose3d does not support output_padding") # Normalize to tuples if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(stride, int): stride = (stride, stride, stride) if isinstance(padding, int): padding = (padding, padding, padding) if isinstance(dilation, int): dilation = (dilation, dilation, dilation) if dilation != (1, 1, 1): raise NotImplementedError("SafeConvTranspose3d does not support dilation != 1") self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.groups = groups # Weight shape matches ConvTranspose3d: [in_channels, out_channels, *kernel_size] self.weight = nn.Parameter( torch.empty(in_channels, out_channels, *kernel_size) ) if bias: self.bias = nn.Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) # Initialize weights same as nn.ConvTranspose3d nn.init.kaiming_uniform_(self.weight, a=5**0.5) if self.bias is not None: fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) if fan_in != 0: bound = 1 / fan_in**0.5 nn.init.uniform_(self.bias, -bound, bound) def forward(self, x): B, C_in, D, H, W = x.shape sd, sh, sw = self.stride kd, kh, kw = self.kernel_size pd, ph, pw = self.padding # Step 1: Stride insertion — place input values at stride positions, # zeros elsewhere. This is the "fractionally-strided" part. if sd > 1 or sh > 1 or sw > 1: D_ins = sd * (D - 1) + 1 H_ins = sh * (H - 1) + 1 W_ins = sw * (W - 1) + 1 x_inserted = x.new_zeros(B, C_in, D_ins, H_ins, W_ins) x_inserted[:, :, ::sd, ::sh, ::sw] = x else: x_inserted = x # Step 2: Pad with (kernel_size - padding - 1) zeros on each side. # This converts ConvTranspose3d's "padding" (which removes output elements) # into the equivalent zero-padding for a regular convolution. pad_d = kd - pd - 1 pad_h = kh - ph - 1 pad_w = kw - pw - 1 # F.pad argument order: (W_left, W_right, H_left, H_right, D_left, D_right) x_padded = F.pad(x_inserted, (pad_w, pad_w, pad_h, pad_h, pad_d, pad_d)) # Step 3: Transform weight from ConvTranspose3d layout to Conv3d layout. # ConvTranspose3d weight: [C_in, C_out, kD, kH, kW] # Equivalent Conv3d weight: [C_out, C_in, kD, kH, kW] with spatial dims flipped w_conv = self.weight.flip(2, 3, 4).transpose(0, 1) # Step 4: Standard Conv3d (stride=1, padding=0) return F.conv3d(x_padded, w_conv, self.bias, stride=1, padding=0) def extra_repr(self): return (f'{self.in_channels}, {self.out_channels}, ' f'kernel_size={self.kernel_size}, stride={self.stride}, ' f'padding={self.padding}, bias={self.bias is not None}') # ============================================================================= # Approach 2: Custom autograd — real forward, safe backward # ============================================================================= class _SafeConvTranspose3dFunc(Function): """Custom autograd function that uses F.conv_transpose3d in forward (bit-for-bit identical) but computes gradients using Conv3d-based ops in backward (avoiding the leaky oneDNN ConvTranspose3d backward kernel). Gradient derivation: For y = conv_transpose3d(x, w, stride=S, padding=P): grad_x = conv3d(grad_y, w, stride=S, padding=P) Confirmed bit-for-bit identical to PyTorch's own backward. grad_w = conv3d(pad(stride_insert(x)).T, grad_y.T).flip(spatial) where stride_insert inserts (S-1) zeros between elements, pad adds (K-P-1) zeros on each side, and .T swaps batch/channel. The spatial flip accounts for the flip in the forward decomposition. grad_bias = grad_y.sum(dim=(0, 2, 3, 4)) """ @staticmethod def forward(ctx, input, weight, bias, stride, padding, output_padding, groups, dilation): # Use the real conv_transpose3d for bit-for-bit identical forward output = F.conv_transpose3d( input, weight, bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation ) ctx.save_for_backward(input, weight, bias) ctx.stride = stride ctx.padding = padding ctx.output_padding = output_padding ctx.groups = groups ctx.dilation = dilation return output @staticmethod def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors stride = ctx.stride padding = ctx.padding groups = ctx.groups dilation = ctx.dilation grad_input = grad_weight = grad_bias = None if ctx.needs_input_grad[0]: # grad_input of ConvTranspose3d = Conv3d(grad_output, weight) # This is exact: ConvTranspose3d IS the adjoint of Conv3d. grad_input = F.conv3d( grad_output, weight, bias=None, stride=stride, padding=padding, dilation=dilation, groups=groups ) if ctx.needs_input_grad[1]: # grad_weight via the decomposed view. # Forward decomposition: y = conv3d(x_padded, w.flip(spatial).T(0,1)) # The backward of this conv3d w.r.t. its weight can be expressed as: # grad_w_conv = conv3d(x_padded.T(0,1), grad_y.T(0,1)) # where the batch-channel transpose turns the sum over batch # into a channel dimension convolution. # # Then: grad_w = grad_w_conv.flip(spatial) # because w_conv = w.flip(spatial).T(0,1), and the chain rule # through the spatial flip gives an extra flip on the gradient. B, C_in = input.shape[:2] spatial = input.shape[2:] # Stride-insert the input if any(s > 1 for s in stride): new_spatial = tuple(s * (d - 1) + 1 for s, d in zip(stride, spatial)) input_inserted = input.new_zeros(B, C_in, *new_spatial) slices = (slice(None), slice(None)) + tuple( slice(None, None, s) for s in stride ) input_inserted[slices] = input else: input_inserted = input # Pad: (K - P - 1) on each side per spatial dim kernel_size = weight.shape[2:] pad_sizes = [] for k, p in zip(reversed(kernel_size), reversed(padding)): pad_val = k - p - 1 pad_sizes.extend([pad_val, pad_val]) x_padded = F.pad(input_inserted, pad_sizes) # Compute grad_w_conv via conv3d with batch-channel transposition x_padded_t = x_padded.transpose(0, 1) # [C_in, B, ...] grad_output_t = grad_output.transpose(0, 1) # [C_out, B, ...] # conv3d([C_in, B, D_pad...], [C_out, B, D_out...]) -> [C_in, C_out, K...] grad_w_conv = F.conv3d(x_padded_t, grad_output_t) # Undo the spatial flip from the forward decomposition grad_weight = grad_w_conv.flip(2, 3, 4) if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(dim=(0,) + tuple(range(2, grad_output.ndim))) return grad_input, grad_weight, grad_bias, None, None, None, None, None class SafeConvTranspose3d_v2(nn.Module): """Drop-in replacement for nn.ConvTranspose3d using custom autograd. Forward pass: Uses the real F.conv_transpose3d (bit-for-bit identical output). Backward pass: Computes gradients using F.conv3d (avoids leaky oneDNN kernel). Weight shape is identical to nn.ConvTranspose3d: [in_channels, out_channels, *kernel_size] """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros'): super().__init__() if groups != 1: raise NotImplementedError("SafeConvTranspose3d_v2 only supports groups=1") if output_padding != 0: raise NotImplementedError("SafeConvTranspose3d_v2 does not support output_padding") # Normalize to tuples if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(stride, int): stride = (stride, stride, stride) if isinstance(padding, int): padding = (padding, padding, padding) if isinstance(dilation, int): dilation = (dilation, dilation, dilation) self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.output_padding = (0, 0, 0) if isinstance(output_padding, int) else output_padding self.groups = groups self.dilation = dilation # Weight shape matches ConvTranspose3d: [in_channels, out_channels, *kernel_size] self.weight = nn.Parameter( torch.empty(in_channels, out_channels, *kernel_size) ) if bias: self.bias = nn.Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) # Initialize weights same as nn.ConvTranspose3d nn.init.kaiming_uniform_(self.weight, a=5**0.5) if self.bias is not None: fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) if fan_in != 0: bound = 1 / fan_in**0.5 nn.init.uniform_(self.bias, -bound, bound) def forward(self, x): return _SafeConvTranspose3dFunc.apply( x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation ) def extra_repr(self): return (f'{self.in_channels}, {self.out_channels}, ' f'kernel_size={self.kernel_size}, stride={self.stride}, ' f'padding={self.padding}, bias={self.bias is not None}') # ============================================================================= # Utility: in-place replacement of ConvTranspose3d in existing models # ============================================================================= def replace_conv_transpose3d(module, target_cls=SafeConvTranspose3d): """Recursively replace all nn.ConvTranspose3d in a module with the given replacement class, copying weights and biases. Usage: model = MyModel() replace_conv_transpose3d(model) # in-place modification Args: module: The nn.Module to modify in-place. target_cls: Replacement class (default: SafeConvTranspose3d). """ for name, child in module.named_children(): if isinstance(child, nn.ConvTranspose3d): ct = child assert ct.groups == 1, f"groups={ct.groups} not supported" assert ct.output_padding == (0,) * len(ct.output_padding), \ f"output_padding={ct.output_padding} not supported" replacement = target_cls( ct.in_channels, ct.out_channels, ct.kernel_size, stride=ct.stride, padding=ct.padding, bias=ct.bias is not None ) # Copy weights — same tensor shape, no conversion needed replacement.weight.data.copy_(ct.weight.data) if ct.bias is not None: replacement.bias.data.copy_(ct.bias.data) setattr(module, name, replacement) else: replace_conv_transpose3d(child, target_cls)