ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
class PatchEmbed2D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
2D Image to Patch Embedding.
Args:
img_size (tuple[int]): Image size.
patch_size (tuple[int]): Patch token size.
in_chans (int): Number of input image channels.
embed_dim(int): Number of projection output channels.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None):
super().__init__()
self.img_size = img_size
height, width = img_size
h_patch_size, w_path_size = patch_size
padding_left = padding_right = padding_top = padding_bottom = 0
h_remainder = height % h_patch_size
w_remainder = width % w_path_size
if h_remainder:
h_pad = h_patch_size - h_remainder
padding_top = h_pad // 2
padding_bottom = int(h_pad - padding_top)
if w_remainder:
w_pad = w_path_size - w_remainder
padding_left = w_pad // 2
padding_right = int(w_pad - padding_left)
self.pad = nn.ZeroPad2d(
(padding_left, padding_right, padding_top, padding_bottom)
)
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x: torch.Tensor):
B, C, H, W = x.shape
x = self.pad(x)
x = self.proj(x)
if self.norm is not None:
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
class PatchEmbed3D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
3D Image to Patch Embedding.
Args:
img_size (tuple[int]): Image size.
patch_size (tuple[int]): Patch token size.
in_chans (int): Number of input image channels.
embed_dim(int): Number of projection output channels.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None):
super().__init__()
self.img_size = img_size
level, height, width = img_size
l_patch_size, h_patch_size, w_patch_size = patch_size
padding_left = padding_right = padding_top = padding_bottom = padding_front = (
padding_back
) = 0
l_remainder = level % l_patch_size
h_remainder = height % l_patch_size
w_remainder = width % w_patch_size
if l_remainder:
l_pad = l_patch_size - l_remainder
padding_front = l_pad // 2
padding_back = l_pad - padding_front
if h_remainder:
h_pad = h_patch_size - h_remainder
padding_top = h_pad // 2
padding_bottom = h_pad - padding_top
if w_remainder:
w_pad = w_patch_size - w_remainder
padding_left = w_pad // 2
padding_right = w_pad - padding_left
self.pad = nn.ZeroPad3d(
(
padding_left,
padding_right,
padding_top,
padding_bottom,
padding_front,
padding_back,
)
)
self.proj = nn.Conv3d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x: torch.Tensor):
B, C, L, H, W = x.shape
x = self.pad(x)
x = self.proj(x)
if self.norm:
x = self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)
return x
class PatchRecovery2D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
Patch Embedding Recovery to 2D Image.
Args:
img_size (tuple[int]): Lat, Lon
patch_size (tuple[int]): Lat, Lon
in_chans (int): Number of input channels.
out_chans (int): Number of output channels.
"""
def __init__(self, img_size, patch_size, in_chans, out_chans):
super().__init__()
self.img_size = img_size
self.conv = nn.ConvTranspose2d(in_chans, out_chans, patch_size, patch_size)
def forward(self, x):
output = self.conv(x)
_, _, H, W = output.shape
h_pad = H - self.img_size[0]
w_pad = W - self.img_size[1]
padding_top = h_pad // 2
padding_bottom = int(h_pad - padding_top)
padding_left = w_pad // 2
padding_right = int(w_pad - padding_left)
return output[
:, :, padding_top : H - padding_bottom, padding_left : W - padding_right
]
class PatchRecovery3D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
Patch Embedding Recovery to 3D Image.
Args:
img_size (tuple[int]): Pl, Lat, Lon
patch_size (tuple[int]): Pl, Lat, Lon
in_chans (int): Number of input channels.
out_chans (int): Number of output channels.
"""
def __init__(self, img_size, patch_size, in_chans, out_chans):
super().__init__()
self.img_size = img_size
self.conv = nn.ConvTranspose3d(in_chans, out_chans, patch_size, patch_size)
def forward(self, x: torch.Tensor):
output = self.conv(x)
_, _, Pl, Lat, Lon = output.shape
pl_pad = Pl - self.img_size[0]
lat_pad = Lat - self.img_size[1]
lon_pad = Lon - self.img_size[2]
padding_front = pl_pad // 2
padding_back = pl_pad - padding_front
padding_top = lat_pad // 2
padding_bottom = lat_pad - padding_top
padding_left = lon_pad // 2
padding_right = lon_pad - padding_left
return output[
:,
:,
padding_front : Pl - padding_back,
padding_top : Lat - padding_bottom,
padding_left : Lon - padding_right,
]