ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def get_earth_position_index(window_size, ndim=3):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
This function construct the position index to reuse symmetrical parameters of the position bias.
implementation from: https://github.com/198808xc/Pangu-Weather/blob/main/pseudocode.py
Args:
window_size (tuple[int]): [pressure levels, latitude, longitude] or [latitude, longitude]
ndim (int): dimension of tensor, 3 or 2
Returns:
position_index (torch.Tensor): [win_pl * win_lat * win_lon, win_pl * win_lat * win_lon] or [win_lat * win_lon, win_lat * win_lon]
"""
if ndim == 3:
win_pl, win_lat, win_lon = window_size
elif ndim == 2:
win_lat, win_lon = window_size
if ndim == 3:
# Index in the pressure level of query matrix
coords_zi = torch.arange(win_pl)
# Index in the pressure level of key matrix
coords_zj = -torch.arange(win_pl) * win_pl
# Index in the latitude of query matrix
coords_hi = torch.arange(win_lat)
# Index in the latitude of key matrix
coords_hj = -torch.arange(win_lat) * win_lat
# Index in the longitude of the key-value pair
coords_w = torch.arange(win_lon)
# Change the order of the index to calculate the index in total
if ndim == 3:
coords_1 = torch.stack(torch.meshgrid([coords_zi, coords_hi, coords_w]))
coords_2 = torch.stack(torch.meshgrid([coords_zj, coords_hj, coords_w]))
elif ndim == 2:
coords_1 = torch.stack(torch.meshgrid([coords_hi, coords_w]))
coords_2 = torch.stack(torch.meshgrid([coords_hj, coords_w]))
coords_flatten_1 = torch.flatten(coords_1, 1)
coords_flatten_2 = torch.flatten(coords_2, 1)
coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :]
coords = coords.permute(1, 2, 0).contiguous()
# Shift the index for each dimension to start from 0
if ndim == 3:
coords[:, :, 2] += win_lon - 1
coords[:, :, 1] *= 2 * win_lon - 1
coords[:, :, 0] *= (2 * win_lon - 1) * win_lat * win_lat
elif ndim == 2:
coords[:, :, 1] += win_lon - 1
coords[:, :, 0] *= 2 * win_lon - 1
# Sum up the indexes in two/three dimensions
position_index = coords.sum(-1)
return position_index
def get_pad3d(input_resolution, window_size):
"""
Args:
input_resolution (tuple[int]): (Pl, Lat, Lon)
window_size (tuple[int]): (Pl, Lat, Lon)
Returns:
padding (tuple[int]): (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)
"""
Pl, Lat, Lon = input_resolution
win_pl, win_lat, win_lon = window_size
padding_left = padding_right = padding_top = padding_bottom = padding_front = (
padding_back
) = 0
pl_remainder = Pl % win_pl
lat_remainder = Lat % win_lat
lon_remainder = Lon % win_lon
if pl_remainder:
pl_pad = win_pl - pl_remainder
padding_front = pl_pad // 2
padding_back = pl_pad - padding_front
if lat_remainder:
lat_pad = win_lat - lat_remainder
padding_top = lat_pad // 2
padding_bottom = lat_pad - padding_top
if lon_remainder:
lon_pad = win_lon - lon_remainder
padding_left = lon_pad // 2
padding_right = lon_pad - padding_left
return (
padding_left,
padding_right,
padding_top,
padding_bottom,
padding_front,
padding_back,
)
def get_pad2d(input_resolution, window_size):
"""
Args:
input_resolution (tuple[int]): Lat, Lon
window_size (tuple[int]): Lat, Lon
Returns:
padding (tuple[int]): (padding_left, padding_right, padding_top, padding_bottom)
"""
input_resolution = [2] + list(input_resolution)
window_size = [2] + list(window_size)
padding = get_pad3d(input_resolution, window_size)
return padding[:4]
def crop2d(x: torch.Tensor, resolution):
"""
Args:
x (torch.Tensor): B, C, Lat, Lon
resolution (tuple[int]): Lat, Lon
"""
_, _, Lat, Lon = x.shape
lat_pad = Lat - resolution[0]
lon_pad = Lon - resolution[1]
padding_top = lat_pad // 2
padding_bottom = lat_pad - padding_top
padding_left = lon_pad // 2
padding_right = lon_pad - padding_left
return x[
:, :, padding_top : Lat - padding_bottom, padding_left : Lon - padding_right
]
def crop3d(x: torch.Tensor, resolution):
"""
Args:
x (torch.Tensor): B, C, Pl, Lat, Lon
resolution (tuple[int]): Pl, Lat, Lon
"""
_, _, Pl, Lat, Lon = x.shape
pl_pad = Pl - resolution[0]
lat_pad = Lat - resolution[1]
lon_pad = Lon - resolution[2]
padding_front = pl_pad // 2
padding_back = pl_pad - padding_front
padding_top = lat_pad // 2
padding_bottom = lat_pad - padding_top
padding_left = lon_pad // 2
padding_right = lon_pad - padding_left
return x[
:,
:,
padding_front : Pl - padding_back,
padding_top : Lat - padding_bottom,
padding_left : Lon - padding_right,
]