ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
class UpSample3D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
3D Up-sampling operation.
Implementation from: https://github.com/198808xc/Pangu-Weather/blob/main/pseudocode.py
Args:
in_dim (int): Number of input channels.
out_dim (int): Number of output channels.
input_resolution (tuple[int]): [pressure levels, latitude, longitude]
output_resolution (tuple[int]): [pressure levels, latitude, longitude]
"""
def __init__(self, in_dim, out_dim, input_resolution, output_resolution):
super().__init__()
self.linear1 = nn.Linear(in_dim, out_dim * 4, bias=False)
self.linear2 = nn.Linear(out_dim, out_dim, bias=False)
self.norm = nn.LayerNorm(out_dim)
self.input_resolution = input_resolution
self.output_resolution = output_resolution
def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): (B, N, C)
"""
B, N, C = x.shape
in_pl, in_lat, in_lon = self.input_resolution
out_pl, out_lat, out_lon = self.output_resolution
x = self.linear1(x)
x = x.reshape(B, in_pl, in_lat, in_lon, 2, 2, C // 2).permute(
0, 1, 2, 4, 3, 5, 6
)
x = x.reshape(B, in_pl, in_lat * 2, in_lon * 2, -1)
pad_h = in_lat * 2 - out_lat
pad_w = in_lon * 2 - out_lon
pad_top = pad_h // 2
pad_bottom = pad_h - pad_top
pad_left = pad_w // 2
pad_right = pad_w - pad_left
x = x[
:,
:out_pl,
pad_top : 2 * in_lat - pad_bottom,
pad_left : 2 * in_lon - pad_right,
:,
]
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3], x.shape[4])
x = self.norm(x)
x = self.linear2(x)
return x
class UpSample2D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
2D Up-sampling operation.
Args:
in_dim (int): Number of input channels.
out_dim (int): Number of output channels.
input_resolution (tuple[int]): [latitude, longitude]
output_resolution (tuple[int]): [latitude, longitude]
"""
def __init__(self, in_dim, out_dim, input_resolution, output_resolution):
super().__init__()
self.linear1 = nn.Linear(in_dim, out_dim * 4, bias=False)
self.linear2 = nn.Linear(out_dim, out_dim, bias=False)
self.norm = nn.LayerNorm(out_dim)
self.input_resolution = input_resolution
self.output_resolution = output_resolution
def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): (B, N, C)
"""
B, N, C = x.shape
in_lat, in_lon = self.input_resolution
out_lat, out_lon = self.output_resolution
x = self.linear1(x)
x = x.reshape(B, in_lat, in_lon, 2, 2, C // 2).permute(0, 1, 3, 2, 4, 5)
x = x.reshape(B, in_lat * 2, in_lon * 2, -1)
pad_h = in_lat * 2 - out_lat
pad_w = in_lon * 2 - out_lon
pad_top = pad_h // 2
pad_bottom = pad_h - pad_top
pad_left = pad_w // 2
pad_right = pad_w - pad_left
x = x[
:, pad_top : 2 * in_lat - pad_bottom, pad_left : 2 * in_lon - pad_right, :
]
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
x = self.norm(x)
x = self.linear2(x)
return x
class DownSample3D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
3D Down-sampling operation
Implementation from: https://github.com/198808xc/Pangu-Weather/blob/main/pseudocode.py
Args:
in_dim (int): Number of input channels.
input_resolution (tuple[int]): [pressure levels, latitude, longitude]
output_resolution (tuple[int]): [pressure levels, latitude, longitude]
"""
def __init__(self, in_dim, input_resolution, output_resolution):
super().__init__()
self.linear = nn.Linear(in_dim * 4, in_dim * 2, bias=False)
self.norm = nn.LayerNorm(4 * in_dim)
self.input_resolution = input_resolution
self.output_resolution = output_resolution
in_pl, in_lat, in_lon = self.input_resolution
out_pl, out_lat, out_lon = self.output_resolution
h_pad = out_lat * 2 - in_lat
w_pad = out_lon * 2 - in_lon
pad_top = h_pad // 2
pad_bottom = h_pad - pad_top
pad_left = w_pad // 2
pad_right = w_pad - pad_left
pad_front = pad_back = 0
self.pad = nn.ZeroPad3d(
(pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back)
)
def forward(self, x):
B, N, C = x.shape
in_pl, in_lat, in_lon = self.input_resolution
out_pl, out_lat, out_lon = self.output_resolution
x = x.reshape(B, in_pl, in_lat, in_lon, C)
# Padding the input to facilitate downsampling
x = self.pad(x.permute(0, -1, 1, 2, 3)).permute(0, 2, 3, 4, 1)
x = x.reshape(B, in_pl, out_lat, 2, out_lon, 2, C).permute(0, 1, 2, 4, 3, 5, 6)
x = x.reshape(B, out_pl * out_lat * out_lon, 4 * C)
x = self.norm(x)
x = self.linear(x)
return x
class DownSample2D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
2D Down-sampling operation
Args:
in_dim (int): Number of input channels.
input_resolution (tuple[int]): [latitude, longitude]
output_resolution (tuple[int]): [latitude, longitude]
"""
def __init__(self, in_dim, input_resolution, output_resolution):
super().__init__()
self.linear = nn.Linear(in_dim * 4, in_dim * 2, bias=False)
self.norm = nn.LayerNorm(4 * in_dim)
self.input_resolution = input_resolution
self.output_resolution = output_resolution
in_lat, in_lon = self.input_resolution
out_lat, out_lon = self.output_resolution
h_pad = out_lat * 2 - in_lat
w_pad = out_lon * 2 - in_lon
pad_top = h_pad // 2
pad_bottom = h_pad - pad_top
pad_left = w_pad // 2
pad_right = w_pad - pad_left
self.pad = nn.ZeroPad2d((pad_left, pad_right, pad_top, pad_bottom))
def forward(self, x: torch.Tensor):
B, N, C = x.shape
in_lat, in_lon = self.input_resolution
out_lat, out_lon = self.output_resolution
x = x.reshape(B, in_lat, in_lon, C)
# Padding the input to facilitate downsampling
x = self.pad(x.permute(0, -1, 1, 2)).permute(0, 2, 3, 1)
x = x.reshape(B, out_lat, 2, out_lon, 2, C).permute(0, 1, 3, 2, 4, 5)
x = x.reshape(B, out_lat * out_lon, 4 * C)
x = self.norm(x)
x = self.linear(x)
return x