# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch def get_earth_position_index(window_size, ndim=3): """ Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn This function construct the position index to reuse symmetrical parameters of the position bias. implementation from: https://github.com/198808xc/Pangu-Weather/blob/main/pseudocode.py Args: window_size (tuple[int]): [pressure levels, latitude, longitude] or [latitude, longitude] ndim (int): dimension of tensor, 3 or 2 Returns: position_index (torch.Tensor): [win_pl * win_lat * win_lon, win_pl * win_lat * win_lon] or [win_lat * win_lon, win_lat * win_lon] """ if ndim == 3: win_pl, win_lat, win_lon = window_size elif ndim == 2: win_lat, win_lon = window_size if ndim == 3: # Index in the pressure level of query matrix coords_zi = torch.arange(win_pl) # Index in the pressure level of key matrix coords_zj = -torch.arange(win_pl) * win_pl # Index in the latitude of query matrix coords_hi = torch.arange(win_lat) # Index in the latitude of key matrix coords_hj = -torch.arange(win_lat) * win_lat # Index in the longitude of the key-value pair coords_w = torch.arange(win_lon) # Change the order of the index to calculate the index in total if ndim == 3: coords_1 = torch.stack(torch.meshgrid([coords_zi, coords_hi, coords_w])) coords_2 = torch.stack(torch.meshgrid([coords_zj, coords_hj, coords_w])) elif ndim == 2: coords_1 = torch.stack(torch.meshgrid([coords_hi, coords_w])) coords_2 = torch.stack(torch.meshgrid([coords_hj, coords_w])) coords_flatten_1 = torch.flatten(coords_1, 1) coords_flatten_2 = torch.flatten(coords_2, 1) coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :] coords = coords.permute(1, 2, 0).contiguous() # Shift the index for each dimension to start from 0 if ndim == 3: coords[:, :, 2] += win_lon - 1 coords[:, :, 1] *= 2 * win_lon - 1 coords[:, :, 0] *= (2 * win_lon - 1) * win_lat * win_lat elif ndim == 2: coords[:, :, 1] += win_lon - 1 coords[:, :, 0] *= 2 * win_lon - 1 # Sum up the indexes in two/three dimensions position_index = coords.sum(-1) return position_index def get_pad3d(input_resolution, window_size): """ Args: input_resolution (tuple[int]): (Pl, Lat, Lon) window_size (tuple[int]): (Pl, Lat, Lon) Returns: padding (tuple[int]): (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back) """ Pl, Lat, Lon = input_resolution win_pl, win_lat, win_lon = window_size padding_left = padding_right = padding_top = padding_bottom = padding_front = ( padding_back ) = 0 pl_remainder = Pl % win_pl lat_remainder = Lat % win_lat lon_remainder = Lon % win_lon if pl_remainder: pl_pad = win_pl - pl_remainder padding_front = pl_pad // 2 padding_back = pl_pad - padding_front if lat_remainder: lat_pad = win_lat - lat_remainder padding_top = lat_pad // 2 padding_bottom = lat_pad - padding_top if lon_remainder: lon_pad = win_lon - lon_remainder padding_left = lon_pad // 2 padding_right = lon_pad - padding_left return ( padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back, ) def get_pad2d(input_resolution, window_size): """ Args: input_resolution (tuple[int]): Lat, Lon window_size (tuple[int]): Lat, Lon Returns: padding (tuple[int]): (padding_left, padding_right, padding_top, padding_bottom) """ input_resolution = [2] + list(input_resolution) window_size = [2] + list(window_size) padding = get_pad3d(input_resolution, window_size) return padding[:4] def crop2d(x: torch.Tensor, resolution): """ Args: x (torch.Tensor): B, C, Lat, Lon resolution (tuple[int]): Lat, Lon """ _, _, Lat, Lon = x.shape lat_pad = Lat - resolution[0] lon_pad = Lon - resolution[1] padding_top = lat_pad // 2 padding_bottom = lat_pad - padding_top padding_left = lon_pad // 2 padding_right = lon_pad - padding_left return x[ :, :, padding_top : Lat - padding_bottom, padding_left : Lon - padding_right ] def crop3d(x: torch.Tensor, resolution): """ Args: x (torch.Tensor): B, C, Pl, Lat, Lon resolution (tuple[int]): Pl, Lat, Lon """ _, _, Pl, Lat, Lon = x.shape pl_pad = Pl - resolution[0] lat_pad = Lat - resolution[1] lon_pad = Lon - resolution[2] padding_front = pl_pad // 2 padding_back = pl_pad - padding_front padding_top = lat_pad // 2 padding_bottom = lat_pad - padding_top padding_left = lon_pad // 2 padding_right = lon_pad - padding_left return x[ :, :, padding_front : Pl - padding_back, padding_top : Lat - padding_bottom, padding_left : Lon - padding_right, ]