ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from ..utils import get_earth_position_index, trunc_normal_
class EarthAttention3D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
3D window attention with earth position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): [pressure levels, latitude, longitude]
window_size (tuple[int]): [pressure levels, latitude, longitude]
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
input_resolution,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wpl, Wlat, Wlon
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.type_of_windows = (input_resolution[0] // window_size[0]) * (
input_resolution[1] // window_size[1]
)
self.earth_position_bias_table = nn.Parameter(
torch.zeros(
(window_size[0] ** 2)
* (window_size[1] ** 2)
* (window_size[2] * 2 - 1),
self.type_of_windows,
num_heads,
)
) # Wpl**2 * Wlat**2 * Wlon*2-1, Npl//Wpl * Nlat//Wlat, nH
earth_position_index = get_earth_position_index(
window_size
) # Wpl*Wlat*Wlon, Wpl*Wlat*Wlon
self.register_buffer("earth_position_index", earth_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.earth_position_bias_table = trunc_normal_(
self.earth_position_bias_table, std=0.02
)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x: torch.Tensor, mask=None):
"""
Args:
x: input features with shape of (B * num_lon, num_pl*num_lat, N, C)
mask: (0/-inf) mask with shape of (num_lon, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon)
"""
B_, nW_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, nW_, N, 3, self.num_heads, C // self.num_heads)
.permute(3, 0, 4, 1, 2, 5)
)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = q @ k.transpose(-2, -1)
earth_position_bias = self.earth_position_bias_table[
self.earth_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1] * self.window_size[2],
self.window_size[0] * self.window_size[1] * self.window_size[2],
self.type_of_windows,
-1,
) # Wpl*Wlat*Wlon, Wpl*Wlat*Wlon, num_pl*num_lat, nH
earth_position_bias = earth_position_bias.permute(
3, 2, 0, 1
).contiguous() # nH, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon
attn = attn + earth_position_bias.unsqueeze(0)
if mask is not None:
nLon = mask.shape[0]
attn = attn.view(
B_ // nLon, nLon, self.num_heads, nW_, N, N
) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, nW_, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).permute(0, 2, 3, 1, 4).reshape(B_, nW_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class EarthAttention2D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
2D window attention with earth position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): [latitude, longitude]
window_size (tuple[int]): [latitude, longitude]
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
input_resolution,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wlat, Wlon
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.type_of_windows = input_resolution[0] // window_size[0]
self.earth_position_bias_table = nn.Parameter(
torch.zeros(
(window_size[0] ** 2) * (window_size[1] * 2 - 1),
self.type_of_windows,
num_heads,
)
) # Wlat**2 * Wlon*2-1, Nlat//Wlat, nH
earth_position_index = get_earth_position_index(
window_size, ndim=2
) # Wlat*Wlon, Wlat*Wlon
self.register_buffer("earth_position_index", earth_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.earth_position_bias_table = trunc_normal_(
self.earth_position_bias_table, std=0.02
)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x: torch.Tensor, mask=None):
"""
Args:
x: input features with shape of (B * num_lon, num_lat, N, C)
mask: (0/-inf) mask with shape of (num_lon, num_lat, Wlat*Wlon, Wlat*Wlon)
"""
B_, nW_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, nW_, N, 3, self.num_heads, C // self.num_heads)
.permute(3, 0, 4, 1, 2, 5)
)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = q @ k.transpose(-2, -1)
earth_position_bias = self.earth_position_bias_table[
self.earth_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
self.type_of_windows,
-1,
) # Wlat*Wlon, Wlat*Wlon, num_lat, nH
earth_position_bias = earth_position_bias.permute(
3, 2, 0, 1
).contiguous() # nH, num_lat, Wlat*Wlon, Wlat*Wlon
attn = attn + earth_position_bias.unsqueeze(0)
if mask is not None:
nLon = mask.shape[0]
attn = attn.view(
B_ // nLon, nLon, self.num_heads, nW_, N, N
) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, nW_, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).permute(0, 2, 3, 1, 4).reshape(B_, nW_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x