Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| from torch import nn | |
| class UpSample3D(nn.Module): | |
| """ | |
| Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn | |
| 3D Up-sampling operation. | |
| Implementation from: https://github.com/198808xc/Pangu-Weather/blob/main/pseudocode.py | |
| Args: | |
| in_dim (int): Number of input channels. | |
| out_dim (int): Number of output channels. | |
| input_resolution (tuple[int]): [pressure levels, latitude, longitude] | |
| output_resolution (tuple[int]): [pressure levels, latitude, longitude] | |
| """ | |
| def __init__(self, in_dim, out_dim, input_resolution, output_resolution): | |
| super().__init__() | |
| self.linear1 = nn.Linear(in_dim, out_dim * 4, bias=False) | |
| self.linear2 = nn.Linear(out_dim, out_dim, bias=False) | |
| self.norm = nn.LayerNorm(out_dim) | |
| self.input_resolution = input_resolution | |
| self.output_resolution = output_resolution | |
| def forward(self, x: torch.Tensor): | |
| """ | |
| Args: | |
| x (torch.Tensor): (B, N, C) | |
| """ | |
| B, N, C = x.shape | |
| in_pl, in_lat, in_lon = self.input_resolution | |
| out_pl, out_lat, out_lon = self.output_resolution | |
| x = self.linear1(x) | |
| x = x.reshape(B, in_pl, in_lat, in_lon, 2, 2, C // 2).permute( | |
| 0, 1, 2, 4, 3, 5, 6 | |
| ) | |
| x = x.reshape(B, in_pl, in_lat * 2, in_lon * 2, -1) | |
| pad_h = in_lat * 2 - out_lat | |
| pad_w = in_lon * 2 - out_lon | |
| pad_top = pad_h // 2 | |
| pad_bottom = pad_h - pad_top | |
| pad_left = pad_w // 2 | |
| pad_right = pad_w - pad_left | |
| x = x[ | |
| :, | |
| :out_pl, | |
| pad_top : 2 * in_lat - pad_bottom, | |
| pad_left : 2 * in_lon - pad_right, | |
| :, | |
| ] | |
| x = x.reshape(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3], x.shape[4]) | |
| x = self.norm(x) | |
| x = self.linear2(x) | |
| return x | |
| class UpSample2D(nn.Module): | |
| """ | |
| Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn | |
| 2D Up-sampling operation. | |
| Args: | |
| in_dim (int): Number of input channels. | |
| out_dim (int): Number of output channels. | |
| input_resolution (tuple[int]): [latitude, longitude] | |
| output_resolution (tuple[int]): [latitude, longitude] | |
| """ | |
| def __init__(self, in_dim, out_dim, input_resolution, output_resolution): | |
| super().__init__() | |
| self.linear1 = nn.Linear(in_dim, out_dim * 4, bias=False) | |
| self.linear2 = nn.Linear(out_dim, out_dim, bias=False) | |
| self.norm = nn.LayerNorm(out_dim) | |
| self.input_resolution = input_resolution | |
| self.output_resolution = output_resolution | |
| def forward(self, x: torch.Tensor): | |
| """ | |
| Args: | |
| x (torch.Tensor): (B, N, C) | |
| """ | |
| B, N, C = x.shape | |
| in_lat, in_lon = self.input_resolution | |
| out_lat, out_lon = self.output_resolution | |
| x = self.linear1(x) | |
| x = x.reshape(B, in_lat, in_lon, 2, 2, C // 2).permute(0, 1, 3, 2, 4, 5) | |
| x = x.reshape(B, in_lat * 2, in_lon * 2, -1) | |
| pad_h = in_lat * 2 - out_lat | |
| pad_w = in_lon * 2 - out_lon | |
| pad_top = pad_h // 2 | |
| pad_bottom = pad_h - pad_top | |
| pad_left = pad_w // 2 | |
| pad_right = pad_w - pad_left | |
| x = x[ | |
| :, pad_top : 2 * in_lat - pad_bottom, pad_left : 2 * in_lon - pad_right, : | |
| ] | |
| x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3]) | |
| x = self.norm(x) | |
| x = self.linear2(x) | |
| return x | |
| class DownSample3D(nn.Module): | |
| """ | |
| Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn | |
| 3D Down-sampling operation | |
| Implementation from: https://github.com/198808xc/Pangu-Weather/blob/main/pseudocode.py | |
| Args: | |
| in_dim (int): Number of input channels. | |
| input_resolution (tuple[int]): [pressure levels, latitude, longitude] | |
| output_resolution (tuple[int]): [pressure levels, latitude, longitude] | |
| """ | |
| def __init__(self, in_dim, input_resolution, output_resolution): | |
| super().__init__() | |
| self.linear = nn.Linear(in_dim * 4, in_dim * 2, bias=False) | |
| self.norm = nn.LayerNorm(4 * in_dim) | |
| self.input_resolution = input_resolution | |
| self.output_resolution = output_resolution | |
| in_pl, in_lat, in_lon = self.input_resolution | |
| out_pl, out_lat, out_lon = self.output_resolution | |
| h_pad = out_lat * 2 - in_lat | |
| w_pad = out_lon * 2 - in_lon | |
| pad_top = h_pad // 2 | |
| pad_bottom = h_pad - pad_top | |
| pad_left = w_pad // 2 | |
| pad_right = w_pad - pad_left | |
| pad_front = pad_back = 0 | |
| self.pad = nn.ZeroPad3d( | |
| (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back) | |
| ) | |
| def forward(self, x): | |
| B, N, C = x.shape | |
| in_pl, in_lat, in_lon = self.input_resolution | |
| out_pl, out_lat, out_lon = self.output_resolution | |
| x = x.reshape(B, in_pl, in_lat, in_lon, C) | |
| # Padding the input to facilitate downsampling | |
| x = self.pad(x.permute(0, -1, 1, 2, 3)).permute(0, 2, 3, 4, 1) | |
| x = x.reshape(B, in_pl, out_lat, 2, out_lon, 2, C).permute(0, 1, 2, 4, 3, 5, 6) | |
| x = x.reshape(B, out_pl * out_lat * out_lon, 4 * C) | |
| x = self.norm(x) | |
| x = self.linear(x) | |
| return x | |
| class DownSample2D(nn.Module): | |
| """ | |
| Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn | |
| 2D Down-sampling operation | |
| Args: | |
| in_dim (int): Number of input channels. | |
| input_resolution (tuple[int]): [latitude, longitude] | |
| output_resolution (tuple[int]): [latitude, longitude] | |
| """ | |
| def __init__(self, in_dim, input_resolution, output_resolution): | |
| super().__init__() | |
| self.linear = nn.Linear(in_dim * 4, in_dim * 2, bias=False) | |
| self.norm = nn.LayerNorm(4 * in_dim) | |
| self.input_resolution = input_resolution | |
| self.output_resolution = output_resolution | |
| in_lat, in_lon = self.input_resolution | |
| out_lat, out_lon = self.output_resolution | |
| h_pad = out_lat * 2 - in_lat | |
| w_pad = out_lon * 2 - in_lon | |
| pad_top = h_pad // 2 | |
| pad_bottom = h_pad - pad_top | |
| pad_left = w_pad // 2 | |
| pad_right = w_pad - pad_left | |
| self.pad = nn.ZeroPad2d((pad_left, pad_right, pad_top, pad_bottom)) | |
| def forward(self, x: torch.Tensor): | |
| B, N, C = x.shape | |
| in_lat, in_lon = self.input_resolution | |
| out_lat, out_lon = self.output_resolution | |
| x = x.reshape(B, in_lat, in_lon, C) | |
| # Padding the input to facilitate downsampling | |
| x = self.pad(x.permute(0, -1, 1, 2)).permute(0, 2, 3, 1) | |
| x = x.reshape(B, out_lat, 2, out_lon, 2, C).permute(0, 1, 3, 2, 4, 5) | |
| x = x.reshape(B, out_lat * out_lon, 4 * C) | |
| x = self.norm(x) | |
| x = self.linear(x) | |
| return x | |