Spaces:
Sleeping
Sleeping
File size: 5,875 Bytes
c3d0544 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def get_earth_position_index(window_size, ndim=3):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
This function construct the position index to reuse symmetrical parameters of the position bias.
implementation from: https://github.com/198808xc/Pangu-Weather/blob/main/pseudocode.py
Args:
window_size (tuple[int]): [pressure levels, latitude, longitude] or [latitude, longitude]
ndim (int): dimension of tensor, 3 or 2
Returns:
position_index (torch.Tensor): [win_pl * win_lat * win_lon, win_pl * win_lat * win_lon] or [win_lat * win_lon, win_lat * win_lon]
"""
if ndim == 3:
win_pl, win_lat, win_lon = window_size
elif ndim == 2:
win_lat, win_lon = window_size
if ndim == 3:
# Index in the pressure level of query matrix
coords_zi = torch.arange(win_pl)
# Index in the pressure level of key matrix
coords_zj = -torch.arange(win_pl) * win_pl
# Index in the latitude of query matrix
coords_hi = torch.arange(win_lat)
# Index in the latitude of key matrix
coords_hj = -torch.arange(win_lat) * win_lat
# Index in the longitude of the key-value pair
coords_w = torch.arange(win_lon)
# Change the order of the index to calculate the index in total
if ndim == 3:
coords_1 = torch.stack(torch.meshgrid([coords_zi, coords_hi, coords_w]))
coords_2 = torch.stack(torch.meshgrid([coords_zj, coords_hj, coords_w]))
elif ndim == 2:
coords_1 = torch.stack(torch.meshgrid([coords_hi, coords_w]))
coords_2 = torch.stack(torch.meshgrid([coords_hj, coords_w]))
coords_flatten_1 = torch.flatten(coords_1, 1)
coords_flatten_2 = torch.flatten(coords_2, 1)
coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :]
coords = coords.permute(1, 2, 0).contiguous()
# Shift the index for each dimension to start from 0
if ndim == 3:
coords[:, :, 2] += win_lon - 1
coords[:, :, 1] *= 2 * win_lon - 1
coords[:, :, 0] *= (2 * win_lon - 1) * win_lat * win_lat
elif ndim == 2:
coords[:, :, 1] += win_lon - 1
coords[:, :, 0] *= 2 * win_lon - 1
# Sum up the indexes in two/three dimensions
position_index = coords.sum(-1)
return position_index
def get_pad3d(input_resolution, window_size):
"""
Args:
input_resolution (tuple[int]): (Pl, Lat, Lon)
window_size (tuple[int]): (Pl, Lat, Lon)
Returns:
padding (tuple[int]): (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)
"""
Pl, Lat, Lon = input_resolution
win_pl, win_lat, win_lon = window_size
padding_left = padding_right = padding_top = padding_bottom = padding_front = (
padding_back
) = 0
pl_remainder = Pl % win_pl
lat_remainder = Lat % win_lat
lon_remainder = Lon % win_lon
if pl_remainder:
pl_pad = win_pl - pl_remainder
padding_front = pl_pad // 2
padding_back = pl_pad - padding_front
if lat_remainder:
lat_pad = win_lat - lat_remainder
padding_top = lat_pad // 2
padding_bottom = lat_pad - padding_top
if lon_remainder:
lon_pad = win_lon - lon_remainder
padding_left = lon_pad // 2
padding_right = lon_pad - padding_left
return (
padding_left,
padding_right,
padding_top,
padding_bottom,
padding_front,
padding_back,
)
def get_pad2d(input_resolution, window_size):
"""
Args:
input_resolution (tuple[int]): Lat, Lon
window_size (tuple[int]): Lat, Lon
Returns:
padding (tuple[int]): (padding_left, padding_right, padding_top, padding_bottom)
"""
input_resolution = [2] + list(input_resolution)
window_size = [2] + list(window_size)
padding = get_pad3d(input_resolution, window_size)
return padding[:4]
def crop2d(x: torch.Tensor, resolution):
"""
Args:
x (torch.Tensor): B, C, Lat, Lon
resolution (tuple[int]): Lat, Lon
"""
_, _, Lat, Lon = x.shape
lat_pad = Lat - resolution[0]
lon_pad = Lon - resolution[1]
padding_top = lat_pad // 2
padding_bottom = lat_pad - padding_top
padding_left = lon_pad // 2
padding_right = lon_pad - padding_left
return x[
:, :, padding_top : Lat - padding_bottom, padding_left : Lon - padding_right
]
def crop3d(x: torch.Tensor, resolution):
"""
Args:
x (torch.Tensor): B, C, Pl, Lat, Lon
resolution (tuple[int]): Pl, Lat, Lon
"""
_, _, Pl, Lat, Lon = x.shape
pl_pad = Pl - resolution[0]
lat_pad = Lat - resolution[1]
lon_pad = Lon - resolution[2]
padding_front = pl_pad // 2
padding_back = pl_pad - padding_front
padding_top = lat_pad // 2
padding_bottom = lat_pad - padding_top
padding_left = lon_pad // 2
padding_right = lon_pad - padding_left
return x[
:,
:,
padding_front : Pl - padding_back,
padding_top : Lat - padding_bottom,
padding_left : Lon - padding_right,
]
|