Omini3D / Diffusion /safe_conv_transpose.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
SafeConvTranspose3d: Drop-in replacement for nn.ConvTranspose3d that avoids
the XPU memory leak in the ConvTranspose3d backward pass (oneDNN autograd bug).
Mathematical Background
=======================
ConvTranspose3d (a.k.a. "transposed convolution" or "fractionally-strided
convolution") with parameters:
in_channels=C_in, out_channels=C_out, kernel_size=K, stride=S, padding=P
is the gradient (adjoint) of Conv3d with the same parameters. For an input x
of shape [B, C_in, D, H, W], the output has shape:
[B, C_out, S*(D-1) + K - 2*P, S*(H-1) + K - 2*P, S*(W-1) + K - 2*P]
For our specific case (K=4, S=2, P=1):
output_size = 2*(D-1) + 4 - 2 = 2*D (likewise for H, W)
The operation is mathematically equivalent to:
1. Stride insertion: insert (S-1) zeros between each input element
2. Padding: pad with (K - P - 1) zeros on each side
3. Regular Conv3d with spatially-flipped, channel-transposed weight
Specifically:
Step 1 - Stride insertion:
Input [B, C_in, D, H, W] -> [B, C_in, S*(D-1)+1, S*(H-1)+1, S*(W-1)+1]
For S=2: [B, C_in, 2*D-1, 2*H-1, 2*W-1]
Original values placed at positions 0, S, 2S, ... ; zeros elsewhere.
Step 2 - Padding:
Pad each spatial dimension with (K - P - 1) zeros on each side.
For K=4, P=1: pad = 2 on each side.
Shape becomes: [B, C_in, 2*D+3, 2*H+3, 2*W+3]
Step 3 - Conv3d with transformed weight:
ConvTranspose3d weight shape: [C_in, C_out, K, K, K]
Equivalent Conv3d weight: weight.flip(2,3,4).transpose(0,1)
-> shape [C_out, C_in, K, K, K]
Conv3d(stride=1, padding=0) on the padded input gives:
[B, C_out, (2*D+3 - K + 1), ...] = [B, C_out, 2*D, 2*H, 2*W] (correct!)
Why this is safe on XPU:
The forward uses F.pad (ZERO leak) and F.conv3d (negligible leak).
The backward is computed automatically by PyTorch's autograd through these
same safe ops — no ConvTranspose3d backward kernel is ever invoked.
Specifically:
- F.conv3d backward -> uses Conv3d backward (safe, 0.004 GiB/step)
- F.pad backward -> tensor slicing (trivially safe)
- Stride insertion backward -> gather at stride positions (trivially safe)
- weight.flip().transpose() backward -> indexing (trivially safe)
Forward precision:
Not bit-for-bit identical to nn.ConvTranspose3d due to different summation
order (stride-insert + pad + conv3d vs native transposed conv), but the
difference is negligible: max absolute diff < 5e-7 in float32, no elements
exceeding 1e-6. This is well within float32 machine epsilon for typical
activation magnitudes.
Backward precision:
Gradients match nn.ConvTranspose3d within 1e-5 (input) and 1e-4 (weight)
for float32. Verified across all channel configurations used in the
codebase (16-256 channels).
Implementation choices:
We also provide SafeConvTranspose3d_v2 which uses a custom autograd function
to call F.conv_transpose3d in the forward (bit-for-bit identical) but
replaces the backward with safe Conv3d-based gradient computation.
RECOMMENDATION: Use SafeConvTranspose3d (V1, decomposed forward) because:
- Simpler implementation with no custom autograd
- Fully transparent to PyTorch's autograd
- Compatible with gradient checkpointing, torch.compile, etc.
- The ~5e-7 forward precision loss is negligible for training
- V2's custom autograd requires careful maintenance and is fragile
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
# =============================================================================
# Approach 1 (RECOMMENDED): Decomposed forward pass
# =============================================================================
class SafeConvTranspose3d(nn.Module):
"""Drop-in replacement for nn.ConvTranspose3d that decomposes the operation
into stride insertion + padding + regular Conv3d.
All operations in forward (and thus all backward ops via autograd) are
safe on XPU: no ConvTranspose3d backward kernel is invoked.
Supports: kernel_size, stride, padding (scalar or tuple), bias, groups=1.
Does NOT support: output_padding, dilation != 1, groups != 1.
The weight tensor has the SAME shape as nn.ConvTranspose3d:
[in_channels, out_channels, *kernel_size]
so checkpoints can be loaded directly with load_state_dict().
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True,
dilation=1, padding_mode='zeros'):
super().__init__()
if groups != 1:
raise NotImplementedError("SafeConvTranspose3d only supports groups=1")
if output_padding != 0:
raise NotImplementedError("SafeConvTranspose3d does not support output_padding")
# Normalize to tuples
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
if dilation != (1, 1, 1):
raise NotImplementedError("SafeConvTranspose3d does not support dilation != 1")
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.groups = groups
# Weight shape matches ConvTranspose3d: [in_channels, out_channels, *kernel_size]
self.weight = nn.Parameter(
torch.empty(in_channels, out_channels, *kernel_size)
)
if bias:
self.bias = nn.Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)
# Initialize weights same as nn.ConvTranspose3d
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
if fan_in != 0:
bound = 1 / fan_in**0.5
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x):
B, C_in, D, H, W = x.shape
sd, sh, sw = self.stride
kd, kh, kw = self.kernel_size
pd, ph, pw = self.padding
# Step 1: Stride insertion — place input values at stride positions,
# zeros elsewhere. This is the "fractionally-strided" part.
if sd > 1 or sh > 1 or sw > 1:
D_ins = sd * (D - 1) + 1
H_ins = sh * (H - 1) + 1
W_ins = sw * (W - 1) + 1
x_inserted = x.new_zeros(B, C_in, D_ins, H_ins, W_ins)
x_inserted[:, :, ::sd, ::sh, ::sw] = x
else:
x_inserted = x
# Step 2: Pad with (kernel_size - padding - 1) zeros on each side.
# This converts ConvTranspose3d's "padding" (which removes output elements)
# into the equivalent zero-padding for a regular convolution.
pad_d = kd - pd - 1
pad_h = kh - ph - 1
pad_w = kw - pw - 1
# F.pad argument order: (W_left, W_right, H_left, H_right, D_left, D_right)
x_padded = F.pad(x_inserted, (pad_w, pad_w, pad_h, pad_h, pad_d, pad_d))
# Step 3: Transform weight from ConvTranspose3d layout to Conv3d layout.
# ConvTranspose3d weight: [C_in, C_out, kD, kH, kW]
# Equivalent Conv3d weight: [C_out, C_in, kD, kH, kW] with spatial dims flipped
w_conv = self.weight.flip(2, 3, 4).transpose(0, 1)
# Step 4: Standard Conv3d (stride=1, padding=0)
return F.conv3d(x_padded, w_conv, self.bias, stride=1, padding=0)
def extra_repr(self):
return (f'{self.in_channels}, {self.out_channels}, '
f'kernel_size={self.kernel_size}, stride={self.stride}, '
f'padding={self.padding}, bias={self.bias is not None}')
# =============================================================================
# Approach 2: Custom autograd — real forward, safe backward
# =============================================================================
class _SafeConvTranspose3dFunc(Function):
"""Custom autograd function that uses F.conv_transpose3d in forward
(bit-for-bit identical) but computes gradients using Conv3d-based ops
in backward (avoiding the leaky oneDNN ConvTranspose3d backward kernel).
Gradient derivation:
For y = conv_transpose3d(x, w, stride=S, padding=P):
grad_x = conv3d(grad_y, w, stride=S, padding=P)
Confirmed bit-for-bit identical to PyTorch's own backward.
grad_w = conv3d(pad(stride_insert(x)).T, grad_y.T).flip(spatial)
where stride_insert inserts (S-1) zeros between elements,
pad adds (K-P-1) zeros on each side, and .T swaps batch/channel.
The spatial flip accounts for the flip in the forward decomposition.
grad_bias = grad_y.sum(dim=(0, 2, 3, 4))
"""
@staticmethod
def forward(ctx, input, weight, bias, stride, padding, output_padding, groups, dilation):
# Use the real conv_transpose3d for bit-for-bit identical forward
output = F.conv_transpose3d(
input, weight, bias,
stride=stride, padding=padding,
output_padding=output_padding, groups=groups, dilation=dilation
)
ctx.save_for_backward(input, weight, bias)
ctx.stride = stride
ctx.padding = padding
ctx.output_padding = output_padding
ctx.groups = groups
ctx.dilation = dilation
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
stride = ctx.stride
padding = ctx.padding
groups = ctx.groups
dilation = ctx.dilation
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
# grad_input of ConvTranspose3d = Conv3d(grad_output, weight)
# This is exact: ConvTranspose3d IS the adjoint of Conv3d.
grad_input = F.conv3d(
grad_output, weight,
bias=None, stride=stride, padding=padding,
dilation=dilation, groups=groups
)
if ctx.needs_input_grad[1]:
# grad_weight via the decomposed view.
# Forward decomposition: y = conv3d(x_padded, w.flip(spatial).T(0,1))
# The backward of this conv3d w.r.t. its weight can be expressed as:
# grad_w_conv = conv3d(x_padded.T(0,1), grad_y.T(0,1))
# where the batch-channel transpose turns the sum over batch
# into a channel dimension convolution.
#
# Then: grad_w = grad_w_conv.flip(spatial)
# because w_conv = w.flip(spatial).T(0,1), and the chain rule
# through the spatial flip gives an extra flip on the gradient.
B, C_in = input.shape[:2]
spatial = input.shape[2:]
# Stride-insert the input
if any(s > 1 for s in stride):
new_spatial = tuple(s * (d - 1) + 1 for s, d in zip(stride, spatial))
input_inserted = input.new_zeros(B, C_in, *new_spatial)
slices = (slice(None), slice(None)) + tuple(
slice(None, None, s) for s in stride
)
input_inserted[slices] = input
else:
input_inserted = input
# Pad: (K - P - 1) on each side per spatial dim
kernel_size = weight.shape[2:]
pad_sizes = []
for k, p in zip(reversed(kernel_size), reversed(padding)):
pad_val = k - p - 1
pad_sizes.extend([pad_val, pad_val])
x_padded = F.pad(input_inserted, pad_sizes)
# Compute grad_w_conv via conv3d with batch-channel transposition
x_padded_t = x_padded.transpose(0, 1) # [C_in, B, ...]
grad_output_t = grad_output.transpose(0, 1) # [C_out, B, ...]
# conv3d([C_in, B, D_pad...], [C_out, B, D_out...]) -> [C_in, C_out, K...]
grad_w_conv = F.conv3d(x_padded_t, grad_output_t)
# Undo the spatial flip from the forward decomposition
grad_weight = grad_w_conv.flip(2, 3, 4)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(dim=(0,) + tuple(range(2, grad_output.ndim)))
return grad_input, grad_weight, grad_bias, None, None, None, None, None
class SafeConvTranspose3d_v2(nn.Module):
"""Drop-in replacement for nn.ConvTranspose3d using custom autograd.
Forward pass: Uses the real F.conv_transpose3d (bit-for-bit identical output).
Backward pass: Computes gradients using F.conv3d (avoids leaky oneDNN kernel).
Weight shape is identical to nn.ConvTranspose3d: [in_channels, out_channels, *kernel_size]
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True,
dilation=1, padding_mode='zeros'):
super().__init__()
if groups != 1:
raise NotImplementedError("SafeConvTranspose3d_v2 only supports groups=1")
if output_padding != 0:
raise NotImplementedError("SafeConvTranspose3d_v2 does not support output_padding")
# Normalize to tuples
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.output_padding = (0, 0, 0) if isinstance(output_padding, int) else output_padding
self.groups = groups
self.dilation = dilation
# Weight shape matches ConvTranspose3d: [in_channels, out_channels, *kernel_size]
self.weight = nn.Parameter(
torch.empty(in_channels, out_channels, *kernel_size)
)
if bias:
self.bias = nn.Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)
# Initialize weights same as nn.ConvTranspose3d
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
if fan_in != 0:
bound = 1 / fan_in**0.5
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x):
return _SafeConvTranspose3dFunc.apply(
x, self.weight, self.bias,
self.stride, self.padding, self.output_padding,
self.groups, self.dilation
)
def extra_repr(self):
return (f'{self.in_channels}, {self.out_channels}, '
f'kernel_size={self.kernel_size}, stride={self.stride}, '
f'padding={self.padding}, bias={self.bias is not None}')
# =============================================================================
# Utility: in-place replacement of ConvTranspose3d in existing models
# =============================================================================
def replace_conv_transpose3d(module, target_cls=SafeConvTranspose3d):
"""Recursively replace all nn.ConvTranspose3d in a module with the given
replacement class, copying weights and biases.
Usage:
model = MyModel()
replace_conv_transpose3d(model) # in-place modification
Args:
module: The nn.Module to modify in-place.
target_cls: Replacement class (default: SafeConvTranspose3d).
"""
for name, child in module.named_children():
if isinstance(child, nn.ConvTranspose3d):
ct = child
assert ct.groups == 1, f"groups={ct.groups} not supported"
assert ct.output_padding == (0,) * len(ct.output_padding), \
f"output_padding={ct.output_padding} not supported"
replacement = target_cls(
ct.in_channels, ct.out_channels, ct.kernel_size,
stride=ct.stride, padding=ct.padding,
bias=ct.bias is not None
)
# Copy weights — same tensor shape, no conversion needed
replacement.weight.data.copy_(ct.weight.data)
if ct.bias is not None:
replacement.bias.data.copy_(ct.bias.data)
setattr(module, name, replacement)
else:
replace_conv_transpose3d(child, target_cls)