Spaces:
Sleeping
Sleeping
File size: 4,606 Bytes
c3d0544 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Literal, Tuple, Union
import numpy as np
import torch
from physicsnemo.models.diffusion.preconditioning import EDMPrecondSuperResolution
from physicsnemo.models.meta import ModelMetaData
@dataclass
class tEDMPrecondSuperResMetaData(ModelMetaData):
"""tEDMPrecondSuperRes meta data"""
name: str = "tEDMPrecondSuperRes"
# Optimization
jit: bool = False
cuda_graphs: bool = False
amp_cpu: bool = False
amp_gpu: bool = True
torch_fx: bool = False
# Data type
bf16: bool = False
# Inference
onnx: bool = False
# Physics informed
func_torch: bool = False
auto_grad: bool = False
class tEDMPrecondSuperRes(EDMPrecondSuperResolution):
"""
Preconditioning proposed in the paper `Heavy-Tailed Diffusion Models,
Pandey et al. <https://arxiv.org/abs/2410.14171>`_ (t-EDM). A variant of
:class:`~physicsnemo.models.diffusion.preconditioning.EDMPrecondSuperResolution`
that replaces the traditional Gaussian noise with a noise sampled from a
Student-t distribution.
Parameters
----------
img_resolution : Union[int, Tuple[int, int]]
Spatial resolution :math:`(H, W)` of the image. If a single int is provided,
the image is assumed to be square.
img_in_channels : int
Number of input channels in the low-resolution input image.
img_out_channels : int
Number of output channels in the high-resolution output image.
use_fp16 : bool, optional
Whether to use half-precision floating point (FP16) for model execution,
by default False.
model_type : str, optional
Class name of the underlying model. Must be one of the following:
'SongUNet', 'SongUNetPosEmbd', 'SongUNetPosLtEmbd', 'DhariwalUNet'.
Defaults to 'SongUNetPosEmbd'.
sigma_data : float, optional
Expected standard deviation of the training data, by default 0.5.
sigma_min : float, optional
Minimum supported noise level, by default 0.0.
sigma_max : float, optional
Maximum supported noise level, by default inf.
nu : int, optional, default=10
Number of degrees of freedom used for the Student-t distribution.
Must be strictly greater than 2.
**model_kwargs : dict
Keyword arguments passed to the underlying model `__init__` method.
"""
def __init__(
self,
img_resolution: Union[int, Tuple[int, int]],
img_in_channels: int,
img_out_channels: int,
use_fp16: bool = False,
model_type: Literal[
"SongUNetPosEmbd", "SongUNetPosLtEmbd", "SongUNet", "DhariwalUNet"
] = "SongUNetPosEmbd",
sigma_data: float = 0.5,
sigma_min=0.0,
sigma_max=float("inf"),
nu: int = 10,
**model_kwargs: Any,
):
# NOTE: Check if nu is greater than 2. This is to ensure the variance of the
# Student-t prior during sampling is finite.
if nu <= 2:
raise ValueError(f"Expected nu > 2, but got {nu}.")
super().__init__(
img_resolution=img_resolution,
img_in_channels=img_in_channels,
img_out_channels=img_out_channels,
use_fp16=use_fp16,
model_type=model_type,
sigma_data=sigma_data,
sigma_min=sigma_min,
sigma_max=sigma_max,
**model_kwargs,
)
self.nu = nu
self.meta = tEDMPrecondSuperResMetaData()
def forward(
self,
x: torch.Tensor,
img_lr: torch.Tensor,
sigma: torch.Tensor,
force_fp32: bool = False,
**model_kwargs: dict,
):
# Rescale sigma to account for nu scaling
sigma *= np.sqrt(self.nu / (self.nu - 2))
return super().forward(x, img_lr, sigma, force_fp32, **model_kwargs)
|