OneForecast / models /Fengwu.py
YuanGao-YG's picture
Upload 97 files
912fe5a verified
import math
from dataclasses import dataclass
import numpy as np
import torch
from torch import nn
from collections.abc import Sequence
import warnings
##### weight init ######
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
u1 = norm_cdf((a - mean) / std)
u2 = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [u1, u2], then translate to
# [2u1-1, 2u2-1].
tensor.uniform_(2 * u1 - 1, 2 * u2 - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Cut & paste from timm master
Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
applied while sampling the normal with mean/std applied, therefore a, b args
should be adjusted to match the range of mean, std args.
"""
with torch.no_grad():
return _trunc_normal_(tensor, mean, std, a, b)
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x: torch.Tensor):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def drop_path(
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
):
"""Cut & paste from timm master
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
"""Cut & paste from timm master
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f"drop_prob={round(self.drop_prob,3):0.3f}"
class PatchEmbed2D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
2D Image to Patch Embedding.
Args:
img_size (tuple[int]): Image size.
patch_size (tuple[int]): Patch token size.
in_chans (int): Number of input image channels.
embed_dim(int): Number of projection output channels.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None):
super().__init__()
self.img_size = img_size
height, width = img_size
h_patch_size, w_path_size = patch_size
padding_left = padding_right = padding_top = padding_bottom = 0
h_remainder = height % h_patch_size
w_remainder = width % w_path_size
if h_remainder:
h_pad = h_patch_size - h_remainder
padding_top = h_pad // 2
padding_bottom = int(h_pad - padding_top)
if w_remainder:
w_pad = w_path_size - w_remainder
padding_left = w_pad // 2
padding_right = int(w_pad - padding_left)
self.pad = nn.ConstantPad3d(
(padding_left, padding_right, padding_top, padding_bottom), value=0
)
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x: torch.Tensor):
B, C, H, W = x.shape
x = self.pad(x)
x = self.proj(x)
if self.norm is not None:
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
class PatchEmbed3D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
3D Image to Patch Embedding.
Args:
img_size (tuple[int]): Image size.
patch_size (tuple[int]): Patch token size.
in_chans (int): Number of input image channels.
embed_dim(int): Number of projection output channels.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None):
super().__init__()
self.img_size = img_size
level, height, width = img_size
l_patch_size, h_patch_size, w_patch_size = patch_size
padding_left = (
padding_right
) = padding_top = padding_bottom = padding_front = padding_back = 0
l_remainder = level % l_patch_size
h_remainder = height % l_patch_size
w_remainder = width % w_patch_size
if l_remainder:
l_pad = l_patch_size - l_remainder
padding_front = l_pad // 2
padding_back = l_pad - padding_front
if h_remainder:
h_pad = h_patch_size - h_remainder
padding_top = h_pad // 2
padding_bottom = h_pad - padding_top
if w_remainder:
w_pad = w_patch_size - w_remainder
padding_left = w_pad // 2
padding_right = w_pad - padding_left
self.pad = nn.ConstantPad3d(
(
padding_left,
padding_right,
padding_top,
padding_bottom,
padding_front,
padding_back,
),
value=0
)
self.proj = nn.Conv3d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x: torch.Tensor):
B, C, L, H, W = x.shape
x = self.pad(x)
x = self.proj(x)
if self.norm:
x = self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)
return x
class PatchRecovery2D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
Patch Embedding Recovery to 2D Image.
Args:
img_size (tuple[int]): Lat, Lon
patch_size (tuple[int]): Lat, Lon
in_chans (int): Number of input channels.
out_chans (int): Number of output channels.
"""
def __init__(self, img_size, patch_size, in_chans, out_chans):
super().__init__()
self.img_size = img_size
self.conv = nn.ConvTranspose2d(in_chans, out_chans, patch_size, patch_size)
def forward(self, x):
output = self.conv(x)
_, _, H, W = output.shape
h_pad = H - self.img_size[0]
w_pad = W - self.img_size[1]
padding_top = h_pad // 2
padding_bottom = int(h_pad - padding_top)
padding_left = w_pad // 2
padding_right = int(w_pad - padding_left)
return output[
:, :, padding_top : H - padding_bottom, padding_left : W - padding_right
]
class PatchRecovery3D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
Patch Embedding Recovery to 3D Image.
Args:
img_size (tuple[int]): Pl, Lat, Lon
patch_size (tuple[int]): Pl, Lat, Lon
in_chans (int): Number of input channels.
out_chans (int): Number of output channels.
"""
def __init__(self, img_size, patch_size, in_chans, out_chans):
super().__init__()
self.img_size = img_size
self.conv = nn.ConvTranspose3d(in_chans, out_chans, patch_size, patch_size)
def forward(self, x: torch.Tensor):
output = self.conv(x)
_, _, Pl, Lat, Lon = output.shape
pl_pad = Pl - self.img_size[0]
lat_pad = Lat - self.img_size[1]
lon_pad = Lon - self.img_size[2]
padding_front = pl_pad // 2
padding_back = pl_pad - padding_front
padding_top = lat_pad // 2
padding_bottom = lat_pad - padding_top
padding_left = lon_pad // 2
padding_right = lon_pad - padding_left
return output[
:,
:,
padding_front : Pl - padding_back,
padding_top : Lat - padding_bottom,
padding_left : Lon - padding_right,
]
class UpSample3D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
3D Up-sampling operation.
Implementation from: https://github.com/198808xc/Pangu-Weather/blob/main/pseudocode.py
Args:
in_dim (int): Number of input channels.
out_dim (int): Number of output channels.
input_resolution (tuple[int]): [pressure levels, latitude, longitude]
output_resolution (tuple[int]): [pressure levels, latitude, longitude]
"""
def __init__(self, in_dim, out_dim, input_resolution, output_resolution):
super().__init__()
self.linear1 = nn.Linear(in_dim, out_dim * 4, bias=False)
self.linear2 = nn.Linear(out_dim, out_dim, bias=False)
self.norm = nn.LayerNorm(out_dim)
self.input_resolution = input_resolution
self.output_resolution = output_resolution
def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): (B, N, C)
"""
B, N, C = x.shape
in_pl, in_lat, in_lon = self.input_resolution
out_pl, out_lat, out_lon = self.output_resolution
x = self.linear1(x)
x = x.reshape(B, in_pl, in_lat, in_lon, 2, 2, C // 2).permute(
0, 1, 2, 4, 3, 5, 6
)
x = x.reshape(B, in_pl, in_lat * 2, in_lon * 2, -1)
pad_h = in_lat * 2 - out_lat
pad_w = in_lon * 2 - out_lon
pad_top = pad_h // 2
pad_bottom = pad_h - pad_top
pad_left = pad_w // 2
pad_right = pad_w - pad_left
x = x[
:,
:out_pl,
pad_top : 2 * in_lat - pad_bottom,
pad_left : 2 * in_lon - pad_right,
:,
]
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3], x.shape[4])
x = self.norm(x)
x = self.linear2(x)
return x
class UpSample2D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
2D Up-sampling operation.
Args:
in_dim (int): Number of input channels.
out_dim (int): Number of output channels.
input_resolution (tuple[int]): [latitude, longitude]
output_resolution (tuple[int]): [latitude, longitude]
"""
def __init__(self, in_dim, out_dim, input_resolution, output_resolution):
super().__init__()
self.linear1 = nn.Linear(in_dim, out_dim * 4, bias=False)
self.linear2 = nn.Linear(out_dim, out_dim, bias=False)
self.norm = nn.LayerNorm(out_dim)
self.input_resolution = input_resolution
self.output_resolution = output_resolution
def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): (B, N, C)
"""
B, N, C = x.shape
in_lat, in_lon = self.input_resolution
out_lat, out_lon = self.output_resolution
x = self.linear1(x)
x = x.reshape(B, in_lat, in_lon, 2, 2, C // 2).permute(0, 1, 3, 2, 4, 5)
x = x.reshape(B, in_lat * 2, in_lon * 2, -1)
pad_h = in_lat * 2 - out_lat
pad_w = in_lon * 2 - out_lon
pad_top = pad_h // 2
pad_bottom = pad_h - pad_top
pad_left = pad_w // 2
pad_right = pad_w - pad_left
x = x[
:, pad_top : 2 * in_lat - pad_bottom, pad_left : 2 * in_lon - pad_right, :
]
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
x = self.norm(x)
x = self.linear2(x)
return x
class DownSample3D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
3D Down-sampling operation
Implementation from: https://github.com/198808xc/Pangu-Weather/blob/main/pseudocode.py
Args:
in_dim (int): Number of input channels.
input_resolution (tuple[int]): [pressure levels, latitude, longitude]
output_resolution (tuple[int]): [pressure levels, latitude, longitude]
"""
def __init__(self, in_dim, input_resolution, output_resolution):
super().__init__()
self.linear = nn.Linear(in_dim * 4, in_dim * 2, bias=False)
self.norm = nn.LayerNorm(4 * in_dim)
self.input_resolution = input_resolution
self.output_resolution = output_resolution
in_pl, in_lat, in_lon = self.input_resolution
out_pl, out_lat, out_lon = self.output_resolution
h_pad = out_lat * 2 - in_lat
w_pad = out_lon * 2 - in_lon
pad_top = h_pad // 2
pad_bottom = h_pad - pad_top
pad_left = w_pad // 2
pad_right = w_pad - pad_left
pad_front = pad_back = 0
self.pad = nn.ConstantPad3d(
(pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back), value=0
)
def forward(self, x):
B, N, C = x.shape
in_pl, in_lat, in_lon = self.input_resolution
out_pl, out_lat, out_lon = self.output_resolution
x = x.reshape(B, in_pl, in_lat, in_lon, C)
# Padding the input to facilitate downsampling
x = self.pad(x.permute(0, -1, 1, 2, 3)).permute(0, 2, 3, 4, 1)
x = x.reshape(B, in_pl, out_lat, 2, out_lon, 2, C).permute(0, 1, 2, 4, 3, 5, 6)
x = x.reshape(B, out_pl * out_lat * out_lon, 4 * C)
x = self.norm(x)
x = self.linear(x)
return x
class DownSample2D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
2D Down-sampling operation
Args:
in_dim (int): Number of input channels.
input_resolution (tuple[int]): [latitude, longitude]
output_resolution (tuple[int]): [latitude, longitude]
"""
def __init__(self, in_dim, input_resolution, output_resolution):
super().__init__()
self.linear = nn.Linear(in_dim * 4, in_dim * 2, bias=False)
self.norm = nn.LayerNorm(4 * in_dim)
self.input_resolution = input_resolution
self.output_resolution = output_resolution
in_lat, in_lon = self.input_resolution
out_lat, out_lon = self.output_resolution
h_pad = out_lat * 2 - in_lat
w_pad = out_lon * 2 - in_lon
pad_top = h_pad // 2
pad_bottom = h_pad - pad_top
pad_left = w_pad // 2
pad_right = w_pad - pad_left
self.pad = nn.ConstantPad3d((pad_left, pad_right, pad_top, pad_bottom), value=0)
def forward(self, x: torch.Tensor):
B, N, C = x.shape
in_lat, in_lon = self.input_resolution
out_lat, out_lon = self.output_resolution
x = x.reshape(B, in_lat, in_lon, C)
# Padding the input to facilitate downsampling
x = self.pad(x.permute(0, -1, 1, 2)).permute(0, 2, 3, 1)
x = x.reshape(B, out_lat, 2, out_lon, 2, C).permute(0, 1, 3, 2, 4, 5)
x = x.reshape(B, out_lat * out_lon, 4 * C)
x = self.norm(x)
x = self.linear(x)
return x
def get_earth_position_index(window_size, ndim=3):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
This function construct the position index to reuse symmetrical parameters of the position bias.
implementation from: https://github.com/198808xc/Pangu-Weather/blob/main/pseudocode.py
Args:
window_size (tuple[int]): [pressure levels, latitude, longitude] or [latitude, longitude]
ndim (int): dimension of tensor, 3 or 2
Returns:
position_index (torch.Tensor): [win_pl * win_lat * win_lon, win_pl * win_lat * win_lon] or [win_lat * win_lon, win_lat * win_lon]
"""
if ndim == 3:
win_pl, win_lat, win_lon = window_size
elif ndim == 2:
win_lat, win_lon = window_size
if ndim == 3:
# Index in the pressure level of query matrix
coords_zi = torch.arange(win_pl)
# Index in the pressure level of key matrix
coords_zj = -torch.arange(win_pl) * win_pl
# Index in the latitude of query matrix
coords_hi = torch.arange(win_lat)
# Index in the latitude of key matrix
coords_hj = -torch.arange(win_lat) * win_lat
# Index in the longitude of the key-value pair
coords_w = torch.arange(win_lon)
# Change the order of the index to calculate the index in total
if ndim == 3:
coords_1 = torch.stack(torch.meshgrid([coords_zi, coords_hi, coords_w]))
coords_2 = torch.stack(torch.meshgrid([coords_zj, coords_hj, coords_w]))
elif ndim == 2:
coords_1 = torch.stack(torch.meshgrid([coords_hi, coords_w]))
coords_2 = torch.stack(torch.meshgrid([coords_hj, coords_w]))
coords_flatten_1 = torch.flatten(coords_1, 1)
coords_flatten_2 = torch.flatten(coords_2, 1)
coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :]
coords = coords.permute(1, 2, 0).contiguous()
# Shift the index for each dimension to start from 0
if ndim == 3:
coords[:, :, 2] += win_lon - 1
coords[:, :, 1] *= 2 * win_lon - 1
coords[:, :, 0] *= (2 * win_lon - 1) * win_lat * win_lat
elif ndim == 2:
coords[:, :, 1] += win_lon - 1
coords[:, :, 0] *= 2 * win_lon - 1
# Sum up the indexes in two/three dimensions
position_index = coords.sum(-1)
return position_index
def get_pad3d(input_resolution, window_size):
"""
Args:
input_resolution (tuple[int]): (Pl, Lat, Lon)
window_size (tuple[int]): (Pl, Lat, Lon)
Returns:
padding (tuple[int]): (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)
"""
Pl, Lat, Lon = input_resolution
win_pl, win_lat, win_lon = window_size
padding_left = (
padding_right
) = padding_top = padding_bottom = padding_front = padding_back = 0
pl_remainder = Pl % win_pl
lat_remainder = Lat % win_lat
lon_remainder = Lon % win_lon
if pl_remainder:
pl_pad = win_pl - pl_remainder
padding_front = pl_pad // 2
padding_back = pl_pad - padding_front
if lat_remainder:
lat_pad = win_lat - lat_remainder
padding_top = lat_pad // 2
padding_bottom = lat_pad - padding_top
if lon_remainder:
lon_pad = win_lon - lon_remainder
padding_left = lon_pad // 2
padding_right = lon_pad - padding_left
return (
padding_left,
padding_right,
padding_top,
padding_bottom,
padding_front,
padding_back,
)
def get_pad2d(input_resolution, window_size):
"""
Args:
input_resolution (tuple[int]): Lat, Lon
window_size (tuple[int]): Lat, Lon
Returns:
padding (tuple[int]): (padding_left, padding_right, padding_top, padding_bottom)
"""
input_resolution = [2] + list(input_resolution)
window_size = [2] + list(window_size)
padding = get_pad3d(input_resolution, window_size)
return padding[: 4]
def crop3d(x: torch.Tensor, resolution):
"""
Args:
x (torch.Tensor): B, C, Pl, Lat, Lon
resolution (tuple[int]): Pl, Lat, Lon
"""
_, _, Pl, Lat, Lon = x.shape
pl_pad = Pl - resolution[0]
lat_pad = Lat - resolution[1]
lon_pad = Lon - resolution[2]
padding_front = pl_pad // 2
padding_back = pl_pad - padding_front
padding_top = lat_pad // 2
padding_bottom = lat_pad - padding_top
padding_left = lon_pad // 2
padding_right = lon_pad - padding_left
return x[
:,
:,
padding_front : Pl - padding_back,
padding_top : Lat - padding_bottom,
padding_left : Lon - padding_right,
]
def crop2d(x: torch.Tensor, resolution):
"""
Args:
x (torch.Tensor): B, C, Lat, Lon
resolution (tuple[int]): Lat, Lon
"""
_, _, Lat, Lon = x.shape
lat_pad = Lat - resolution[0]
lon_pad = Lon - resolution[1]
padding_top = lat_pad // 2
padding_bottom = lat_pad - padding_top
padding_left = lon_pad // 2
padding_right = lon_pad - padding_left
return x[
:, :, padding_top : Lat - padding_bottom, padding_left : Lon - padding_right
]
class EarthAttention2D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
2D window attention with earth position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): [latitude, longitude]
window_size (tuple[int]): [latitude, longitude]
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
input_resolution,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wlat, Wlon
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.type_of_windows = input_resolution[0] // window_size[0]
self.earth_position_bias_table = nn.Parameter(
torch.zeros(
(window_size[0] ** 2) * (window_size[1] * 2 - 1),
self.type_of_windows,
num_heads,
)
) # Wlat**2 * Wlon*2-1, Nlat//Wlat, nH
earth_position_index = get_earth_position_index(
window_size, ndim=2
) # Wlat*Wlon, Wlat*Wlon
self.register_buffer("earth_position_index", earth_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.earth_position_bias_table = trunc_normal_(
self.earth_position_bias_table, std=0.02
)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x: torch.Tensor, mask=None):
"""
Args:
x: input features with shape of (B * num_lon, num_lat, N, C)
mask: (0/-inf) mask with shape of (num_lon, num_lat, Wlat*Wlon, Wlat*Wlon)
"""
B_, nW_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, nW_, N, 3, self.num_heads, C // self.num_heads)
.permute(3, 0, 4, 1, 2, 5)
)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = q @ k.transpose(-2, -1)
earth_position_bias = self.earth_position_bias_table[
self.earth_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
self.type_of_windows,
-1,
) # Wlat*Wlon, Wlat*Wlon, num_lat, nH
earth_position_bias = earth_position_bias.permute(
3, 2, 0, 1
).contiguous() # nH, num_lat, Wlat*Wlon, Wlat*Wlon
attn = attn + earth_position_bias.unsqueeze(0)
if mask is not None:
nLon = mask.shape[0]
attn = attn.view(
B_ // nLon, nLon, self.num_heads, nW_, N, N
) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, nW_, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).permute(0, 2, 3, 1, 4).reshape(B_, nW_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class EarthAttention3D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
3D window attention with earth position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): [pressure levels, latitude, longitude]
window_size (tuple[int]): [pressure levels, latitude, longitude]
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
input_resolution,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wpl, Wlat, Wlon
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.type_of_windows = (input_resolution[0] // window_size[0]) * (
input_resolution[1] // window_size[1]
)
self.earth_position_bias_table = nn.Parameter(
torch.zeros(
(window_size[0] ** 2)
* (window_size[1] ** 2)
* (window_size[2] * 2 - 1),
self.type_of_windows,
num_heads,
)
) # Wpl**2 * Wlat**2 * Wlon*2-1, Npl//Wpl * Nlat//Wlat, nH
earth_position_index = get_earth_position_index(
window_size
) # Wpl*Wlat*Wlon, Wpl*Wlat*Wlon
self.register_buffer("earth_position_index", earth_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.earth_position_bias_table = trunc_normal_(
self.earth_position_bias_table, std=0.02
)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x: torch.Tensor, mask=None):
"""
Args:
x: input features with shape of (B * num_lon, num_pl*num_lat, N, C)
mask: (0/-inf) mask with shape of (num_lon, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon)
"""
B_, nW_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, nW_, N, 3, self.num_heads, C // self.num_heads)
.permute(3, 0, 4, 1, 2, 5)
)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = q @ k.transpose(-2, -1)
earth_position_bias = self.earth_position_bias_table[
self.earth_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1] * self.window_size[2],
self.window_size[0] * self.window_size[1] * self.window_size[2],
self.type_of_windows,
-1,
) # Wpl*Wlat*Wlon, Wpl*Wlat*Wlon, num_pl*num_lat, nH
earth_position_bias = earth_position_bias.permute(
3, 2, 0, 1
).contiguous() # nH, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon
attn = attn + earth_position_bias.unsqueeze(0)
if mask is not None:
nLon = mask.shape[0]
attn = attn.view(
B_ // nLon, nLon, self.num_heads, nW_, N, N
) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, nW_, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).permute(0, 2, 3, 1, 4).reshape(B_, nW_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Transformer3DBlock(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
3D Transformer Block
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (tuple[int]): Window size [pressure levels, latitude, longitude].
shift_size (tuple[int]): Shift size for SW-MSA [pressure levels, latitude, longitude].
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=None,
shift_size=None,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
window_size = (2, 6, 12) if window_size is None else window_size
shift_size = (1, 3, 6) if shift_size is None else shift_size
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.norm1 = norm_layer(dim)
padding = get_pad3d(input_resolution, window_size)
self.pad = nn.ConstantPad3d(padding, value=0)
pad_resolution = list(input_resolution)
pad_resolution[0] += padding[-1] + padding[-2]
pad_resolution[1] += padding[2] + padding[3]
pad_resolution[2] += padding[0] + padding[1]
self.attn = EarthAttention3D(
dim=dim,
input_resolution=pad_resolution,
window_size=window_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
shift_pl, shift_lat, shift_lon = self.shift_size
self.roll = shift_pl and shift_lon and shift_lat
if self.roll:
attn_mask = get_shift_window_mask(pad_resolution, window_size, shift_size)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x: torch.Tensor):
Pl, Lat, Lon = self.input_resolution
B, L, C = x.shape
shortcut = x
x = self.norm1(x)
x = x.view(B, Pl, Lat, Lon, C)
# start pad
x = self.pad(x.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1)
_, Pl_pad, Lat_pad, Lon_pad, _ = x.shape
shift_pl, shift_lat, shift_lon = self.shift_size
if self.roll:
shifted_x = torch.roll(
x, shifts=(-shift_pl, -shift_lat, -shift_lat), dims=(1, 2, 3)
)
x_windows = window_partition(shifted_x, self.window_size)
# B*num_lon, num_pl*num_lat, win_pl, win_lat, win_lon, C
else:
shifted_x = x
x_windows = window_partition(shifted_x, self.window_size)
# B*num_lon, num_pl*num_lat, win_pl, win_lat, win_lon, C
win_pl, win_lat, win_lon = self.window_size
x_windows = x_windows.view(
x_windows.shape[0], x_windows.shape[1], win_pl * win_lat * win_lon, C
)
# B*num_lon, num_pl*num_lat, win_pl*win_lat*win_lon, C
attn_windows = self.attn(
x_windows, mask=self.attn_mask
) # B*num_lon, num_pl*num_lat, win_pl*win_lat*win_lon, C
attn_windows = attn_windows.view(
attn_windows.shape[0], attn_windows.shape[1], win_pl, win_lat, win_lon, C
)
if self.roll:
shifted_x = window_reverse(
attn_windows, self.window_size, Pl=Pl_pad, Lat=Lat_pad, Lon=Lon_pad
)
# B * Pl * Lat * Lon * C
x = torch.roll(
shifted_x, shifts=(shift_pl, shift_lat, shift_lon), dims=(1, 2, 3)
)
else:
shifted_x = window_reverse(
attn_windows, self.window_size, Pl=Pl_pad, Lat=Lat_pad, Lon=Lon_pad
)
x = shifted_x
# crop, end pad
x = crop3d(x.permute(0, 4, 1, 2, 3), self.input_resolution).permute(
0, 2, 3, 4, 1
)
x = x.reshape(B, Pl * Lat * Lon, C)
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
##### shift window mask ############
def window_partition(x: torch.Tensor, window_size, ndim=3):
"""
Args:
x: (B, Pl, Lat, Lon, C) or (B, Lat, Lon, C)
window_size (tuple[int]): [win_pl, win_lat, win_lon] or [win_lat, win_lon]
ndim (int): dimension of window (3 or 2)
Returns:
windows: (B*num_lon, num_pl*num_lat, win_pl, win_lat, win_lon, C) or (B*num_lon, num_lat, win_lat, win_lon, C)
"""
if ndim == 3:
B, Pl, Lat, Lon, C = x.shape
win_pl, win_lat, win_lon = window_size
x = x.view(
B, Pl // win_pl, win_pl, Lat // win_lat, win_lat, Lon // win_lon, win_lon, C
)
windows = (
x.permute(0, 5, 1, 3, 2, 4, 6, 7)
.contiguous()
.view(-1, (Pl // win_pl) * (Lat // win_lat), win_pl, win_lat, win_lon, C)
)
return windows
elif ndim == 2:
B, Lat, Lon, C = x.shape
win_lat, win_lon = window_size
x = x.view(B, Lat // win_lat, win_lat, Lon // win_lon, win_lon, C)
windows = (
x.permute(0, 3, 1, 2, 4, 5)
.contiguous()
.view(-1, (Lat // win_lat), win_lat, win_lon, C)
)
return windows
def window_reverse(windows, window_size, Pl=1, Lat=1, Lon=1, ndim=3):
"""
Args:
windows: (B*num_lon, num_pl*num_lat, win_pl, win_lat, win_lon, C) or (B*num_lon, num_lat, win_lat, win_lon, C)
window_size (tuple[int]): [win_pl, win_lat, win_lon] or [win_lat, win_lon]
Pl (int): pressure levels
Lat (int): latitude
Lon (int): longitude
ndim (int): dimension of window (3 or 2)
Returns:
x: (B, Pl, Lat, Lon, C) or (B, Lat, Lon, C)
"""
if ndim == 3:
win_pl, win_lat, win_lon = window_size
B = int(windows.shape[0] / (Lon / win_lon))
x = windows.view(
B,
Lon // win_lon,
Pl // win_pl,
Lat // win_lat,
win_pl,
win_lat,
win_lon,
-1,
)
x = x.permute(0, 2, 4, 3, 5, 1, 6, 7).contiguous().view(B, Pl, Lat, Lon, -1)
return x
elif ndim == 2:
win_lat, win_lon = window_size
B = int(windows.shape[0] / (Lon / win_lon))
x = windows.view(B, Lon // win_lon, Lat // win_lat, win_lat, win_lon, -1)
x = x.permute(0, 2, 3, 1, 4, 5).contiguous().view(B, Lat, Lon, -1)
return x
def get_shift_window_mask(input_resolution, window_size, shift_size, ndim=3):
"""
Along the longitude dimension, the leftmost and rightmost indices are actually close to each other.
If half windows apper at both leftmost and rightmost positions, they are dircetly merged into one window.
Args:
input_resolution (tuple[int]): [pressure levels, latitude, longitude] or [latitude, longitude]
window_size (tuple[int]): Window size [pressure levels, latitude, longitude] or [latitude, longitude]
shift_size (tuple[int]): Shift size for SW-MSA [pressure levels, latitude, longitude] or [latitude, longitude]
ndim (int): dimension of window (3 or 2)
Returns:
attn_mask: (n_lon, n_pl*n_lat, win_pl*win_lat*win_lon, win_pl*win_lat*win_lon) or (n_lon, n_lat, win_lat*win_lon, win_lat*win_lon)
"""
if ndim == 3:
Pl, Lat, Lon = input_resolution
win_pl, win_lat, win_lon = window_size
shift_pl, shift_lat, shift_lon = shift_size
img_mask = torch.zeros((1, Pl, Lat, Lon + shift_lon, 1))
elif ndim == 2:
Lat, Lon = input_resolution
win_lat, win_lon = window_size
shift_lat, shift_lon = shift_size
img_mask = torch.zeros((1, Lat, Lon + shift_lon, 1))
if ndim == 3:
pl_slices = (
slice(0, -win_pl),
slice(-win_pl, -shift_pl),
slice(-shift_pl, None),
)
lat_slices = (
slice(0, -win_lat),
slice(-win_lat, -shift_lat),
slice(-shift_lat, None),
)
lon_slices = (
slice(0, -win_lon),
slice(-win_lon, -shift_lon),
slice(-shift_lon, None),
)
cnt = 0
if ndim == 3:
for pl in pl_slices:
for lat in lat_slices:
for lon in lon_slices:
img_mask[:, pl, lat, lon, :] = cnt
cnt += 1
img_mask = img_mask[:, :, :, :Lon, :]
elif ndim == 2:
for lat in lat_slices:
for lon in lon_slices:
img_mask[:, lat, lon, :] = cnt
cnt += 1
img_mask = img_mask[:, :, :Lon, :]
mask_windows = window_partition(
img_mask, window_size, ndim=ndim
) # n_lon, n_pl*n_lat, win_pl, win_lat, win_lon, 1 or n_lon, n_lat, win_lat, win_lon, 1
if ndim == 3:
win_total = win_pl * win_lat * win_lon
elif ndim == 2:
win_total = win_lat * win_lon
mask_windows = mask_windows.view(
mask_windows.shape[0], mask_windows.shape[1], win_total
)
attn_mask = mask_windows.unsqueeze(2) - mask_windows.unsqueeze(3)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
return attn_mask
class Transformer2DBlock(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
2D Transformer Block
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (tuple[int]): Window size [latitude, longitude].
shift_size (tuple[int]): Shift size for SW-MSA [latitude, longitude].
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=None,
shift_size=None,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
window_size = (6, 12) if window_size is None else window_size
shift_size = (3, 6) if shift_size is None else shift_size
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.norm1 = norm_layer(dim)
padding = get_pad2d(input_resolution, window_size)
self.pad = nn.ConstantPad3d(padding, value=0)
pad_resolution = list(input_resolution)
pad_resolution[0] += padding[2] + padding[3]
pad_resolution[1] += padding[0] + padding[1]
self.attn = EarthAttention2D(
dim=dim,
input_resolution=pad_resolution,
window_size=window_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
shift_lat, shift_lon = self.shift_size
self.roll = shift_lon and shift_lat
if self.roll:
attn_mask = get_shift_window_mask(
pad_resolution, window_size, shift_size, ndim=2
)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x: torch.Tensor):
Lat, Lon = self.input_resolution
B, L, C = x.shape
shortcut = x
x = self.norm1(x)
x = x.view(B, Lat, Lon, C)
# start pad
x = self.pad(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
_, Lat_pad, Lon_pad, _ = x.shape
shift_lat, shift_lon = self.shift_size
if self.roll:
shifted_x = torch.roll(x, shifts=(-shift_lat, -shift_lat), dims=(1, 2))
x_windows = window_partition(shifted_x, self.window_size, ndim=2)
# B*num_lon, num_lat, win_lat, win_lon, C
else:
shifted_x = x
x_windows = window_partition(shifted_x, self.window_size, ndim=2)
# B*num_lon, num_lat, win_lat, win_lon, C
win_lat, win_lon = self.window_size
x_windows = x_windows.view(
x_windows.shape[0], x_windows.shape[1], win_lat * win_lon, C
)
# B*num_lon, num_lat, win_lat*win_lon, C
attn_windows = self.attn(
x_windows, mask=self.attn_mask
) # B*num_lon, num_lat, win_lat*win_lon, C
attn_windows = attn_windows.view(
attn_windows.shape[0], attn_windows.shape[1], win_lat, win_lon, C
)
if self.roll:
shifted_x = window_reverse(
attn_windows, self.window_size, Lat=Lat_pad, Lon=Lon_pad, ndim=2
)
# B * Lat * Lon * C
x = torch.roll(shifted_x, shifts=(shift_lat, shift_lon), dims=(1, 2))
else:
shifted_x = window_reverse(
attn_windows, self.window_size, Lat=Lat_pad, Lon=Lon_pad, ndim=2
)
x = shifted_x
# crop, end pad
x = crop2d(x.permute(0, 3, 1, 2), self.input_resolution).permute(0, 2, 3, 1)
x = x.reshape(B, Lat * Lon, C)
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class FuserLayer(nn.Module):
"""Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
A basic 3D Transformer layer for one stage
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (tuple[int]): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.blocks = nn.ModuleList(
[
Transformer3DBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=(0, 0, 0) if i % 2 == 0 else None,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i]
if isinstance(drop_path, Sequence)
else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)
]
)
def forward(self, x):
for blk in self.blocks:
x = blk(x)
return x
class EncoderLayer(nn.Module):
"""A 2D Transformer Encoder Module for one stage
Args:
img_size (tuple[int]): image size(Lat, Lon).
patch_size (tuple[int]): Patch token size of Patch Embedding.
in_chans (int): number of input channels of Patch Embedding.
dim (int): Number of input channels of transformer.
input_resolution (tuple[int]): Input resolution for transformer before downsampling.
middle_resolution (tuple[int]): Input resolution for transformer after downsampling.
depth (int): Number of blocks for transformer before downsampling.
depth_middle (int): Number of blocks for transformer after downsampling.
num_heads (int): Number of attention heads.
window_size (tuple[int]): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
img_size,
patch_size,
in_chans,
dim,
input_resolution,
middle_resolution,
depth,
depth_middle,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.in_chans = in_chans
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.depth_middle = depth_middle
if isinstance(drop_path, Sequence):
drop_path_middle = drop_path[depth:]
drop_path = drop_path[:depth]
else:
drop_path_middle = drop_path
if isinstance(num_heads, Sequence):
num_heads_middle = num_heads[1]
num_heads = num_heads[0]
else:
num_heads_middle = num_heads
self.patchembed2d = PatchEmbed2D(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=dim,
)
self.blocks = nn.ModuleList(
[
Transformer2DBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=(0, 0) if i % 2 == 0 else None,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i]
if isinstance(drop_path, Sequence)
else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)
]
)
self.downsample = DownSample2D(
in_dim=dim,
input_resolution=input_resolution,
output_resolution=middle_resolution,
)
self.blocks_middle = nn.ModuleList(
[
Transformer2DBlock(
dim=dim * 2,
input_resolution=middle_resolution,
num_heads=num_heads_middle,
window_size=window_size,
shift_size=(0, 0) if i % 2 == 0 else None,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path_middle[i]
if isinstance(drop_path_middle, Sequence)
else drop_path_middle,
norm_layer=norm_layer,
)
for i in range(depth_middle)
]
)
def forward(self, x):
x = self.patchembed2d(x)
B, C, Lat, Lon = x.shape
x = x.reshape(B, C, -1).transpose(1, 2)
for blk in self.blocks:
x = blk(x)
skip = x.reshape(B, Lat, Lon, C)
x = self.downsample(x)
for blk in self.blocks_middle:
x = blk(x)
return x, skip
class DecoderLayer(nn.Module):
"""A 2D Transformer Decoder Module for one stage
Args:
img_size (tuple[int]): image size(Lat, Lon).
patch_size (tuple[int]): Patch token size of Patch Recovery.
out_chans (int): number of output channels of Patch Recovery.
dim (int): Number of input channels of transformer.
output_resolution (tuple[int]): Input resolution for transformer after upsampling.
middle_resolution (tuple[int]): Input resolution for transformer before upsampling.
depth (int): Number of blocks for transformer after upsampling.
depth_middle (int): Number of blocks for transformer before upsampling.
num_heads (int): Number of attention heads.
window_size (tuple[int]): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
img_size,
patch_size,
out_chans,
dim,
output_resolution,
middle_resolution,
depth,
depth_middle,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.out_chans = out_chans
self.dim = dim
self.output_resolution = output_resolution
self.depth = depth
self.depth_middle = depth_middle
if isinstance(drop_path, Sequence):
drop_path_middle = drop_path[depth:]
drop_path = drop_path[:depth]
else:
drop_path_middle = drop_path
if isinstance(num_heads, Sequence):
num_heads_middle = num_heads[1]
num_heads = num_heads[0]
else:
num_heads_middle = num_heads
self.blocks_middle = nn.ModuleList(
[
Transformer2DBlock(
dim=dim * 2,
input_resolution=middle_resolution,
num_heads=num_heads_middle,
window_size=window_size,
shift_size=(0, 0) if i % 2 == 0 else None,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path_middle[i]
if isinstance(drop_path_middle, Sequence)
else drop_path_middle,
norm_layer=norm_layer,
)
for i in range(depth_middle)
]
)
self.upsample = UpSample2D(
in_dim=dim * 2,
out_dim=dim,
input_resolution=middle_resolution,
output_resolution=output_resolution,
)
self.blocks = nn.ModuleList(
[
Transformer2DBlock(
dim=dim,
input_resolution=output_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=(0, 0) if i % 2 == 0 else None,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i]
if isinstance(drop_path, Sequence)
else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)
]
)
self.patchrecovery2d = PatchRecovery2D(img_size, patch_size, 2 * dim, out_chans)
def forward(self, x, skip):
B, Lat, Lon, C = skip.shape
for blk in self.blocks_middle:
x = blk(x)
x = self.upsample(x)
for blk in self.blocks:
x = blk(x)
output = torch.cat([x, skip.reshape(B, -1, C)], dim=-1)
output = output.transpose(1, 2).reshape(B, -1, Lat, Lon)
output = self.patchrecovery2d(output)
return output
class Fengwu(nn.Module):
"""
FengWu PyTorch impl of: `FengWu: Pushing the Skillful Global Medium-range Weather Forecast beyond 10 Days Lead`
- https://arxiv.org/pdf/2304.02948.pdf
Args:
img_size: Image size(Lat, Lon). Default: (721,1440)
pressure_level: Number of pressure_level. Default: 37
embed_dim (int): Patch embedding dimension. Default: 192
patch_size (tuple[int]): Patch token size. Default: (4,4)
num_heads (tuple[int]): Number of attention heads in different layers.
window_size (tuple[int]): Window size.
"""
def __init__(
self,
# params,
in_shape=(1, 69, 120, 240),
pressure_level=13,
embed_dim=192,
patch_size=(4, 4),
num_heads=(6, 12, 12, 6),
window_size=(2, 6, 12),
**kwargs):
super().__init__()
img_size = in_shape[2],in_shape[3]
drop_path = np.linspace(0, 0.2, 8).tolist()
drop_path_fuser = [0.2] * 6
resolution_down1 = (
math.ceil(img_size[0] / patch_size[0]),
math.ceil(img_size[1] / patch_size[1]),
)
resolution_down2 = (
math.ceil(resolution_down1[0] / 2),
math.ceil(resolution_down1[1] / 2),
)
resolution = (resolution_down1, resolution_down2)
self.encoder_surface = EncoderLayer(
img_size=img_size,
patch_size=patch_size,
in_chans=4,
dim=embed_dim,
input_resolution=resolution[0],
middle_resolution=resolution[1],
depth=2,
depth_middle=6,
num_heads=num_heads[:2],
window_size=window_size[1:],
drop_path=drop_path,
)
self.encoder_z = EncoderLayer(
img_size=img_size,
patch_size=patch_size,
in_chans=pressure_level,
dim=embed_dim,
input_resolution=resolution[0],
middle_resolution=resolution[1],
depth=2,
depth_middle=6,
num_heads=num_heads[:2],
window_size=window_size[1:],
drop_path=drop_path,
)
self.encoder_r = EncoderLayer(
img_size=img_size,
patch_size=patch_size,
in_chans=pressure_level,
dim=embed_dim,
input_resolution=resolution[0],
middle_resolution=resolution[1],
depth=2,
depth_middle=6,
num_heads=num_heads[:2],
window_size=window_size[1:],
drop_path=drop_path,
)
self.encoder_u = EncoderLayer(
img_size=img_size,
patch_size=patch_size,
in_chans=pressure_level,
dim=embed_dim,
input_resolution=resolution[0],
middle_resolution=resolution[1],
depth=2,
depth_middle=6,
num_heads=num_heads[:2],
window_size=window_size[1:],
drop_path=drop_path,
)
self.encoder_v = EncoderLayer(
img_size=img_size,
patch_size=patch_size,
in_chans=pressure_level,
dim=embed_dim,
input_resolution=resolution[0],
middle_resolution=resolution[1],
depth=2,
depth_middle=6,
num_heads=num_heads[:2],
window_size=window_size[1:],
drop_path=drop_path,
)
self.encoder_t = EncoderLayer(
img_size=img_size,
patch_size=patch_size,
in_chans=pressure_level,
dim=embed_dim,
input_resolution=resolution[0],
middle_resolution=resolution[1],
depth=2,
depth_middle=6,
num_heads=num_heads[:2],
window_size=window_size[1:],
drop_path=drop_path,
)
self.fuser = FuserLayer(
dim=embed_dim * 2,
input_resolution=(6, resolution[1][0], resolution[1][1]),
depth=6,
num_heads=num_heads[1],
window_size=window_size,
drop_path=drop_path_fuser,
)
self.decoder_surface = DecoderLayer(
img_size=img_size,
patch_size=patch_size,
out_chans=4,
dim=embed_dim,
output_resolution=resolution[0],
middle_resolution=resolution[1],
depth=2,
depth_middle=6,
num_heads=num_heads[:2],
window_size=window_size[1:],
drop_path=drop_path,
)
self.decoder_z = DecoderLayer(
img_size=img_size,
patch_size=patch_size,
out_chans=pressure_level,
dim=embed_dim,
output_resolution=resolution[0],
middle_resolution=resolution[1],
depth=2,
depth_middle=6,
num_heads=num_heads[:2],
window_size=window_size[1:],
drop_path=drop_path,
)
self.decoder_r = DecoderLayer(
img_size=img_size,
patch_size=patch_size,
out_chans=pressure_level,
dim=embed_dim,
output_resolution=resolution[0],
middle_resolution=resolution[1],
depth=2,
depth_middle=6,
num_heads=num_heads[:2],
window_size=window_size[1:],
drop_path=drop_path,
)
self.decoder_u = DecoderLayer(
img_size=img_size,
patch_size=patch_size,
out_chans=pressure_level,
dim=embed_dim,
output_resolution=resolution[0],
middle_resolution=resolution[1],
depth=2,
depth_middle=6,
num_heads=num_heads[:2],
window_size=window_size[1:],
drop_path=drop_path,
)
self.decoder_v = DecoderLayer(
img_size=img_size,
patch_size=patch_size,
out_chans=pressure_level,
dim=embed_dim,
output_resolution=resolution[0],
middle_resolution=resolution[1],
depth=2,
depth_middle=6,
num_heads=num_heads[:2],
window_size=window_size[1:],
drop_path=drop_path,
)
self.decoder_t = DecoderLayer(
img_size=img_size,
patch_size=patch_size,
out_chans=pressure_level,
dim=embed_dim,
output_resolution=resolution[0],
middle_resolution=resolution[1],
depth=2,
depth_middle=6,
num_heads=num_heads[:2],
window_size=window_size[1:],
drop_path=drop_path,
)
def forward(self, x):
"""
Args:
surface (torch.Tensor): 2D n_lat=721, n_lon=1440, chans=4.
z (torch.Tensor): 2D n_lat=721, n_lon=1440, chans=37.
r (torch.Tensor): 2D n_lat=721, n_lon=1440, chans=37.
u (torch.Tensor): 2D n_lat=721, n_lon=1440, chans=37.
v (torch.Tensor): 2D n_lat=721, n_lon=1440, chans=37.
t (torch.Tensor): 2D n_lat=721, n_lon=1440, chans=37.
"""
surface = x[:, 65:69, :, :]
u = x[:, 39:52, :, :] # 从第4通道开始,每隔5个通道提取u
v = x[:, 52:65, :, :] # 从第5通道开始,每隔5个通道提取v
t = x[:, 26:39, :, :] # 从第6通道开始,每隔5个通道提取t
r = x[:, 13:26, :, :] # 从第7通道开始,每隔5个通道提取r
z = x[:, 0:13, :, :] # 从第8通道开始,每隔5个通道提取z
# print(f"surface shape: {surface.shape}")
# print(f"u shape: {u.shape}")
# print(f"v shape: {v.shape}")
# print(f"t shape: {t.shape}")
# print(f"r shape: {r.shape}")
# print(f"z shape: {z.shape}")
surface, skip_surface = self.encoder_surface(surface)
z, skip_z = self.encoder_z(z)
r, skip_r = self.encoder_r(r)
u, skip_u = self.encoder_u(u)
v, skip_v = self.encoder_v(v)
t, skip_t = self.encoder_t(t)
x = torch.cat(
[
surface.unsqueeze(1),
z.unsqueeze(1),
r.unsqueeze(1),
u.unsqueeze(1),
v.unsqueeze(1),
t.unsqueeze(1),
],
dim=1,
)
B, PL, L_SIZE, C = x.shape
x = x.reshape(B, -1, C)
x = self.fuser(x)
x = x.reshape(B, PL, L_SIZE, C)
surface, z, r, u, v, t = (
x[:, 0, :, :],
x[:, 1, :, :],
x[:, 2, :, :],
x[:, 3, :, :],
x[:, 4, :, :],
x[:, 5, :, :],
)
surface = self.decoder_surface(surface, skip_surface)
z = self.decoder_z(z, skip_z)
r = self.decoder_r(r, skip_r)
u = self.decoder_u(u, skip_u)
v = self.decoder_v(v, skip_v)
t = self.decoder_t(t, skip_t)
# stached = torch.stack((z,r,u,v,t),dim=1)
# reshaped = stached.permute(0,1,3,4,2).reshape(B,13*5,180,360)
# reshaped = stached.permute(0,1,3,4,2).reshape(B,13*5,180,360)
reshaped =torch.cat((z,r, t,u,v),dim=1)
final_output = torch.cat([reshaped, surface], dim=1)
return final_output
if __name__ == '__main__':
inputs = torch.randn(1, 69, 120, 240)
model = Fengwu()
output = model(inputs)
print(inputs.shape)
print(output.shape)