Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Callable, Optional | |
| import torch | |
| from torch import Tensor | |
| from physicsnemo.models.diffusion import EDMPrecond | |
| from physicsnemo.utils.patching import GridPatching2D | |
| # NOTE: use two wrappers for apply, to avoid recompilation when input shape changes | |
| def _apply_wrapper_Cin_channels(patching, input, additional_input=None): | |
| """ | |
| Apply the patching operation to the input tensor with :math:`C_{in}` channels. | |
| """ | |
| return patching.apply(input=input, additional_input=additional_input) | |
| def _apply_wrapper_Cout_channels_no_grad(patching, input, additional_input=None): | |
| """ | |
| Apply the patching operation to an input tensor with :math:`C_{out}` | |
| channels that does not require gradients. | |
| """ | |
| return patching.apply(input=input, additional_input=additional_input) | |
| def _apply_wrapper_Cout_channels_grad(patching, input, additional_input=None): | |
| """ | |
| Apply the patching operation to an input tensor with :math:`C_{out}` | |
| channels that requires gradients. | |
| """ | |
| return patching.apply(input=input, additional_input=additional_input) | |
| def _fuse_wrapper(patching, input, batch_size): | |
| return patching.fuse(input=input, batch_size=batch_size) | |
| def _apply_wrapper_select( | |
| input: torch.Tensor, patching: GridPatching2D | None | |
| ) -> Callable: | |
| """ | |
| Select the correct patching wrapper based on the input tensor's requires_grad attribute. | |
| If patching is None, return the identity function. | |
| If patching is not None, return the appropriate patching wrapper. | |
| If input.requires_grad is True, return _apply_wrapper_Cout_channels_grad. | |
| If input.requires_grad is False, return _apply_wrapper_Cout_channels_no_grad. | |
| """ | |
| if patching: | |
| if input.requires_grad: | |
| return _apply_wrapper_Cout_channels_grad | |
| else: | |
| return _apply_wrapper_Cout_channels_no_grad | |
| else: | |
| return lambda patching, input, additional_input=None: input | |
| def stochastic_sampler( | |
| net: torch.nn.Module, | |
| latents: Tensor, | |
| img_lr: Tensor, | |
| class_labels: Optional[Tensor] = None, | |
| randn_like: Callable[[Tensor], Tensor] = torch.randn_like, | |
| patching: Optional[GridPatching2D] = None, | |
| mean_hr: Optional[Tensor] = None, | |
| lead_time_label: Optional[Tensor] = None, | |
| num_steps: int = 18, | |
| sigma_min: float = 0.002, | |
| sigma_max: float = 800, | |
| rho: float = 7, | |
| S_churn: float = 0, | |
| S_min: float = 0, | |
| S_max: float = float("inf"), | |
| S_noise: float = 1, | |
| ) -> Tensor: | |
| r""" | |
| Proposed EDM sampler (Algorithm 2) with minor changes to enable | |
| super-resolution and patch-based diffusion. | |
| Parameters | |
| ---------- | |
| net : torch.nn.Module | |
| The neural network model that generates denoised images from noisy | |
| inputs. | |
| Expected signature: ``net(x, x_lr, t_hat, class_labels, | |
| lead_time_label=lead_time_label, | |
| embedding_selector=embedding_selector)``. | |
| Inputs: | |
| - **x** (*torch.Tensor*): Noisy input of shape :math:`(B, C_{out}, H, W)` | |
| - **x_lr** (*torch.Tensor*): Conditioning input of shape :math:`(B, C_{cond}, H, W)` | |
| - **t_hat** (*torch.Tensor*): Noise level of shape :math:`(B, 1, 1, 1)` or scalar | |
| - **class_labels** (*torch.Tensor, optional*): Optional class labels | |
| - **lead_time_label** (*torch.Tensor, optional*): Optional lead time labels | |
| - **embedding_selector** (*callable, optional*): Function to select | |
| positional embeddings. Used for patch-based diffusion. | |
| Output: | |
| - **denoised** (*torch.Tensor*): Denoised prediction of shape :math:`(B, C_{out}, H, W)` | |
| Required attributes: | |
| - **sigma_min** (*float*): Minimum supported noise level for the model | |
| - **sigma_max** (*float*): Maximum supported noise level for the model | |
| - **round_sigma** (*callable*): Method to convert sigma values to | |
| tensor representation | |
| latents : Tensor | |
| The latent variables (e.g., noise) used as the initial input for the | |
| sampler. Has shape :math:`(B, C_{out}, H, W)`. | |
| img_lr : Tensor | |
| Low-resolution input image for conditioning the super-resolution | |
| process. Must have shape :math:`(B, C_{lr}, H, W)`. | |
| class_labels : Optional[Tensor], optional | |
| Class labels for conditional generation, if required by the model. By | |
| default ``None``. | |
| randn_like : Callable[[Tensor], Tensor] | |
| Function to generate random noise with the same shape as the input | |
| tensor. | |
| By default ``torch.randn_like``. | |
| patching : Optional[GridPatching2D], default=None | |
| A patching utility for patch-based diffusion. Implements methods to | |
| extract patches from an image and batch the patches along dim=0. | |
| Should also implement a ``fuse`` method to reconstruct the original | |
| image from a batch of patches. See | |
| :class:`~physicsnemo.utils.patching.GridPatching2D` for details. By | |
| default ``None``, in which case non-patched diffusion is used. | |
| mean_hr : Optional[Tensor], optional | |
| Optional tensor containing mean high-resolution images for | |
| conditioning. Must have same height and width as ``img_lr``, with shape | |
| :math:`(B_{hr}, C_{hr}, H, W)` where the batch dimension | |
| :math:`B_{hr}` can be either 1, either equal to batch_size, or can be omitted. If | |
| :math:`B_{hr} = 1` or is omitted, ``mean_hr`` will be expanded to match the shape | |
| of ``img_lr``. By default ``None``. | |
| lead_time_label : Optional[Tensor], optional | |
| Optional lead time labels. By default ``None``. | |
| num_steps : int | |
| Number of time steps for the sampler. By default 18. | |
| sigma_min : float | |
| Minimum noise level. By default 0.002. | |
| sigma_max : float | |
| Maximum noise level. By default 800. | |
| rho : float | |
| Exponent used in the time step discretization. By default 7. | |
| S_churn : float | |
| Churn parameter controlling the level of noise added in each step. By | |
| default 0. | |
| S_min : float | |
| Minimum time step for applying churn. By default 0. | |
| S_max : float | |
| Maximum time step for applying churn. By default ``float("inf")``. | |
| S_noise : float | |
| Noise scaling factor applied during the churn step. By default 1. | |
| Returns | |
| ------- | |
| Tensor | |
| The final denoised image produced by the sampler. Same shape as | |
| ``latents``: :math:`(B, C_{out}, H, W)`. | |
| See Also | |
| -------- | |
| :class:`~physicsnemo.models.diffusion.preconditioning.EDMPrecondSuperResolution`: A model | |
| wrapper that provides preconditioning for super-resolution diffusion | |
| models and implements the required interface for this sampler. | |
| """ | |
| # Adjust noise levels based on what's supported by the network. | |
| # Proposed EDM sampler (Algorithm 2) with minor changes to enable | |
| # super-resolution/ | |
| sigma_min = max(sigma_min, net.sigma_min) | |
| sigma_max = min(sigma_max, net.sigma_max) | |
| # Safety check on type of patching | |
| if patching is not None and not isinstance(patching, GridPatching2D): | |
| raise ValueError("patching must be an instance of GridPatching2D.") | |
| # Safety check: if patching is used then img_lr and latents must have same | |
| # height and width, otherwise there is mismatch in the number | |
| # of patches extracted to form the final batch_size. | |
| if patching: | |
| if img_lr.shape[-2:] != latents.shape[-2:]: | |
| raise ValueError( | |
| f"img_lr and latents must have the same height and width, " | |
| f"but found {img_lr.shape[-2:]} vs {latents.shape[-2:]}. " | |
| ) | |
| # img_lr and latents must also have the same batch_size, otherwise mismatch | |
| # when processed by the network | |
| if img_lr.shape[0] != latents.shape[0]: | |
| raise ValueError( | |
| f"img_lr and latents must have the same batch size, but found " | |
| f"{img_lr.shape[0]} vs {latents.shape[0]}." | |
| ) | |
| # Time step discretization. | |
| step_indices = torch.arange(num_steps, device=latents.device) | |
| t_steps = ( | |
| sigma_max ** (1 / rho) | |
| + step_indices | |
| / (num_steps - 1) | |
| * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) | |
| ) ** rho | |
| t_steps = torch.cat( | |
| [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] | |
| ) # t_N = 0 | |
| batch_size = img_lr.shape[0] | |
| # conditioning = [mean_hr, img_lr, global_lr, pos_embd] | |
| x_lr = img_lr | |
| if mean_hr is not None: | |
| if mean_hr.shape[-2:] != img_lr.shape[-2:]: | |
| raise ValueError( | |
| f"mean_hr and img_lr must have the same height and width, " | |
| f"but found {mean_hr.shape[-2:]} vs {img_lr.shape[-2:]}." | |
| ) | |
| x_lr = torch.cat((mean_hr.expand(x_lr.shape[0], -1, -1, -1), x_lr), dim=1) | |
| # input and position padding + patching | |
| if patching: | |
| # Patched conditioning [x_lr, mean_hr] | |
| # (batch_size * patch_num, C_in + C_out, patch_shape_y, patch_shape_x) | |
| x_lr = _apply_wrapper_Cin_channels( | |
| patching=patching, input=x_lr, additional_input=img_lr | |
| ) | |
| # Function to select the correct positional embedding for each patch | |
| def patch_embedding_selector(emb): | |
| # emb: (N_pe, image_shape_y, image_shape_x) | |
| # return: (batch_size * patch_num, N_pe, patch_shape_y, patch_shape_x) | |
| return patching.apply(emb.expand(batch_size, -1, -1, -1)) | |
| else: | |
| patch_embedding_selector = None | |
| optional_args = {} | |
| if lead_time_label is not None: | |
| optional_args["lead_time_label"] = lead_time_label | |
| if patching: | |
| optional_args["embedding_selector"] = patch_embedding_selector | |
| # Main sampling loop. | |
| x_next = latents * t_steps[0] | |
| for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 | |
| x_cur = x_next | |
| # Increase noise temporarily. | |
| gamma = S_churn / num_steps if S_min <= t_cur <= S_max else 0 | |
| t_hat = net.round_sigma(t_cur + gamma * t_cur) | |
| x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) | |
| # Euler step. Perform patching operation on score tensor if patch-based | |
| # generation is used denoised = net(x_hat, t_hat, | |
| # class_labels,lead_time_label=lead_time_label).to(torch.float64) | |
| x_hat_batch = _apply_wrapper_select(input=x_hat, patching=patching)( | |
| patching=patching, input=x_hat | |
| ).to(latents.device) | |
| x_lr = x_lr.to(latents.device) | |
| if isinstance(net, EDMPrecond): | |
| # Conditioning info is passed as keyword arg | |
| denoised = net( | |
| x_hat_batch, | |
| t_hat, | |
| condition=x_lr, | |
| class_labels=class_labels, | |
| **optional_args, | |
| ) | |
| else: | |
| denoised = net( | |
| x_hat_batch, | |
| x_lr, | |
| t_hat, | |
| class_labels, | |
| **optional_args, | |
| ) | |
| if patching: | |
| # Un-patch the denoised image | |
| # (batch_size, C_out, img_shape_y, img_shape_x) | |
| denoised = _fuse_wrapper( | |
| patching=patching, input=denoised, batch_size=batch_size | |
| ) | |
| d_cur = (x_hat - denoised) / t_hat | |
| x_next = x_hat + (t_next - t_hat) * d_cur | |
| # Apply 2nd order correction. | |
| if i < num_steps - 1: | |
| # Patched input | |
| # (batch_size * patch_num, C_out, patch_shape_y, patch_shape_x) | |
| x_next_batch = _apply_wrapper_select(input=x_next, patching=patching)( | |
| patching=patching, input=x_next | |
| ).to(latents.device) | |
| if isinstance(net, EDMPrecond): | |
| # Conditioning info is passed as keyword arg | |
| denoised = net( | |
| x_next_batch, | |
| t_next, | |
| condition=x_lr, | |
| class_labels=class_labels, | |
| **optional_args, | |
| ) | |
| else: | |
| denoised = net( | |
| x_next_batch, | |
| x_lr, | |
| t_next, | |
| class_labels, | |
| **optional_args, | |
| ) | |
| if patching: | |
| # Un-patch the denoised image | |
| # (batch_size, C_out, img_shape_y, img_shape_x) | |
| denoised = _fuse_wrapper( | |
| patching=patching, input=denoised, batch_size=batch_size | |
| ) | |
| d_prime = (x_next - denoised) / t_next | |
| x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) | |
| return x_next | |