# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import nn class CubeEmbedding(nn.Module): """ 3D Image Cube Embedding Args: img_size (tuple[int]): Image size [T, Lat, Lon]. patch_size (tuple[int]): Patch token size [T, Lat, Lon]. in_chans (int): Number of input image channels. embed_dim (int): Number of projection output channels. norm_layer (nn.Module, optional): Normalization layer. Default: torch.nn.LayerNorm """ def __init__( self, img_size, patch_size, in_chans, embed_dim, norm_layer=nn.LayerNorm ): super().__init__() patches_resolution = [ img_size[0] // patch_size[0], img_size[1] // patch_size[1], img_size[2] // patch_size[2], ] self.img_size = img_size self.patches_resolution = patches_resolution self.embed_dim = embed_dim self.proj = nn.Conv3d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size ) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x: torch.Tensor): B, C, T, Lat, Lon = x.shape x = self.proj(x).reshape(B, self.embed_dim, -1).transpose(1, 2) # B T*Lat*Lon C if self.norm is not None: x = self.norm(x) x = x.transpose(1, 2).reshape(B, self.embed_dim, *self.patches_resolution) return x class ConvBlock(nn.Module): """ Conv2d block Args: in_chans (int): Number of input channels. out_chans (int): Number of output channels. num_groups (int): Number of groups to separate the channels into for group normalization. num_residuals (int, optinal): Number of Conv2d operator. Default: 2 upsample (int, optinal): 1: Upsample, 0: Conv, -1: Downsample. Default: 0 """ def __init__(self, in_chans, out_chans, num_groups, num_residuals=2, upsample=0): super().__init__() if upsample == 1: self.conv = nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2) elif upsample == -1: self.conv = nn.Conv2d( in_chans, out_chans, kernel_size=(3, 3), stride=2, padding=1 ) elif upsample == 0: self.conv = nn.Conv2d( in_chans, out_chans, kernel_size=(3, 3), stride=1, padding=1 ) blk = [] for i in range(num_residuals): blk.append( nn.Conv2d(out_chans, out_chans, kernel_size=3, stride=1, padding=1) ) blk.append(nn.GroupNorm(num_groups, out_chans)) blk.append(nn.SiLU()) self.b = nn.Sequential(*blk) def forward(self, x): x = self.conv(x) x_skip = x x = self.b(x) return x + x_skip