Jason-thingnario's picture
feat: initial implementation of DGMR solar radiation nowcasting models
2fa5aae
import torch
import torch.nn.functional as F
from torch.nn.utils.parametrizations import spectral_norm
from torch.autograd import Variable
class ConvGRUCell(torch.nn.Module):
def __init__(self, in_channel, out_channel, kernel_size=3):
super().__init__()
padding = kernel_size // 2
self.out_channel = out_channel
self.conv1 = spectral_norm(
torch.nn.Conv2d(
in_channels=in_channel + out_channel,
out_channels=2 * out_channel,
kernel_size=kernel_size,
padding=padding
)
)
self.conv2 = spectral_norm(
torch.nn.Conv2d(
in_channels=in_channel + out_channel,
out_channels=out_channel,
kernel_size=kernel_size,
padding=padding
)
)
def forward(self, x, h_st):
"""
x -> dim (Batch, channels*2, width, height)
h_st -> dim (Batch, channels, width, height)
"""
x_shape = x.shape
h_shape = h_st.shape
# resize width, height
if h_shape[2] > x_shape[2]:
w_l = 1
else:
w_l = 0
if h_shape[3] > x_shape[3]:
h_b = 1
else:
h_b = 0
x = F.pad(x, (0, h_b, w_l, 0), "reflect")
# print(x.shape, h_st.shape)
xx = torch.cat([x, h_st], dim=1)
xx = self.conv1(xx)
gamma, beta = torch.split(xx, self.out_channel, dim=1)
reset_gate = torch.sigmoid(gamma)
update_gate = torch.sigmoid(beta)
out = torch.cat([x, h_st * reset_gate], dim=1)
out = torch.tanh(self.conv2(out))
out = (1 - update_gate) * out + h_st * update_gate
new_st = out
return out, new_st
class ConvGRU(torch.nn.Module):
def __init__(self, in_channel, out_channel, kernel_size):
super().__init__()
self.out_channel = out_channel
self.convgru_cell = ConvGRUCell(in_channel, out_channel, kernel_size)
def _get_init_state(self, batch_size, imd_w, imd_h, dtype):
state = Variable(
torch.zeros(
batch_size,
self.out_channel,
self.h,
self.w)).type(dtype)
return state
def forward(self, x_sequence, init_hidden=None):
"""
Args:
x_sequence shape -> (batch_size, time, c, width, height)
Return:
outputs shape -> (time, batch_size, c, width, height)
"""
seq_len = x_sequence.shape[1]
img_w = x_sequence.shape[3]
img_h = x_sequence.shape[4]
dtype = x_sequence.type()
if init_hidden is None:
hidden_state = self._get_init_state(
x_sequence.shape[0], img_w, img_h, dtype)
else:
hidden_state = init_hidden
out_list = []
for t in range(seq_len):
out, hidden_state = self.convgru_cell(
x_sequence[:, t, :, :, :], hidden_state)
out_list.append(out)
outputs = torch.stack(out_list, dim=0)
return outputs