# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib from dataclasses import dataclass from typing import Any, Dict, List, Literal, Set, Tuple, Union import torch from physicsnemo.models.diffusion.utils import _wrapped_property from physicsnemo.models.meta import ModelMetaData from physicsnemo.models.module import Module network_module = importlib.import_module("physicsnemo.models.diffusion") @dataclass class MetaData(ModelMetaData): name: str = "UNet" # Optimization jit: bool = False cuda_graphs: bool = False amp_cpu: bool = False amp_gpu: bool = True torch_fx: bool = False # Data type bf16: bool = True # Inference onnx: bool = False # Physics informed func_torch: bool = False auto_grad: bool = False class UNet(Module): # TODO a lot of redundancy, need to clean up r""" This interface provides a U-Net wrapper for CorrDiff deterministic regression model (and other deterministic downsampling models). It supports the following architectures: - :class:`~physicsnemo.models.diffusion.song_unet.SongUNet` - :class:`~physicsnemo.models.diffusion.song_unet.SongUNetPosEmbd` - :class:`~physicsnemo.models.diffusion.song_unet.SongUNetPosLtEmbd` - :class:`~physicsnemo.models.diffusion.dhariwal_unet.DhariwalUNet` It shares the same architeture as a conditional diffusion model. It does so by concatenating a conditioning image to a zero-filled latent state, and by setting the noise level and the class labels to zero. Parameters ----------- img_resolution : Union[int, Tuple[int, int]] The resolution of the input/output image. If a single int is provided, then the image is assumed to be square. img_in_channels : int Number of channels in the input image. img_out_channels : int Number of channels in the output image. use_fp16: bool, optional, default=False Execute the underlying model at FP16 precision. model_type: Literal['SongUNet', 'SongUNetPosEmbd', 'SongUNetPosLtEmbd', 'DhariwalUNet'], default='SongUNetPosEmbd' Class name of the underlying architecture. Must be one of the following: 'SongUNet', 'SongUNetPosEmbd', 'SongUNetPosLtEmbd', 'DhariwalUNet'. **model_kwargs : dict Keyword arguments passed to the underlying architecture `__init__` method. Please refer to the documentation of these classes for details on how to call and use these models directly. Forward ------- x : torch.Tensor The input tensor, typically zero-filled, of shape :math:`(B, C_{in}, H_{in}, W_{in})`. img_lr : torch.Tensor Conditioning image of shape :math:`(B, C_{lr}, H_{in}, W_{in})`. **model_kwargs : dict Additional keyword arguments to pass to the underlying architecture forward method. Outputs ------- torch.Tensor Output tensor of shape :math:`(B, C_{out}, H_{in}, W_{in})` (same spatial dimensions as the input). """ __model_checkpoint_version__ = "0.2.0" __supported_model_checkpoint_version__ = { "0.1.0": "Loading UNet checkpoint from older version 0.1.0 (current version is 0.2.0). This version is still supported, but consider re-saving the model to upgrade to version 0.2.0 and remove this warning." } # Classes that can be wrapped by this UNet class. _wrapped_classes: Set[str] = { "SongUNetPosEmbd", "SongUNetPosLtEmbd", "SongUNet", "DhariwalUNet", } # Arguments of the __init__ method that can be overridden with the # ``Module.from_checkpoint`` method. Here, since we use splatted arguments # for the wrapped model instance, we allow overriding of any overridable # argument of the wrapped classes. _overridable_args: Set[str] = set.union( *( getattr(getattr(network_module, cls_name), "_overridable_args", set()) for cls_name in _wrapped_classes ) ) @classmethod def _backward_compat_arg_mapper( cls, version: str, args: Dict[str, Any] ) -> Dict[str, Any]: """Map arguments from older versions to current version format. Parameters ---------- version : str Version of the checkpoint being loaded args : Dict[str, Any] Arguments dictionary from the checkpoint Returns ------- Dict[str, Any] Updated arguments dictionary compatible with current version """ # Call parent class method first args = super()._backward_compat_arg_mapper(version, args) if version == "0.1.0": # In version 0.1.0, img_channels was unused if "img_channels" in args: _ = args.pop("img_channels") # Sigma parameters are also unused if "sigma_min" in args: _ = args.pop("sigma_min") if "sigma_max" in args: _ = args.pop("sigma_max") if "sigma_data" in args: _ = args.pop("sigma_data") return args def __init__( self, img_resolution: Union[int, Tuple[int, int]], img_in_channels: int, img_out_channels: int, use_fp16: bool = False, model_type: Literal[ "SongUNetPosEmbd", "SongUNetPosLtEmbd", "SongUNet", "DhariwalUNet" ] = "SongUNetPosEmbd", **model_kwargs: dict, ): super().__init__(meta=MetaData) # Validation if model_type not in self._wrapped_classes: raise ValueError( f"Model type '{model_type}' is not supported. " f"Must be one of: {', '.join(self._wrapped_classes)}" ) # for compatibility with older versions that took only 1 dimension if isinstance(img_resolution, int): self.img_shape_x = self.img_shape_y = img_resolution else: self.img_shape_y = img_resolution[0] self.img_shape_x = img_resolution[1] self.img_in_channels = img_in_channels self.img_out_channels = img_out_channels model_class = getattr(network_module, model_type) self.model = model_class( img_resolution=img_resolution, in_channels=img_in_channels + img_out_channels, out_channels=img_out_channels, **model_kwargs, ) self.use_fp16 = use_fp16 # Properties delegated to the wrapped model amp_mode = _wrapped_property( "amp_mode", "model", "Set to ``True`` when using automatic mixed precision.", ) profile_mode = _wrapped_property( "profile_mode", "model", "Set to ``True`` to enable profiling of the wrapped model.", ) @property def use_fp16(self): """ bool: Whether the model uses float16 precision. Returns ------- bool True if the model is in float16 mode, False otherwise. """ return self._use_fp16 @use_fp16.setter def use_fp16(self, value: bool): """ Set whether the model should use float16 precision. Parameters ---------- value : bool If True, moves the model to torch.float16. If False, moves to torch.float32. Raises ------ ValueError If `value` is not a boolean. """ # NOTE: allow 0/1 values for older checkpoints if not (isinstance(value, bool) or value in [0, 1]): raise ValueError( f"`use_fp16` must be a boolean, but got {type(value).__name__}." ) self._use_fp16 = value if value: self.to(torch.float16) else: self.to(torch.float32) def forward( self, x: torch.Tensor, img_lr: torch.Tensor, force_fp32: bool = False, **model_kwargs: dict, ) -> torch.Tensor: # SR: concatenate input channels if img_lr is not None: x = torch.cat((x, img_lr), dim=1) dtype = ( torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32 ) F_x = self.model( x.to(dtype), # (c_in * x).to(dtype), torch.zeros(x.shape[0], dtype=dtype, device=x.device), # c_noise.flatten() class_labels=None, **model_kwargs, ) if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): raise ValueError( f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." ) # skip connection D_x = F_x.to(torch.float32) return D_x def round_sigma(self, sigma: Union[float, List, torch.Tensor]) -> torch.Tensor: """ Convert a given sigma value(s) to a tensor representation. Parameters ---------- sigma : Union[float, List, torch.Tensor] The sigma value(s) to convert. Returns ------- torch.Tensor The tensor representation of the provided sigma value(s). """ return torch.as_tensor(sigma) # TODO: implement amp_mode and profile_mode properties for StormCastUNet (same # as UNet) class StormCastUNet(Module): """ U-Net wrapper for StormCast; used so the same Song U-Net network can be re-used for this model. Parameters ----------- img_resolution : int or List[int] The resolution of the input/output image. img_channels : int Number of color channels. img_in_channels : int Number of input color channels. img_out_channels : int Number of output color channels. use_fp16: bool, optional Execute the underlying model at FP16 precision?, by default False. sigma_min: float, optional Minimum supported noise level, by default 0. sigma_max: float, optional Maximum supported noise level, by default float('inf'). sigma_data: float, optional Expected standard deviation of the training data, by default 0.5. model_type: str, optional Class name of the underlying model, by default 'SongUNet'. **model_kwargs : dict Keyword arguments for the underlying model. """ def __init__( self, img_resolution, img_in_channels, img_out_channels, use_fp16=False, sigma_min=0, sigma_max=float("inf"), sigma_data=0.5, model_type="SongUNet", **model_kwargs, ): super().__init__(meta=MetaData("StormCastUNet")) if isinstance(img_resolution, int): self.img_shape_x = self.img_shape_y = img_resolution else: self.img_shape_x = img_resolution[0] self.img_shape_y = img_resolution[1] self.img_in_channels = img_in_channels self.img_out_channels = img_out_channels self.use_fp16 = use_fp16 self.sigma_min = sigma_min self.sigma_max = sigma_max self.sigma_data = sigma_data model_class = getattr(network_module, model_type) self.model = model_class( img_resolution=img_resolution, in_channels=img_in_channels, out_channels=img_out_channels, **model_kwargs, ) # Properties delegated to the wrapped model amp_mode = _wrapped_property( "amp_mode", "model", "Set to ``True`` when using automatic mixed precision.", ) profile_mode = _wrapped_property( "profile_mode", "model", "Set to ``True`` to enable profiling of the wrapped model.", ) def forward(self, x, force_fp32=False, **model_kwargs): """Run a forward pass of the StormCast regression U-Net. Args: x (torch.Tensor): input to the U-Net force_fp32 (bool, optional): force casting to fp_32 if True. Defaults to False. Raises: ValueError: If input data type is a mismatch with provided options Returns: D_x (torch.Tensor): Output (prediction) of the U-Net """ x = x.to(torch.float32) dtype = ( torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32 ) F_x = self.model( x.to(dtype), torch.zeros(x.shape[0], dtype=x.dtype, device=x.device), class_labels=None, **model_kwargs, ) if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): raise ValueError( f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." ) D_x = F_x.to(torch.float32) return D_x