Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| def get_earth_position_index(window_size, ndim=3): | |
| """ | |
| Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn | |
| This function construct the position index to reuse symmetrical parameters of the position bias. | |
| implementation from: https://github.com/198808xc/Pangu-Weather/blob/main/pseudocode.py | |
| Args: | |
| window_size (tuple[int]): [pressure levels, latitude, longitude] or [latitude, longitude] | |
| ndim (int): dimension of tensor, 3 or 2 | |
| Returns: | |
| position_index (torch.Tensor): [win_pl * win_lat * win_lon, win_pl * win_lat * win_lon] or [win_lat * win_lon, win_lat * win_lon] | |
| """ | |
| if ndim == 3: | |
| win_pl, win_lat, win_lon = window_size | |
| elif ndim == 2: | |
| win_lat, win_lon = window_size | |
| if ndim == 3: | |
| # Index in the pressure level of query matrix | |
| coords_zi = torch.arange(win_pl) | |
| # Index in the pressure level of key matrix | |
| coords_zj = -torch.arange(win_pl) * win_pl | |
| # Index in the latitude of query matrix | |
| coords_hi = torch.arange(win_lat) | |
| # Index in the latitude of key matrix | |
| coords_hj = -torch.arange(win_lat) * win_lat | |
| # Index in the longitude of the key-value pair | |
| coords_w = torch.arange(win_lon) | |
| # Change the order of the index to calculate the index in total | |
| if ndim == 3: | |
| coords_1 = torch.stack(torch.meshgrid([coords_zi, coords_hi, coords_w])) | |
| coords_2 = torch.stack(torch.meshgrid([coords_zj, coords_hj, coords_w])) | |
| elif ndim == 2: | |
| coords_1 = torch.stack(torch.meshgrid([coords_hi, coords_w])) | |
| coords_2 = torch.stack(torch.meshgrid([coords_hj, coords_w])) | |
| coords_flatten_1 = torch.flatten(coords_1, 1) | |
| coords_flatten_2 = torch.flatten(coords_2, 1) | |
| coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :] | |
| coords = coords.permute(1, 2, 0).contiguous() | |
| # Shift the index for each dimension to start from 0 | |
| if ndim == 3: | |
| coords[:, :, 2] += win_lon - 1 | |
| coords[:, :, 1] *= 2 * win_lon - 1 | |
| coords[:, :, 0] *= (2 * win_lon - 1) * win_lat * win_lat | |
| elif ndim == 2: | |
| coords[:, :, 1] += win_lon - 1 | |
| coords[:, :, 0] *= 2 * win_lon - 1 | |
| # Sum up the indexes in two/three dimensions | |
| position_index = coords.sum(-1) | |
| return position_index | |
| def get_pad3d(input_resolution, window_size): | |
| """ | |
| Args: | |
| input_resolution (tuple[int]): (Pl, Lat, Lon) | |
| window_size (tuple[int]): (Pl, Lat, Lon) | |
| Returns: | |
| padding (tuple[int]): (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back) | |
| """ | |
| Pl, Lat, Lon = input_resolution | |
| win_pl, win_lat, win_lon = window_size | |
| padding_left = padding_right = padding_top = padding_bottom = padding_front = ( | |
| padding_back | |
| ) = 0 | |
| pl_remainder = Pl % win_pl | |
| lat_remainder = Lat % win_lat | |
| lon_remainder = Lon % win_lon | |
| if pl_remainder: | |
| pl_pad = win_pl - pl_remainder | |
| padding_front = pl_pad // 2 | |
| padding_back = pl_pad - padding_front | |
| if lat_remainder: | |
| lat_pad = win_lat - lat_remainder | |
| padding_top = lat_pad // 2 | |
| padding_bottom = lat_pad - padding_top | |
| if lon_remainder: | |
| lon_pad = win_lon - lon_remainder | |
| padding_left = lon_pad // 2 | |
| padding_right = lon_pad - padding_left | |
| return ( | |
| padding_left, | |
| padding_right, | |
| padding_top, | |
| padding_bottom, | |
| padding_front, | |
| padding_back, | |
| ) | |
| def get_pad2d(input_resolution, window_size): | |
| """ | |
| Args: | |
| input_resolution (tuple[int]): Lat, Lon | |
| window_size (tuple[int]): Lat, Lon | |
| Returns: | |
| padding (tuple[int]): (padding_left, padding_right, padding_top, padding_bottom) | |
| """ | |
| input_resolution = [2] + list(input_resolution) | |
| window_size = [2] + list(window_size) | |
| padding = get_pad3d(input_resolution, window_size) | |
| return padding[:4] | |
| def crop2d(x: torch.Tensor, resolution): | |
| """ | |
| Args: | |
| x (torch.Tensor): B, C, Lat, Lon | |
| resolution (tuple[int]): Lat, Lon | |
| """ | |
| _, _, Lat, Lon = x.shape | |
| lat_pad = Lat - resolution[0] | |
| lon_pad = Lon - resolution[1] | |
| padding_top = lat_pad // 2 | |
| padding_bottom = lat_pad - padding_top | |
| padding_left = lon_pad // 2 | |
| padding_right = lon_pad - padding_left | |
| return x[ | |
| :, :, padding_top : Lat - padding_bottom, padding_left : Lon - padding_right | |
| ] | |
| def crop3d(x: torch.Tensor, resolution): | |
| """ | |
| Args: | |
| x (torch.Tensor): B, C, Pl, Lat, Lon | |
| resolution (tuple[int]): Pl, Lat, Lon | |
| """ | |
| _, _, Pl, Lat, Lon = x.shape | |
| pl_pad = Pl - resolution[0] | |
| lat_pad = Lat - resolution[1] | |
| lon_pad = Lon - resolution[2] | |
| padding_front = pl_pad // 2 | |
| padding_back = pl_pad - padding_front | |
| padding_top = lat_pad // 2 | |
| padding_bottom = lat_pad - padding_top | |
| padding_left = lon_pad // 2 | |
| padding_right = lon_pad - padding_left | |
| return x[ | |
| :, | |
| :, | |
| padding_front : Pl - padding_back, | |
| padding_top : Lat - padding_bottom, | |
| padding_left : Lon - padding_right, | |
| ] | |