File size: 13,516 Bytes
c3d0544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Callable, Optional

import torch
from torch import Tensor

from physicsnemo.models.diffusion import EDMPrecond
from physicsnemo.utils.patching import GridPatching2D


# NOTE: use two wrappers for apply, to avoid recompilation when input shape changes
@torch.compile()
def _apply_wrapper_Cin_channels(patching, input, additional_input=None):
    """
    Apply the patching operation to the input tensor with :math:`C_{in}` channels.
    """
    return patching.apply(input=input, additional_input=additional_input)


@torch.compile()
def _apply_wrapper_Cout_channels_no_grad(patching, input, additional_input=None):
    """
    Apply the patching operation to an input tensor with :math:`C_{out}`
    channels that does not require gradients.
    """
    return patching.apply(input=input, additional_input=additional_input)


@torch.compile()
def _apply_wrapper_Cout_channels_grad(patching, input, additional_input=None):
    """
    Apply the patching operation to an input tensor with :math:`C_{out}`
    channels that requires gradients.
    """
    return patching.apply(input=input, additional_input=additional_input)


@torch.compile()
def _fuse_wrapper(patching, input, batch_size):
    return patching.fuse(input=input, batch_size=batch_size)


def _apply_wrapper_select(
    input: torch.Tensor, patching: GridPatching2D | None
) -> Callable:
    """
    Select the correct patching wrapper based on the input tensor's requires_grad attribute.
    If patching is None, return the identity function.
    If patching is not None, return the appropriate patching wrapper.
    If input.requires_grad is True, return _apply_wrapper_Cout_channels_grad.
    If input.requires_grad is False, return _apply_wrapper_Cout_channels_no_grad.
    """
    if patching:
        if input.requires_grad:
            return _apply_wrapper_Cout_channels_grad
        else:
            return _apply_wrapper_Cout_channels_no_grad
    else:
        return lambda patching, input, additional_input=None: input


