LARRES / model_convlstm.py
Staty's picture
Upload 50 files
2b21abc verified
import torch
import torch.nn.functional as F
from torch import nn, Tensor
import numpy as np
import h5py
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split
#Obtained from: https://holmdk.github.io/2020/04/02/video_prediction.html
class ConvLSTMCell(nn.Module):
def __init__(self, input_dim, hidden_dim, kernel_size, bias):
"""
Initialize ConvLSTM cell.
Parameters
----------
input_dim: int
Number of channels of input tensor.
hidden_dim: int
Number of channels of hidden state.
kernel_size: (int, int)
Size of the convolutional kernel.
bias: bool
Whether or not to add the bias.
"""
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
self.bias = bias
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
out_channels=4 * self.hidden_dim,
kernel_size=self.kernel_size,
padding=self.padding,
bias=self.bias)
def forward(self, input_tensor, cur_state):
h_cur, c_cur = cur_state
combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
combined_conv = self.conv(combined)
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
i = torch.sigmoid(cc_i)
f = torch.sigmoid(cc_f)
o = torch.sigmoid(cc_o)
g = torch.tanh(cc_g)
c_next = f * c_cur + i * g
h_next = o * torch.tanh(c_next)
return h_next, c_next
def init_hidden(self, batch_size, image_size):
height, width = image_size
return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
def process_highdim_array(arr):
"""
处理形状为 (1, 60, 1, 71, 73) 的高维数组,将最后两个维度从 (71, 73) 变为 (72, 72)。
参数:
arr (ndarray): 输入的高维 numpy 数组,假设形状为 (1, 60, 1, 71, 73)。
返回:
ndarray: 处理后的数组,形状为 (1, 60, 1, 72, 72)。
"""
# 检查数组的最后两个维度是否为 (71, 73)
if arr.shape[-2:] != (71, 73):
raise ValueError("输入数组的最后两个维度必须是 (71, 73)")
# 对最后两个维度的 (71, 73) 进行处理
# 去掉最后一个维度,变成 (71, 72)
arr_trimmed = arr[..., :-1]
# 在倒数第二个维度填充一行 0,变成 (72, 72)
arr_padded = np.pad(arr_trimmed, ((0, 0), (0, 0), (0, 1), (0, 0)), mode='constant', constant_values=0)
return arr_padded
class ionexDataset(Dataset):
def __init__(self, npy_data, nstepsin=36, nstepsout=12, stride=12):
self.data = npy_data.astype(np.float32)
self.nstepsin=nstepsin
self.nstepsout=nstepsout
self.stride=stride
self.idx=np.arange(0,len(self.data)-nstepsout-nstepsin+1,stride)
def __getitem__(self, index):
# find the end of this pattern
i=self.idx[index]
end_ix = i + self.nstepsin
# check if we are beyond the sequence
if end_ix + self.nstepsout> len(self.data):
return None,None
# gather input and output parts of the pattern
seq_x, seq_y = self.data[i:end_ix], self.data[end_ix:end_ix+self.nstepsout]
return process_highdim_array(seq_x),process_highdim_array(seq_y)
def __len__(self):
return len(self.idx)
def split_train_val(self, val_split=0.25):
train_idx, val_idx = train_test_split(list(range(len(self))), test_size=val_split)
return Subset(self, train_idx), Subset(self, val_idx)
nstepsin=36
nstepsout=12
stride=12
max_epochs=200
# batch_size=2
f = h5py.File('train2015.h5', 'r')
train_npy=np.array(f['2020'])/10
f = h5py.File('test2015.h5', 'r')
test_npy=np.array(f['2015'])/10
# f = h5py.File('train2015.h5', 'r')
# train_npy=np.array(f['2020'])/10
# f1=h5py.File('c1pg2015.h5', 'r')
# f = h5py.File('test2015.h5', 'r')
# test_npy=np.array(f['2015'])/10-np.array(f1['2015'])/10
# f = h5py.File('train2020.h5', 'r')
# train_npy=np.array(f['2020'])/10
# f = h5py.File('test2020.h5', 'r')
# test_npy=np.array(f['2020'])/10
# f = h5py.File('train2020.h5', 'r')
# train_npy=np.array(f['2020'])/10
# f1=h5py.File('c1pg2020.h5', 'r')
# f = h5py.File('test2020.h5', 'r')
# test_npy=np.array(f['2020'])/10-np.array(f1['2020'])/10
f.close()
print("Training data:", train_npy.shape)
print("Testing data:", test_npy.shape)
class EncoderDecoderConvLSTM(nn.Module):
def __init__(self, nf, in_chan, out_chan, nstepsout=12):
super().__init__()
self.nstepsout=nstepsout
self.encoder_1_convlstm = ConvLSTMCell(input_dim=in_chan, hidden_dim=nf, kernel_size=(3, 3), bias=True)
self.encoder_2_convlstm = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=(3, 3), bias=True)
self.encoder_3_convlstm = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=(3, 3), bias=True)
self.decoder_1_convlstm = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=(3, 3), bias=True)
self.decoder_2_convlstm = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=(3, 3), bias=True)
self.decoder_3_convlstm = ConvLSTMCell(input_dim=nf, hidden_dim=nf, kernel_size=(3, 3), bias=True)
self.conv2d = nn.Conv2d(in_channels=nf, out_channels=1, kernel_size=(1,1))
def forward(self, x, future_seq=0, hidden_state=None):
b, seq_len, _, h, w = x.size()
# encoder
# initialize hidden states
h1, c1 = self.encoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w))
h2, c2 = self.encoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w))
h3, c3 = self.decoder_3_convlstm.init_hidden(batch_size=b, image_size=(h, w))
b, seq_len, _, h, w = x.size()
for t in range(seq_len):
h1, c1 = self.encoder_1_convlstm(input_tensor=x[:, t, :, :], cur_state=[h1, c1])
h2, c2 = self.encoder_2_convlstm(input_tensor=h1, cur_state=[h2, c2])
h3, c3 = self.encoder_3_convlstm(input_tensor=h2, cur_state=[h3, c3])
# decoder
# initialize hidden states
h4, c4 = h1, c1 #self.decoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w))
h5, c5 = h2, c2 #self.decoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w))
h6, c6 = h3, c3 #self.decoder_3_convlstm.init_hidden(batch_size=b, image_size=(h, w))
outputs=[]
for t in range(self.nstepsout):
h4, c4 = self.decoder_1_convlstm(input_tensor=h3, cur_state=[h4, c4]) #note that h3 is not updated during prediction
h5, c5 = self.decoder_2_convlstm(input_tensor=h4, cur_state=[h5, c5])
h6, c6 = self.decoder_3_convlstm(input_tensor=h5, cur_state=[h6, c6])
outputs.append(self.conv2d(h4))
outputs = torch.stack(outputs, 1)
return outputs