# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Optional import torch from torch import Tensor from physicsnemo.models.diffusion import EDMPrecond from physicsnemo.utils.patching import GridPatching2D # NOTE: use two wrappers for apply, to avoid recompilation when input shape changes @torch.compile() def _apply_wrapper_Cin_channels(patching, input, additional_input=None): """ Apply the patching operation to the input tensor with :math:`C_{in}` channels. """ return patching.apply(input=input, additional_input=additional_input) @torch.compile() def _apply_wrapper_Cout_channels_no_grad(patching, input, additional_input=None): """ Apply the patching operation to an input tensor with :math:`C_{out}` channels that does not require gradients. """ return patching.apply(input=input, additional_input=additional_input) @torch.compile() def _apply_wrapper_Cout_channels_grad(patching, input, additional_input=None): """ Apply the patching operation to an input tensor with :math:`C_{out}` channels that requires gradients. """ return patching.apply(input=input, additional_input=additional_input) @torch.compile() def _fuse_wrapper(patching, input, batch_size): return patching.fuse(input=input, batch_size=batch_size) def _apply_wrapper_select( input: torch.Tensor, patching: GridPatching2D | None ) -> Callable: """ Select the correct patching wrapper based on the input tensor's requires_grad attribute. If patching is None, return the identity function. If patching is not None, return the appropriate patching wrapper. If input.requires_grad is True, return _apply_wrapper_Cout_channels_grad. If input.requires_grad is False, return _apply_wrapper_Cout_channels_no_grad. """ if patching: if input.requires_grad: return _apply_wrapper_Cout_channels_grad else: return _apply_wrapper_Cout_channels_no_grad else: return lambda patching, input, additional_input=None: input def stochastic_sampler( net: torch.nn.Module, latents: Tensor, img_lr: Tensor, class_labels: Optional[Tensor] = None, randn_like: Callable[[Tensor], Tensor] = torch.randn_like, patching: Optional[GridPatching2D] = None, mean_hr: Optional[Tensor] = None, lead_time_label: Optional[Tensor] = None, num_steps: int = 18, sigma_min: float = 0.002, sigma_max: float = 800, rho: float = 7, S_churn: float = 0, S_min: float = 0, S_max: float = float("inf"), S_noise: float = 1, ) -> Tensor: r""" Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution and patch-based diffusion. Parameters ---------- net : torch.nn.Module The neural network model that generates denoised images from noisy inputs. Expected signature: ``net(x, x_lr, t_hat, class_labels, lead_time_label=lead_time_label, embedding_selector=embedding_selector)``. Inputs: - **x** (*torch.Tensor*): Noisy input of shape :math:`(B, C_{out}, H, W)` - **x_lr** (*torch.Tensor*): Conditioning input of shape :math:`(B, C_{cond}, H, W)` - **t_hat** (*torch.Tensor*): Noise level of shape :math:`(B, 1, 1, 1)` or scalar - **class_labels** (*torch.Tensor, optional*): Optional class labels - **lead_time_label** (*torch.Tensor, optional*): Optional lead time labels - **embedding_selector** (*callable, optional*): Function to select positional embeddings. Used for patch-based diffusion. Output: - **denoised** (*torch.Tensor*): Denoised prediction of shape :math:`(B, C_{out}, H, W)` Required attributes: - **sigma_min** (*float*): Minimum supported noise level for the model - **sigma_max** (*float*): Maximum supported noise level for the model - **round_sigma** (*callable*): Method to convert sigma values to tensor representation latents : Tensor The latent variables (e.g., noise) used as the initial input for the sampler. Has shape :math:`(B, C_{out}, H, W)`. img_lr : Tensor Low-resolution input image for conditioning the super-resolution process. Must have shape :math:`(B, C_{lr}, H, W)`. class_labels : Optional[Tensor], optional Class labels for conditional generation, if required by the model. By default ``None``. randn_like : Callable[[Tensor], Tensor] Function to generate random noise with the same shape as the input tensor. By default ``torch.randn_like``. patching : Optional[GridPatching2D], default=None A patching utility for patch-based diffusion. Implements methods to extract patches from an image and batch the patches along dim=0. Should also implement a ``fuse`` method to reconstruct the original image from a batch of patches. See :class:`~physicsnemo.utils.patching.GridPatching2D` for details. By default ``None``, in which case non-patched diffusion is used. mean_hr : Optional[Tensor], optional Optional tensor containing mean high-resolution images for conditioning. Must have same height and width as ``img_lr``, with shape :math:`(B_{hr}, C_{hr}, H, W)` where the batch dimension :math:`B_{hr}` can be either 1, either equal to batch_size, or can be omitted. If :math:`B_{hr} = 1` or is omitted, ``mean_hr`` will be expanded to match the shape of ``img_lr``. By default ``None``. lead_time_label : Optional[Tensor], optional Optional lead time labels. By default ``None``. num_steps : int Number of time steps for the sampler. By default 18. sigma_min : float Minimum noise level. By default 0.002. sigma_max : float Maximum noise level. By default 800. rho : float Exponent used in the time step discretization. By default 7. S_churn : float Churn parameter controlling the level of noise added in each step. By default 0. S_min : float Minimum time step for applying churn. By default 0. S_max : float Maximum time step for applying churn. By default ``float("inf")``. S_noise : float Noise scaling factor applied during the churn step. By default 1. Returns ------- Tensor The final denoised image produced by the sampler. Same shape as ``latents``: :math:`(B, C_{out}, H, W)`. See Also -------- :class:`~physicsnemo.models.diffusion.preconditioning.EDMPrecondSuperResolution`: A model wrapper that provides preconditioning for super-resolution diffusion models and implements the required interface for this sampler. """ # Adjust noise levels based on what's supported by the network. # Proposed EDM sampler (Algorithm 2) with minor changes to enable # super-resolution/ sigma_min = max(sigma_min, net.sigma_min) sigma_max = min(sigma_max, net.sigma_max) # Safety check on type of patching if patching is not None and not isinstance(patching, GridPatching2D): raise ValueError("patching must be an instance of GridPatching2D.") # Safety check: if patching is used then img_lr and latents must have same # height and width, otherwise there is mismatch in the number # of patches extracted to form the final batch_size. if patching: if img_lr.shape[-2:] != latents.shape[-2:]: raise ValueError( f"img_lr and latents must have the same height and width, " f"but found {img_lr.shape[-2:]} vs {latents.shape[-2:]}. " ) # img_lr and latents must also have the same batch_size, otherwise mismatch # when processed by the network if img_lr.shape[0] != latents.shape[0]: raise ValueError( f"img_lr and latents must have the same batch size, but found " f"{img_lr.shape[0]} vs {latents.shape[0]}." ) # Time step discretization. step_indices = torch.arange(num_steps, device=latents.device) t_steps = ( sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) ) ** rho t_steps = torch.cat( [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] ) # t_N = 0 batch_size = img_lr.shape[0] # conditioning = [mean_hr, img_lr, global_lr, pos_embd] x_lr = img_lr if mean_hr is not None: if mean_hr.shape[-2:] != img_lr.shape[-2:]: raise ValueError( f"mean_hr and img_lr must have the same height and width, " f"but found {mean_hr.shape[-2:]} vs {img_lr.shape[-2:]}." ) x_lr = torch.cat((mean_hr.expand(x_lr.shape[0], -1, -1, -1), x_lr), dim=1) # input and position padding + patching if patching: # Patched conditioning [x_lr, mean_hr] # (batch_size * patch_num, C_in + C_out, patch_shape_y, patch_shape_x) x_lr = _apply_wrapper_Cin_channels( patching=patching, input=x_lr, additional_input=img_lr ) # Function to select the correct positional embedding for each patch def patch_embedding_selector(emb): # emb: (N_pe, image_shape_y, image_shape_x) # return: (batch_size * patch_num, N_pe, patch_shape_y, patch_shape_x) return patching.apply(emb.expand(batch_size, -1, -1, -1)) else: patch_embedding_selector = None optional_args = {} if lead_time_label is not None: optional_args["lead_time_label"] = lead_time_label if patching: optional_args["embedding_selector"] = patch_embedding_selector # Main sampling loop. x_next = latents * t_steps[0] for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 x_cur = x_next # Increase noise temporarily. gamma = S_churn / num_steps if S_min <= t_cur <= S_max else 0 t_hat = net.round_sigma(t_cur + gamma * t_cur) x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) # Euler step. Perform patching operation on score tensor if patch-based # generation is used denoised = net(x_hat, t_hat, # class_labels,lead_time_label=lead_time_label).to(torch.float64) x_hat_batch = _apply_wrapper_select(input=x_hat, patching=patching)( patching=patching, input=x_hat ).to(latents.device) x_lr = x_lr.to(latents.device) if isinstance(net, EDMPrecond): # Conditioning info is passed as keyword arg denoised = net( x_hat_batch, t_hat, condition=x_lr, class_labels=class_labels, **optional_args, ) else: denoised = net( x_hat_batch, x_lr, t_hat, class_labels, **optional_args, ) if patching: # Un-patch the denoised image # (batch_size, C_out, img_shape_y, img_shape_x) denoised = _fuse_wrapper( patching=patching, input=denoised, batch_size=batch_size ) d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur # Apply 2nd order correction. if i < num_steps - 1: # Patched input # (batch_size * patch_num, C_out, patch_shape_y, patch_shape_x) x_next_batch = _apply_wrapper_select(input=x_next, patching=patching)( patching=patching, input=x_next ).to(latents.device) if isinstance(net, EDMPrecond): # Conditioning info is passed as keyword arg denoised = net( x_next_batch, t_next, condition=x_lr, class_labels=class_labels, **optional_args, ) else: denoised = net( x_next_batch, x_lr, t_next, class_labels, **optional_args, ) if patching: # Un-patch the denoised image # (batch_size, C_out, img_shape_y, img_shape_x) denoised = _fuse_wrapper( patching=patching, input=denoised, batch_size=batch_size ) d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) return x_next