# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Model architecture layers used in the paper "Elucidating the Design Space of Diffusion-Based Generative Models". """ import contextlib import importlib import math from typing import Any, Dict, List, Literal, Set import numpy as np import nvtx import torch from einops import rearrange from torch.nn.functional import elu, gelu, leaky_relu, relu, sigmoid, silu, tanh from physicsnemo.models.diffusion import weight_init # Import apex GroupNorm if installed only _is_apex_available = False if torch.cuda.is_available(): try: apex_gn_module = importlib.import_module("apex.contrib.group_norm") ApexGroupNorm = getattr(apex_gn_module, "GroupNorm") _is_apex_available = True except ImportError: pass def _validate_amp(amp_mode: bool) -> None: """Raise if `amp_mode` is False but PyTorch autocast (CPU or CUDA) is active. Parameters ---------- amp_mode : bool Your intended AMP flag. Set False when you require full precision. """ try: cuda_amp = bool(torch.is_autocast_enabled()) except AttributeError: # very old PyTorch cuda_amp = False try: cpu_amp = bool(torch.is_autocast_enabled("cpu")) except AttributeError: cpu_amp = False if not amp_mode and (cuda_amp or cpu_amp): active = [] if cuda_amp: active.append("cuda") if cpu_amp: active.append("cpu") raise RuntimeError( f"amp_mode=False but torch autocast is enabled on: {', '.join(active)}. " "Disable autocast for this region or set amp_mode=True if mixed precision is intended." ) class Linear(torch.nn.Module): """ A fully connected (dense) layer implementation. The layer's weights and biases can be initialized using custom initialization strategies like "kaiming_normal", and can be further scaled by factors `init_weight` and `init_bias`. Parameters ---------- in_features : int Size of each input sample. out_features : int Size of each output sample. bias : bool, optional The biases of the layer. If set to `None`, the layer will not learn an additive bias. By default True. init_mode : str, optional (default="kaiming_normal") The mode/type of initialization to use for weights and biases. Supported modes are: - "xavier_uniform": Xavier (Glorot) uniform initialization. - "xavier_normal": Xavier (Glorot) normal initialization. - "kaiming_uniform": Kaiming (He) uniform initialization. - "kaiming_normal": Kaiming (He) normal initialization. By default "kaiming_normal". init_weight : float, optional A scaling factor to multiply with the initialized weights. By default 1. init_bias : float, optional A scaling factor to multiply with the initialized biases. By default 0. amp_mode : bool, optional A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ def __init__( self, in_features: int, out_features: int, bias: bool = True, init_mode: str = "kaiming_normal", init_weight: int = 1, init_bias: int = 0, amp_mode: bool = False, ): super().__init__() self.in_features = in_features self.out_features = out_features self.amp_mode = amp_mode init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) self.weight = torch.nn.Parameter( weight_init([out_features, in_features], **init_kwargs) * init_weight ) self.bias = ( torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None ) def forward(self, x): weight, bias = self.weight, self.bias _validate_amp(self.amp_mode) if not self.amp_mode: if self.weight is not None and self.weight.dtype != x.dtype: weight = self.weight.to(x.dtype) if self.bias is not None and self.bias.dtype != x.dtype: bias = self.bias.to(x.dtype) x = x @ weight.t() if self.bias is not None: x = x.add_(bias) return x class Conv2d(torch.nn.Module): """ A custom 2D convolutional layer implementation with support for up-sampling, down-sampling, and custom weight and bias initializations. The layer's weights and biases canbe initialized using custom initialization strategies like "kaiming_normal", and can be further scaled by factors `init_weight` and `init_bias`. Parameters ---------- in_channels : int Number of channels in the input image. out_channels : int Number of channels produced by the convolution. kernel : int Size of the convolving kernel. bias : bool, optional The biases of the layer. If set to `None`, the layer will not learn an additive bias. By default True. up : bool, optional Whether to perform up-sampling. By default False. down : bool, optional Whether to perform down-sampling. By default False. resample_filter : List[int], optional Filter to be used for resampling. By default [1, 1]. fused_resample : bool, optional If True, performs fused up-sampling and convolution or fused down-sampling and convolution. By default False. init_mode : str, optional (default="kaiming_normal") init_mode : str, optional (default="kaiming_normal") The mode/type of initialization to use for weights and biases. Supported modes are: - "xavier_uniform": Xavier (Glorot) uniform initialization. - "xavier_normal": Xavier (Glorot) normal initialization. - "kaiming_uniform": Kaiming (He) uniform initialization. - "kaiming_normal": Kaiming (He) normal initialization. By default "kaiming_normal". init_weight : float, optional A scaling factor to multiply with the initialized weights. By default 1.0. init_bias : float, optional A scaling factor to multiply with the initialized biases. By default 0.0. fused_conv_bias: bool, optional A boolean flag indicating whether bias will be passed as a parameter of conv2d. By default False. amp_mode : bool, optional A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ def __init__( self, in_channels: int, out_channels: int, kernel: int, bias: bool = True, up: bool = False, down: bool = False, resample_filter: List[int] = [1, 1], fused_resample: bool = False, init_mode: str = "kaiming_normal", init_weight: float = 1.0, init_bias: float = 0.0, fused_conv_bias: bool = False, amp_mode: bool = False, ): if up and down: raise ValueError("Both 'up' and 'down' cannot be true at the same time.") if not kernel and fused_conv_bias: print( "Warning: Kernel is required when fused_conv_bias is enabled. Setting fused_conv_bias to False." ) fused_conv_bias = False super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.up = up self.down = down self.fused_resample = fused_resample self.fused_conv_bias = fused_conv_bias self.amp_mode = amp_mode init_kwargs = dict( mode=init_mode, fan_in=in_channels * kernel * kernel, fan_out=out_channels * kernel * kernel, ) self.weight = ( torch.nn.Parameter( weight_init([out_channels, in_channels, kernel, kernel], **init_kwargs) * init_weight ) if kernel else None ) self.bias = ( torch.nn.Parameter(weight_init([out_channels], **init_kwargs) * init_bias) if kernel and bias else None ) f = torch.as_tensor(resample_filter, dtype=torch.float32) f = f.ger(f).unsqueeze(0).unsqueeze(1) / f.sum().square() self.register_buffer("resample_filter", f if up or down else None) def forward(self, x): weight, bias, resample_filter = self.weight, self.bias, self.resample_filter _validate_amp(self.amp_mode) if not self.amp_mode: if self.weight is not None and self.weight.dtype != x.dtype: weight = self.weight.to(x.dtype) if self.bias is not None and self.bias.dtype != x.dtype: bias = self.bias.to(x.dtype) if ( self.resample_filter is not None and self.resample_filter.dtype != x.dtype ): resample_filter = self.resample_filter.to(x.dtype) w = weight if weight is not None else None b = bias if bias is not None else None f = resample_filter if resample_filter is not None else None w_pad = w.shape[-1] // 2 if w is not None else 0 f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0 if self.fused_resample and self.up and w is not None: x = torch.nn.functional.conv_transpose2d( x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=max(f_pad - w_pad, 0), ) if self.fused_conv_bias: x = torch.nn.functional.conv2d( x, w, padding=max(w_pad - f_pad, 0), bias=b ) else: x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0)) elif self.fused_resample and self.down and w is not None: x = torch.nn.functional.conv2d(x, w, padding=w_pad + f_pad) if self.fused_conv_bias: x = torch.nn.functional.conv2d( x, f.tile([self.out_channels, 1, 1, 1]), groups=self.out_channels, stride=2, bias=b, ) else: x = torch.nn.functional.conv2d( x, f.tile([self.out_channels, 1, 1, 1]), groups=self.out_channels, stride=2, ) else: if self.up: x = torch.nn.functional.conv_transpose2d( x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad, ) if self.down: x = torch.nn.functional.conv2d( x, f.tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad, ) if w is not None: # ask in corrdiff channel whether w will ever be none if self.fused_conv_bias: x = torch.nn.functional.conv2d(x, w, padding=w_pad, bias=b) else: x = torch.nn.functional.conv2d(x, w, padding=w_pad) if b is not None and not self.fused_conv_bias: x = x.add_(b.reshape(1, -1, 1, 1)) return x def _compute_groupnorm_groups( num_channels: int, num_groups: int = 32, min_channels_per_group: int = 4, ) -> int: """ Compute the number of groups for GroupNorm based on the number of channels and the minimum number of channels per group. Parameters ---------- num_channels : int Number of channels in the input tensor. num_groups : int, optional, default=32 Desired number of groups to divide the input channels. This might be adjusted based on the ``min_channels_per_group``. min_channels_per_group : int, optional, default=4 Minimum channels required per group. This ensures that no group has fewer channels than this number. Returns ------- int The number of groups to use for GroupNorm. """ num_groups: int = min( num_groups, (num_channels + min_channels_per_group - 1) // min_channels_per_group, ) if num_channels % num_groups != 0: raise ValueError( "num_channels must be divisible by num_groups or min_channels_per_group" ) return num_groups def get_group_norm( num_channels: int, num_groups: int = 32, min_channels_per_group: int = 4, eps: float = 1e-5, use_apex_gn: bool = False, act: str | None = None, amp_mode: bool = False, ) -> torch.nn.Module: """ Utility function to get the GroupNorm layer, either from apex or from torch. Parameters ---------- num_channels : int Number of channels in the input tensor. num_groups : int, optional, default=32 Desired number of groups to divide the input channels. This might be adjusted based on the ``min_channels_per_group``. min_channels_per_group : int, optional, default=4 Minimum channels required per group. This ensures that no group has fewer channels than this number. eps : float, optional, default=1e-5 A small number added to the variance to prevent division by zero. use_apex_gn : bool, optional, default=False A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. Need to set this as False on cpu. act : str, optional, default=None The activation function to use when fusing activation with GroupNorm. amp_mode : bool, optional, default=False A boolean flag indicating whether mixed-precision (AMP) training is enabled. Returns ------- torch.nn.Module The GroupNorm layer. If ``use_apex_gn`` is ``True``, returns an ApexGroupNorm layer, otherwise returns an instance of :class:`~physicsnemo.models.diffusion.layers.GroupNorm`. .. note:: If ``num_channels`` is not divisible by ``num_groups``, the actual number of groups might be adjusted to satisfy the ``min_channels_per_group`` condition. """ if use_apex_gn and not _is_apex_available: raise ValueError("'apex' is not installed, set `use_apex_gn=False`") act: str | None = act.lower() if act else act if use_apex_gn: # adjust number of groups to be consistent with GroupNorm num_groups: int = _compute_groupnorm_groups( num_channels, num_groups, min_channels_per_group ) return ApexGroupNorm( num_groups=num_groups, num_channels=num_channels, eps=eps, affine=True, act=act, ) else: return GroupNorm( num_channels=num_channels, num_groups=num_groups, min_channels_per_group=min_channels_per_group, eps=eps, act=act, amp_mode=amp_mode, ) class GroupNorm(torch.nn.Module): """ A custom Group Normalization layer implementation. Group Normalization (GN) divides the channels of the input tensor into groups and normalizes the features within each group independently. It does not require the batch size as in Batch Normalization, making it suitable for batch sizes of any size or even for batch-free scenarios. Parameters ---------- num_channels : int Number of channels in the input tensor. num_groups : int, optional, default=32 Desired number of groups to divide the input channels. This might be adjusted based on the ``min_channels_per_group``. min_channels_per_group : int, optional, default=4 Minimum channels required per group. This ensures that no group has fewer channels than this number. eps : float, optional, default=1e-5 A small number added to the variance to prevent division by zero. use_apex_gn : bool, optional, default=False Deprecated. Please use :func:`~physicsnemo.models.diffusion.layers.get_group_norm` instead. fused_act : bool, optional, default=False Deprecated. Please use :func:`~physicsnemo.models.diffusion.layers.get_group_norm` instead. act : str, optional, default=None The activation function to use when fusing activation with GroupNorm. amp_mode : bool, optional, default=False A boolean flag indicating whether mixed-precision (AMP) training is enabled. Forward ------- x : torch.Tensor 4-D input tensor of shape :math:`(B, C, H, W)`, where :math:`B` is batch size, :math:`C` is ``num_channels``, and :math:`H, W` are spatial dimensions. Outputs ------- torch.Tensor Output tensor of the same shape as input: :math:`(B, C, H, W)`. .. note:: If ``num_channels`` is not divisible by ``num_groups``, the actual number of groups might be adjusted to satisfy the ``min_channels_per_group`` condition. """ def __init__( self, num_channels: int, num_groups: int = 32, min_channels_per_group: int = 4, eps: float = 1e-5, use_apex_gn: bool = False, fused_act: bool = False, act: str | None = None, amp_mode: bool = False, ): super().__init__() # backwards compatibility warnings if use_apex_gn: raise ValueError( "'use_apex_gn' is deprecated. Please use 'get_group_norm' to enable " "Apex-based group norm." ) if fused_act: raise ValueError( "'fused_act' is deprecated and only supported for Apex-based group norm. " "Please use `get_group_norm` to enable fused activations." ) # initialize groupnorm self.num_groups: int = _compute_groupnorm_groups( num_channels, num_groups, min_channels_per_group ) self.eps = eps self.weight = torch.nn.Parameter(torch.ones(num_channels)) self.bias = torch.nn.Parameter(torch.zeros(num_channels)) self.act = act.lower() if act else act self.act_fn = None if self.act is not None: self.act_fn = self.get_activation_function() self.amp_mode = amp_mode def forward(self, x): weight, bias = self.weight, self.bias _validate_amp(self.amp_mode) if not self.amp_mode: if weight.dtype != x.dtype: weight = self.weight.to(x.dtype) if bias.dtype != x.dtype: bias = self.bias.to(x.dtype) if self.training: # Use default torch implementation of GroupNorm for training # This does not support channels last memory format x = torch.nn.functional.group_norm( x, num_groups=self.num_groups, weight=weight, bias=bias, eps=self.eps, ) else: # Use custom GroupNorm implementation that supports channels last # memory layout for inference x = rearrange(x, "b (g c) h w -> b g c h w", g=self.num_groups) mean = x.mean(dim=[2, 3, 4], keepdim=True) var = x.var(dim=[2, 3, 4], keepdim=True) x = (x - mean) * (var + self.eps).rsqrt() x = rearrange(x, "b g c h w -> b (g c) h w") weight = rearrange(weight, "c -> 1 c 1 1") bias = rearrange(bias, "c -> 1 c 1 1") x = x * weight + bias if self.act_fn is not None: x = self.act_fn(x) return x def get_activation_function(self): """ Get activation function given string input """ activation_map = { "silu": silu, "relu": relu, "leaky_relu": leaky_relu, "sigmoid": sigmoid, "tanh": tanh, "gelu": gelu, "elu": elu, } act_fn = activation_map.get(self.act, None) if act_fn is None: raise ValueError(f"Unknown activation function: {self.act}") return act_fn class AttentionOp(torch.autograd.Function): """ Attention weight computation, i.e., softmax(Q^T * K). Performs all computation using FP32, but uses the original datatype for inputs/outputs/gradients to conserve memory. """ @staticmethod def forward(ctx, q, k): w = ( torch.einsum( "ncq,nck->nqk", q.to(torch.float32), (k / torch.sqrt(torch.tensor(k.shape[1]))).to(torch.float32), ) .softmax(dim=2) .to(q.dtype) ) ctx.save_for_backward(q, k, w) return w @staticmethod def backward(ctx, dw): q, k, w = ctx.saved_tensors db = torch._softmax_backward_data( grad_output=dw.to(torch.float32), output=w.to(torch.float32), dim=2, input_dtype=torch.float32, ) dq = torch.einsum("nck,nqk->ncq", k.to(torch.float32), db).to( q.dtype ) / np.sqrt(k.shape[1]) dk = torch.einsum("ncq,nqk->nck", q.to(torch.float32), db).to( k.dtype ) / np.sqrt(k.shape[1]) return dq, dk class Attention(torch.nn.Module): """ Self-attention block used in U-Net-style architectures, such as DDPM++, NCSN++, and ADM. Applies GroupNorm followed by multi-head self-attention and a projection layer. Parameters ---------- out_channels : int Number of channels :math:`C` in the input and output feature maps. num_heads : int Number of attention heads. Must be a positive integer. eps : float, optional, default=1e-5 Epsilon value for numerical stability in GroupNorm. init_zero : dict, optional, default={'init_weight': 0} Initialization parameters with zero weights for certain layers. init_attn : dict, optional, default=None Initialization parameters specific to attention mechanism layers. Defaults to 'init' if not provided. init : dict, optional, default={} Initialization parameters for convolutional and linear layers. use_apex_gn : bool, optional, default=False A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. Need to set this as False on cpu. amp_mode : bool, optional, default=False A boolean flag indicating whether mixed-precision (AMP) training is enabled. fused_conv_bias: bool, optional, default=False A boolean flag indicating whether bias will be passed as a parameter of conv2d. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, C, H, W)`, where :math:`B` is batch size, :math:`C` is `out_channels`, and :math:`H, W` are spatial dimensions. Outputs ------- torch.Tensor Output tensor of the same shape as input: :math:`(B, C, H, W)`. """ def __init__( self, *, out_channels: int, num_heads: int, eps: float = 1e-5, init_zero: Dict[str, Any] = dict(init_weight=0), init_attn: Any = None, init: Dict[str, Any] = dict(), use_apex_gn: bool = False, amp_mode: bool = False, fused_conv_bias: bool = False, ) -> None: super().__init__() # Parameters validation if not isinstance(num_heads, int) or num_heads <= 0: raise ValueError( f"`num_heads` must be a positive integer, but got {num_heads}" ) if out_channels % num_heads != 0: raise ValueError( f"`out_channels` must be divisible by `num_heads`, but got {out_channels} and {num_heads}" ) self.num_heads = num_heads self.norm = get_group_norm( num_channels=out_channels, eps=eps, use_apex_gn=use_apex_gn, amp_mode=amp_mode, ) self.qkv = Conv2d( in_channels=out_channels, out_channels=out_channels * 3, kernel=1, fused_conv_bias=fused_conv_bias, amp_mode=amp_mode, **(init_attn if init_attn is not None else init), ) self.proj = Conv2d( in_channels=out_channels, out_channels=out_channels, kernel=1, fused_conv_bias=fused_conv_bias, amp_mode=amp_mode, **init_zero, ) def forward(self, x: torch.Tensor) -> torch.Tensor: x1: torch.Tensor = self.qkv(self.norm(x)) # # NOTE: V1.0.1 implementation # q, k, v = x1.reshape( # x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1 # ).unbind(2) # w = AttentionOp.apply(q, k) # attn = torch.einsum("nqk,nck->ncq", w, v) q, k, v = ( ( x1.reshape( x.shape[0], self.num_heads, x.shape[1] // self.num_heads, 3, -1 ) ) .permute(0, 1, 4, 3, 2) .unbind(-2) ) attn = torch.nn.functional.scaled_dot_product_attention( q, k, v, scale=1 / math.sqrt(k.shape[-1]) ) attn = attn.transpose(-1, -2) x: torch.Tensor = self.proj(attn.reshape(*x.shape)).add_(x) return x class UNetBlock(torch.nn.Module): """ Unified U-Net block with optional up/downsampling and self-attention. Represents the union of all features employed by the DDPM++, NCSN++, and ADM architectures. Parameters: ----------- in_channels : int Number of input channels :math:`C_{in}`. out_channels : int Number of output channels :math:`C_{out}`. emb_channels : int Number of embedding channels :math:`C_{emb}`. up : bool, optional, default=False If True, applies upsampling in the forward pass. down : bool, optional, default=False If True, applies downsampling in the forward pass. attention : bool, optional, default=False If True, enables the self-attention mechanism in the block. num_heads : int, optional, default=None Number of attention heads. If None, defaults to :math:`C_{out} / 64`. channels_per_head : int, optional, default=64 Number of channels per attention head. dropout : float, optional, default=0.0 Dropout probability. skip_scale : float, optional, default=1.0 Scale factor applied to skip connections. eps : float, optional, default=1e-5 Epsilon value used for normalization layers. resample_filter : List[int], optional, default=``[1, 1]`` Filter for resampling layers. resample_proj : bool, optional, default=False If True, resampling projection is enabled. adaptive_scale : bool, optional, default=True If True, uses adaptive scaling in the forward pass. init : dict, optional, default=``{}`` Initialization parameters for convolutional and linear layers. init_zero : dict, optional, default=``{'init_weight': 0}`` Initialization parameters with zero weights for certain layers. init_attn : dict, optional, default=``None`` Initialization parameters specific to attention mechanism layers. Defaults to ``init`` if not provided. use_apex_gn : bool, optional, default=False A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. Need to set this as False on cpu. act : str, optional, default=None The activation function to use when fusing activation with GroupNorm. fused_conv_bias: bool, optional, default=False A boolean flag indicating whether bias will be passed as a parameter of conv2d. profile_mode: bool, optional, default=False A boolean flag indicating whether to enable all nvtx annotations during profiling. amp_mode : bool, optional, default=False A boolean flag indicating whether mixed-precision (AMP) training is enabled. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, C_{in}, H, W)`, where :math:`B` is batch size, :math:`C_{in}` is ``in_channels``, and :math:`H, W` are spatial dimensions. emb : torch.Tensor Embedding tensor of shape :math:`(B, C_{emb})`, where :math:`B` is batch size, and :math:`C_{emb}` is ``emb_channels``. Outputs ------- torch.Tensor Output tensor of shape :math:`(B, C_{out}, H, W)`, where :math:`B` is batch size, :math:`C_{out}` is ``out_channels``, and :math:`H, W` are spatial dimensions. """ # NOTE: these attributes have specific usage in old checkpoints, do not # reuse them! _reserved_attributes: Set[str] = set(["norm2", "qkv", "proj"]) def __init__( self, in_channels: int, out_channels: int, emb_channels: int, up: bool = False, down: bool = False, attention: bool = False, num_heads: int | None = None, channels_per_head: int = 64, dropout: float = 0.0, skip_scale: float = 1.0, eps: float = 1e-5, resample_filter: List[int] = [1, 1], resample_proj: bool = False, adaptive_scale: bool = True, init: Dict[str, Any] = dict(), init_zero: Dict[str, Any] = dict(init_weight=0), init_attn: Any = None, use_apex_gn: bool = False, act: str = "silu", fused_conv_bias: bool = False, profile_mode: bool = False, amp_mode: bool = False, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.emb_channels = emb_channels self.num_heads = ( 0 if not attention else ( num_heads if num_heads is not None else out_channels // channels_per_head ) ) self.attention = attention self.dropout = dropout self.skip_scale = skip_scale self.adaptive_scale = adaptive_scale self.profile_mode = profile_mode self.amp_mode = amp_mode self.norm0 = get_group_norm( num_channels=in_channels, eps=eps, use_apex_gn=use_apex_gn, act=act, amp_mode=amp_mode, ) self.conv0 = Conv2d( in_channels=in_channels, out_channels=out_channels, kernel=3, up=up, down=down, resample_filter=resample_filter, fused_conv_bias=fused_conv_bias, amp_mode=amp_mode, **init, ) self.affine = Linear( in_features=emb_channels, out_features=out_channels * (2 if adaptive_scale else 1), amp_mode=amp_mode, **init, ) if self.adaptive_scale: self.norm1 = get_group_norm( num_channels=out_channels, eps=eps, use_apex_gn=use_apex_gn, amp_mode=amp_mode, ) else: self.norm1 = get_group_norm( num_channels=out_channels, eps=eps, use_apex_gn=use_apex_gn, act=act, amp_mode=amp_mode, ) self.conv1 = Conv2d( in_channels=out_channels, out_channels=out_channels, kernel=3, fused_conv_bias=fused_conv_bias, amp_mode=amp_mode, **init_zero, ) self.skip = None if out_channels != in_channels or up or down: kernel = 1 if resample_proj or out_channels != in_channels else 0 fused_conv_bias = fused_conv_bias if kernel != 0 else False self.skip = Conv2d( in_channels=in_channels, out_channels=out_channels, kernel=kernel, up=up, down=down, resample_filter=resample_filter, fused_conv_bias=fused_conv_bias, amp_mode=amp_mode, **init, ) if self.attention: self.attn = Attention( out_channels=out_channels, num_heads=self.num_heads, eps=eps, init_zero=init_zero, init_attn=init_attn, init=init, use_apex_gn=use_apex_gn, amp_mode=amp_mode, fused_conv_bias=fused_conv_bias, ) else: self.attn = None # A hook to migrate legacy attention module self.register_load_state_dict_pre_hook(self._migrate_attention_module) def forward(self, x, emb): with ( nvtx.annotate(message="UNetBlock", color="purple") if self.profile_mode else contextlib.nullcontext() ): orig = x x = self.conv0(self.norm0(x)) params = self.affine(emb).unsqueeze(2).unsqueeze(3) _validate_amp(self.amp_mode) if not self.amp_mode: if params.dtype != x.dtype: params = params.to(x.dtype) # type: ignore if self.adaptive_scale: scale, shift = params.chunk(chunks=2, dim=1) x = silu(torch.addcmul(shift, self.norm1(x), scale + 1)) else: x = self.norm1(x.add_(params)) x = self.conv1( torch.nn.functional.dropout(x, p=self.dropout, training=self.training) ) x = x.add_(self.skip(orig) if self.skip is not None else orig) x = x * self.skip_scale if self.attn: x = self.attn(x) x = x * self.skip_scale return x def __setattr__(self, name, value): """Prevent setting attributes with reserved names. Parameters ---------- name : str Attribute name. value : Any Attribute value. """ if name in getattr(self.__class__, "_reserved_attributes", set()): raise AttributeError(f"Attribute '{name}' is reserved and cannot be set.") super().__setattr__(name, value) @staticmethod def _migrate_attention_module( module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): """``load_state_dict`` pre-hook that handles legacy checkpoints that stored attention layers at root. The earliest versions of ``UNetBlock`` stored the attention-layer parameters directly on the block using attribute names contained in ``_reserved_attributes``. These have since been moved under the dedicated ``attn`` sub-module. This helper migrates the parameter names so that older checkpoints can still be loaded. """ _mapping = { f"{prefix}norm2.weight": f"{prefix}attn.norm.weight", f"{prefix}norm2.bias": f"{prefix}attn.norm.bias", f"{prefix}qkv.weight": f"{prefix}attn.qkv.weight", f"{prefix}qkv.bias": f"{prefix}attn.qkv.bias", f"{prefix}proj.weight": f"{prefix}attn.proj.weight", f"{prefix}proj.bias": f"{prefix}attn.proj.bias", } for old_key, new_key in _mapping.items(): if old_key in state_dict: # NOTE: Only migrate if destination key not already present to # avoid accidental overwriting when both are present. if new_key not in state_dict: state_dict[new_key] = state_dict.pop(old_key) else: raise ValueError( f"Checkpoint contains both legacy and new keys for {old_key}" ) class PositionalEmbedding(torch.nn.Module): """ A module for generating positional embeddings based on timesteps. This embedding technique is employed in the DDPM++ and ADM architectures. Parameters: ----------- num_channels : int Number of channels for the embedding. max_positions : int, optional Maximum number of positions for the embeddings, by default 10000. endpoint : bool, optional If True, the embedding considers the endpoint. By default False. amp_mode : bool, optional A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. learnable : bool, optional A boolean flag indicating whether learnable positional embedding is enabled. Defaults to False. freq_embed_dim: int, optional The dimension of the frequency embedding. Defaults to None, in which case it will be set to num_channels. mlp_hidden_dim: int, optional The dimension of the hidden layer in the MLP. Defaults to None, in which case it will be set to 2 * num_channels. Only applicable if learnable is True; if learnable is False, this parameter is ignored. embed_fn: Literal["cos_sin", "np_sin_cos"], optional The function to use for embedding into sin/cos features (allows for swapping the order of sin/cos). Defaults to 'cos_sin'. Options: - 'cos_sin': Uses torch to compute frequency embeddings and returns in order (cos, sin) - 'np_sin_cos': Uses numpy to compute frequency embeddings and returns in order (sin, cos) """ def __init__( self, num_channels: int, max_positions: int = 10000, endpoint: bool = False, amp_mode: bool = False, learnable: bool = False, freq_embed_dim: int | None = None, mlp_hidden_dim: int | None = None, embed_fn: Literal["cos_sin", "np_sin_cos"] = "cos_sin", ): super().__init__() self.num_channels = num_channels self.max_positions = max_positions self.endpoint = endpoint self.amp_mode = amp_mode self.learnable = learnable self.embed_fn = embed_fn if freq_embed_dim is None: freq_embed_dim = num_channels self.freq_embed_dim = freq_embed_dim if learnable: if mlp_hidden_dim is None: mlp_hidden_dim = 2 * num_channels self.mlp = torch.nn.Sequential( torch.nn.Linear(freq_embed_dim, mlp_hidden_dim, bias=True), torch.nn.SiLU(), torch.nn.Linear(mlp_hidden_dim, num_channels, bias=True), ) if self.embed_fn == "np_sin_cos": half_embed_dim = freq_embed_dim // 2 pow = np.arange(half_embed_dim, dtype=np.float32) / half_embed_dim w = np.exp(-np.log(self.max_positions) * pow) self.register_buffer("freqs", torch.from_numpy(w).float()) def _cos_sin_embedding(self, x): freqs = torch.arange( start=0, end=self.freq_embed_dim // 2, dtype=torch.float32, device=x.device ) freqs = freqs / (self.freq_embed_dim // 2 - (1 if self.endpoint else 0)) freqs = (1 / self.max_positions) ** freqs _validate_amp(self.amp_mode) if not self.amp_mode: if freqs.dtype != x.dtype: freqs = freqs.to(x.dtype) x = x.ger(freqs) x = torch.cat([x.cos(), x.sin()], dim=1) return x def _sin_cos_embedding_np(self, x): x = torch.outer(x, self.freqs) x = torch.cat([x.sin(), x.cos()], dim=1) return x def forward(self, x): if self.embed_fn == "cos_sin": x = self._cos_sin_embedding(x) elif self.embed_fn == "np_sin_cos": x = self._sin_cos_embedding_np(x) if self.learnable: x = self.mlp(x) return x class FourierEmbedding(torch.nn.Module): """ Generates Fourier embeddings for timesteps, primarily used in the NCSN++ architecture. This class generates embeddings by first multiplying input tensor `x` and internally stored random frequencies, and then concatenating the cosine and sine of the resultant. Parameters: ----------- num_channels : int The number of channels in the embedding. The final embedding size will be 2 * num_channels because of concatenation of cosine and sine results. scale : int, optional A scale factor applied to the random frequencies, controlling their range and thereby the frequency of oscillations in the embedding space. By default 16. amp_mode : bool, optional A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ def __init__(self, num_channels: int, scale: int = 16, amp_mode: bool = False): super().__init__() self.register_buffer("freqs", torch.randn(num_channels // 2) * scale) self.amp_mode = amp_mode def forward(self, x): freqs = self.freqs _validate_amp(self.amp_mode) if not self.amp_mode: if x.dtype != self.freqs.dtype: freqs = self.freqs.to(x.dtype) x = x.ger((2 * np.pi * freqs)) x = torch.cat([x.cos(), x.sin()], dim=1) return x