def stochastic_sampler(
    net: torch.nn.Module,
    latents: Tensor,
    img_lr: Tensor,
    class_labels: Optional[Tensor] = None,
    randn_like: Callable[[Tensor], Tensor] = torch.randn_like,
    patching: Optional[GridPatching2D] = None,
    mean_hr: Optional[Tensor] = None,
    lead_time_label: Optional[Tensor] = None,
    num_steps: int = 18,
    sigma_min: float = 0.002,
    sigma_max: float = 800,
    rho: float = 7,
    S_churn: float = 0,
    S_min: float = 0,
    S_max: float = float("inf"),
    S_noise: float = 1,
) -> Tensor:
    r"""
    Proposed EDM sampler (Algorithm 2) with minor changes to enable
    super-resolution and patch-based diffusion.

    Parameters
    ----------
    net : torch.nn.Module
        The neural network model that generates denoised images from noisy
        inputs.
        Expected signature: ``net(x, x_lr, t_hat, class_labels,
        lead_time_label=lead_time_label,
        embedding_selector=embedding_selector)``.

        Inputs:
            - **x** (*torch.Tensor*): Noisy input of shape :math:`(B, C_{out}, H, W)`
            - **x_lr** (*torch.Tensor*): Conditioning input of shape :math:`(B, C_{cond}, H, W)`
            - **t_hat** (*torch.Tensor*): Noise level of shape :math:`(B, 1, 1, 1)` or scalar
            - **class_labels** (*torch.Tensor, optional*): Optional class labels
            - **lead_time_label** (*torch.Tensor, optional*): Optional lead time labels
            - **embedding_selector** (*callable, optional*): Function to select
              positional embeddings. Used for patch-based diffusion.

        Output:
            - **denoised** (*torch.Tensor*): Denoised prediction of shape :math:`(B, C_{out}, H, W)`

        Required attributes:
            - **sigma_min** (*float*): Minimum supported noise level for the model
            - **sigma_max** (*float*): Maximum supported noise level for the model
            - **round_sigma** (*callable*): Method to convert sigma values to
              tensor representation

    latents : Tensor
        The latent variables (e.g., noise) used as the initial input for the
        sampler. Has shape :math:`(B, C_{out}, H, W)`.
    img_lr : Tensor
        Low-resolution input image for conditioning the super-resolution
        process. Must have shape :math:`(B, C_{lr}, H, W)`.
    class_labels : Optional[Tensor], optional
        Class labels for conditional generation, if required by the model. By
        default ``None``.
    randn_like : Callable[[Tensor], Tensor]
        Function to generate random noise with the same shape as the input
        tensor.
        By default ``torch.randn_like``.
    patching : Optional[GridPatching2D], default=None
        A patching utility for patch-based diffusion. Implements methods to
        extract patches from an image and batch the patches along dim=0.
        Should also implement a ``fuse`` method to reconstruct the original
        image from a batch of patches. See
        :class:`~physicsnemo.utils.patching.GridPatching2D` for details. By
        default ``None``, in which case non-patched diffusion is used.
    mean_hr : Optional[Tensor], optional
        Optional tensor containing mean high-resolution images for
        conditioning. Must have same height and width as ``img_lr``, with shape
        :math:`(B_{hr}, C_{hr}, H, W)`  where the batch dimension
        :math:`B_{hr}` can be either 1, either equal to batch_size, or can be omitted. If
        :math:`B_{hr} = 1` or is omitted, ``mean_hr`` will be expanded to match the shape
        of ``img_lr``. By default ``None``.
    lead_time_label : Optional[Tensor], optional
        Optional lead time labels. By default ``None``.
    num_steps : int
        Number of time steps for the sampler. By default 18.
    sigma_min : float
        Minimum noise level. By default 0.002.
    sigma_max : float
        Maximum noise level. By default 800.
    rho : float
        Exponent used in the time step discretization. By default 7.
    S_churn : float
        Churn parameter controlling the level of noise added in each step. By
        default 0.
    S_min : float
        Minimum time step for applying churn. By default 0.
    S_max : float
        Maximum time step for applying churn. By default ``float("inf")``.
    S_noise : float
        Noise scaling factor applied during the churn step. By default 1.

    Returns
    -------
    Tensor
        The final denoised image produced by the sampler. Same shape as
        ``latents``: :math:`(B, C_{out}, H, W)`.

    See Also
    --------
    :class:`~physicsnemo.models.diffusion.preconditioning.EDMPrecondSuperResolution`: A model
        wrapper that provides preconditioning for super-resolution diffusion
        models and implements the required interface for this sampler.
    """

    # Adjust noise levels based on what's supported by the network.
    # Proposed EDM sampler (Algorithm 2) with minor changes to enable
    # super-resolution/
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)

    # Safety check on type of patching
    if patching is not None and not isinstance(patching, GridPatching2D):
        raise ValueError("patching must be an instance of GridPatching2D.")

    # Safety check: if patching is used then img_lr and latents must have same
    # height and width, otherwise there is mismatch in the number
    # of patches extracted to form the final batch_size.
    if patching:
        if img_lr.shape[-2:] != latents.shape[-2:]:
            raise ValueError(
                f"img_lr and latents must have the same height and width, "
                f"but found {img_lr.shape[-2:]} vs {latents.shape[-2:]}. "
            )
    # img_lr and latents must also have the same batch_size, otherwise mismatch
    # when processed by the network
    if img_lr.shape[0] != latents.shape[0]:
        raise ValueError(
            f"img_lr and latents must have the same batch size, but found "
            f"{img_lr.shape[0]} vs {latents.shape[0]}."
        )

    # Time step discretization.
    step_indices = torch.arange(num_steps, device=latents.device)
    t_steps = (
        sigma_max ** (1 / rho)
        + step_indices
        / (num_steps - 1)
        * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
    ) ** rho
    t_steps = torch.cat(
        [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]
    )  # t_N = 0

    batch_size = img_lr.shape[0]

    # conditioning = [mean_hr, img_lr, global_lr, pos_embd]
    x_lr = img_lr
    if mean_hr is not None:
        if mean_hr.shape[-2:] != img_lr.shape[-2:]:
            raise ValueError(
                f"mean_hr and img_lr must have the same height and width, "
                f"but found {mean_hr.shape[-2:]} vs {img_lr.shape[-2:]}."
            )
        x_lr = torch.cat((mean_hr.expand(x_lr.shape[0], -1, -1, -1), x_lr), dim=1)

    # input and position padding + patching
    if patching:
        # Patched conditioning [x_lr, mean_hr]
        # (batch_size * patch_num, C_in + C_out, patch_shape_y, patch_shape_x)
        x_lr = _apply_wrapper_Cin_channels(
            patching=patching, input=x_lr, additional_input=img_lr
        )

        # Function to select the correct positional embedding for each patch
        def patch_embedding_selector(emb):
            # emb: (N_pe, image_shape_y, image_shape_x)
            # return: (batch_size * patch_num, N_pe, patch_shape_y, patch_shape_x)
            return patching.apply(emb.expand(batch_size, -1, -1, -1))

    else:
        patch_embedding_selector = None

    optional_args = {}
    if lead_time_label is not None:
        optional_args["lead_time_label"] = lead_time_label
    if patching:
        optional_args["embedding_selector"] = patch_embedding_selector

    # Main sampling loop.
    x_next = latents * t_steps[0]
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):  # 0, ..., N-1
        x_cur = x_next
        # Increase noise temporarily.
        gamma = S_churn / num_steps if S_min <= t_cur <= S_max else 0
        t_hat = net.round_sigma(t_cur + gamma * t_cur)

        x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur)

        # Euler step. Perform patching operation on score tensor if patch-based
        # generation is used denoised = net(x_hat, t_hat,
        # class_labels,lead_time_label=lead_time_label).to(torch.float64)
        x_hat_batch = _apply_wrapper_select(input=x_hat, patching=patching)(
            patching=patching, input=x_hat
        ).to(latents.device)

        x_lr = x_lr.to(latents.device)

        if isinstance(net, EDMPrecond):
            # Conditioning info is passed as keyword arg
            denoised = net(
                x_hat_batch,
                t_hat,
                condition=x_lr,
                class_labels=class_labels,
                **optional_args,
            )
        else:
            denoised = net(
                x_hat_batch,
                x_lr,
                t_hat,
                class_labels,
                **optional_args,
            )

        if patching:
            # Un-patch the denoised image
            # (batch_size, C_out, img_shape_y, img_shape_x)
            denoised = _fuse_wrapper(
                patching=patching, input=denoised, batch_size=batch_size
            )

        d_cur = (x_hat - denoised) / t_hat
        x_next = x_hat + (t_next - t_hat) * d_cur

        # Apply 2nd order correction.
        if i < num_steps - 1:
            # Patched input
            # (batch_size * patch_num, C_out, patch_shape_y, patch_shape_x)
            x_next_batch = _apply_wrapper_select(input=x_next, patching=patching)(
                patching=patching, input=x_next
            ).to(latents.device)

            if isinstance(net, EDMPrecond):
                # Conditioning info is passed as keyword arg
                denoised = net(
                    x_next_batch,
                    t_next,
                    condition=x_lr,
                    class_labels=class_labels,
                    **optional_args,
                )
            else:
                denoised = net(
                    x_next_batch,
                    x_lr,
                    t_next,
                    class_labels,
                    **optional_args,
                )

            if patching:
                # Un-patch the denoised image
                # (batch_size, C_out, img_shape_y, img_shape_x)
                denoised = _fuse_wrapper(
                    patching=patching, input=denoised, batch_size=batch_size
                )

            d_prime = (x_next - denoised) / t_next
            x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
    return x_next