ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
class CubeEmbedding(nn.Module):
"""
3D Image Cube Embedding
Args:
img_size (tuple[int]): Image size [T, Lat, Lon].
patch_size (tuple[int]): Patch token size [T, Lat, Lon].
in_chans (int): Number of input image channels.
embed_dim (int): Number of projection output channels.
norm_layer (nn.Module, optional): Normalization layer. Default: torch.nn.LayerNorm
"""
def __init__(
self, img_size, patch_size, in_chans, embed_dim, norm_layer=nn.LayerNorm
):
super().__init__()
patches_resolution = [
img_size[0] // patch_size[0],
img_size[1] // patch_size[1],
img_size[2] // patch_size[2],
]
self.img_size = img_size
self.patches_resolution = patches_resolution
self.embed_dim = embed_dim
self.proj = nn.Conv3d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x: torch.Tensor):
B, C, T, Lat, Lon = x.shape
x = self.proj(x).reshape(B, self.embed_dim, -1).transpose(1, 2) # B T*Lat*Lon C
if self.norm is not None:
x = self.norm(x)
x = x.transpose(1, 2).reshape(B, self.embed_dim, *self.patches_resolution)
return x
class ConvBlock(nn.Module):
"""
Conv2d block
Args:
in_chans (int): Number of input channels.
out_chans (int): Number of output channels.
num_groups (int): Number of groups to separate the channels into for group normalization.
num_residuals (int, optinal): Number of Conv2d operator. Default: 2
upsample (int, optinal): 1: Upsample, 0: Conv, -1: Downsample. Default: 0
"""
def __init__(self, in_chans, out_chans, num_groups, num_residuals=2, upsample=0):
super().__init__()
if upsample == 1:
self.conv = nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2)
elif upsample == -1:
self.conv = nn.Conv2d(
in_chans, out_chans, kernel_size=(3, 3), stride=2, padding=1
)
elif upsample == 0:
self.conv = nn.Conv2d(
in_chans, out_chans, kernel_size=(3, 3), stride=1, padding=1
)
blk = []
for i in range(num_residuals):
blk.append(
nn.Conv2d(out_chans, out_chans, kernel_size=3, stride=1, padding=1)
)
blk.append(nn.GroupNorm(num_groups, out_chans))
blk.append(nn.SiLU())
self.b = nn.Sequential(*blk)
def forward(self, x):
x = self.conv(x)
x_skip = x
x = self.b(x)
return x + x_skip