Jason-thingnario's picture
feat: initial implementation of DGMR solar radiation nowcasting models
2fa5aae
from torch.nn.utils.parametrizations import spectral_norm
import torch
from torch.nn import functional as F
import einops
class GBlock(torch.nn.Module):
def __init__(self, in_channel: int, out_channel: int):
super().__init__()
self.in_channel = in_channel
self.out_channel = out_channel
# TODO: close batch
# self.bn1 = torch.nn.BatchNorm2d(in_channel)
# self.bn2 = torch.nn.BatchNorm2d(out_channel)
self.relu = torch.nn.ReLU()
# conv1x1
self.conv1x1 = spectral_norm(
torch.nn.Conv2d(in_channel, out_channel, kernel_size=1)
)
self.conv3x3_1 = spectral_norm(
torch.nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1)
)
self.conv3x3_2 = spectral_norm(
torch.nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1)
)
def forward(self, x: torch.Tensor):
# if shape is different then applied
if x.shape[1] != self.out_channel:
res = self.conv1x1(x)
else:
res = x.clone()
# first
# x = self.bn1(x)
x = self.relu(x)
x = self.conv3x3_1(x)
# second
# x = self.bn2(x)
x = self.relu(x)
x = self.conv3x3_2(x)
y = x + res
return y
class Up_GBlock(torch.nn.Module):
def __init__(self, in_channel: int):
super().__init__()
self.in_channel = in_channel
self.out_channel = int(in_channel / 2)
# TODO: close batch
# self.bn1 = torch.nn.BatchNorm2d(in_channel)
# self.bn2 = torch.nn.BatchNorm2d(self.out_channel)
# self.bn2 = torch.nn.BatchNorm2d(in_channel)
self.relu = torch.nn.ReLU()
self.up = torch.nn.Upsample(scale_factor=2, mode='nearest')
self.conv1x1 = spectral_norm(
torch.nn.Conv2d(in_channel, self.out_channel, kernel_size=1)
)
self.conv3x3_1 = spectral_norm(
torch.nn.Conv2d(in_channel, in_channel, kernel_size=3, padding=1)
)
self.conv3x3_2 = spectral_norm(
torch.nn.Conv2d(
in_channel,
self.out_channel,
kernel_size=3,
padding=1))
def forward(self, x):
res = self.up(x)
res = self.conv1x1(res)
# x = self.bn1(x)
x = self.relu(x)
x = self.up(x)
x = self.conv3x3_1(x)
# x = self.bn2(x)
x = self.relu(x)
x = self.conv3x3_2(x)
y = x + res
return y
class DBlock(torch.nn.Module):
def __init__(
self,
in_channel: int,
out_channel: int,
conv_type='2d',
apply_relu=True,
apply_down=False):
super().__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.apply_relu = apply_relu
self.apply_down = apply_down
# construct layer
if conv_type == '2d':
self.avg_pool = torch.nn.AvgPool2d(kernel_size=2, stride=2)
conv = torch.nn.Conv2d
elif conv_type == '3d':
self.avg_pool = torch.nn.AvgPool3d(kernel_size=2, stride=2)
conv = torch.nn.Conv3d
self.relu = torch.nn.ReLU()
self.conv1x1 = spectral_norm(
conv(in_channel, out_channel, kernel_size=1)
)
self.conv3x3_1 = spectral_norm(
conv(in_channel, out_channel, kernel_size=3, padding=1)
)
self.conv3x3_2 = spectral_norm(
conv(out_channel, out_channel, kernel_size=3, padding=1)
)
def forward(self, x):
# Residual block
if x.shape[1] != self.out_channel:
res = self.conv1x1(x)
else:
res = x.clone()
if self.apply_down:
res = self.avg_pool(res)
##
if self.apply_relu:
x = self.relu(x)
x = self.conv3x3_1(x)
x = self.relu(x)
x = self.conv3x3_2(x)
if self.apply_down:
x = self.avg_pool(x)
# connect
y = res + x
return y
class LBlock(torch.nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.relu = torch.nn.ReLU()
conv = torch.nn.Conv2d
self.conv1x1 = conv(
in_channel,
(out_channel - in_channel),
kernel_size=1)
self.conv3x3_1 = conv(in_channel, in_channel, kernel_size=3, padding=1)
self.conv3x3_2 = conv(
in_channel,
out_channel,
kernel_size=3,
padding=1)
def forward(self, x):
res = torch.cat([x, self.conv1x1(x)], dim=1)
x = self.relu(x)
x = self.conv3x3_1(x)
x = self.relu(x)
x = self.conv3x3_2(x)
y = x + res
return y
# Attention Layer
def attention_einsum(q, k, v):
"""
Apply self-attention to tensors
"""
# Reshape 3D tensor to 2D tensor with first dimension L = h x w
k = einops.rearrange(k, "h w c -> (h w) c") # [h, w, c] -> [L, c]
v = einops.rearrange(v, "h w c -> (h w) c") # [h, w, c] -> [L, c]
# Einstein summation corresponding to the query * key operation.
beta = F.softmax(torch.einsum("hwc, Lc->hwL", q, k), dim=-1)
# Einstein summation corresponding to the attention * value operation.
out = torch.einsum("hwL, Lc->hwc", beta, v)
return out
class AttentionLayer(torch.nn.Module):
def __init__(self, in_channel, out_channel, ratio_kq=8, ratio_v=8):
super().__init__()
self.ratio_kq = ratio_kq
self.ratio_v = ratio_v
self.in_channel = in_channel
self.out_channel = out_channel
# compute query, key, and value using 1x1 convolution
self.query = torch.nn.Conv2d(
in_channel,
out_channel // ratio_kq,
kernel_size=1,
bias=False
)
self.key = torch.nn.Conv2d(
in_channel,
out_channel // ratio_kq,
kernel_size=1,
bias=False
)
self.value = torch.nn.Conv2d(
in_channel,
out_channel // ratio_v,
kernel_size=1,
bias=False
)
self.conv = torch.nn.Conv2d(
out_channel // 8,
out_channel,
kernel_size=1,
bias=False
)
self.gamma = torch.nn.Parameter(torch.zeros(1))
def forward(self, x):
query = self.query(x)
key = self.key(x)
value = self.value(x)
# apply attention
out = []
for i in range(x.shape[0]):
out.append(attention_einsum(query[i], key[i], value[i]))
out = torch.stack(out, dim=0)
out = self.gamma * self.conv(out)
out = out + x # skip connection
return out