|
|
import os
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from torch import nn, Tensor
|
|
|
import numpy as np
|
|
|
import h5py
|
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
from torch.utils.data import Subset
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
import torch.optim as optim
|
|
|
|
|
|
from model_convlstm import ionexDataset, train_npy, nstepsin, nstepsout, stride, EncoderDecoderConvLSTM, max_epochs
|
|
|
from model_LARRES import larres
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ionexData = ionexDataset(train_npy, nstepsin=nstepsin, nstepsout=nstepsout, stride=stride)
|
|
|
train_data, val_data = ionexData.split_train_val(val_split=0.2)
|
|
|
|
|
|
train_loader = DataLoader(train_data, batch_size=16, num_workers=0)
|
|
|
val_loader = DataLoader(val_data, batch_size=16, num_workers=0)
|
|
|
|
|
|
for X, y in train_loader:
|
|
|
print(f"Shape of X: {X.shape} {X.dtype} [N, C, H, W]")
|
|
|
print(f"Shape of Y: {y.shape} {y.dtype}")
|
|
|
break
|
|
|
print(f"Training samples: {len(train_loader.dataset)}")
|
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
model=larres().to(device)
|
|
|
|
|
|
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
|
|
criterion = nn.L1Loss()
|
|
|
|
|
|
|
|
|
|
|
|
best_val_loss = float('inf')
|
|
|
|
|
|
|
|
|
for epoch in range(max_epochs):
|
|
|
|
|
|
model.train()
|
|
|
all_loss = 0
|
|
|
for batch_idx, (data, target) in enumerate(train_loader):
|
|
|
data, target = data.to(device), target.to(device)
|
|
|
optimizer.zero_grad()
|
|
|
output = model(data)
|
|
|
target_last = target - data[:, 24:36, :, :, :]
|
|
|
|
|
|
loss = criterion(output[:,:12,:,:71,:], target_last[:,:12,:,:71,:])
|
|
|
print(loss)
|
|
|
all_loss+=loss
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
print(f'Epoch {epoch + 1}/{max_epochs}, Train Loss: {all_loss.item():.4f}')
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
val_loss = 0.0
|
|
|
with torch.no_grad():
|
|
|
for data, target in val_loader:
|
|
|
data, target = data.to(device), target.to(device)
|
|
|
output = model(data)
|
|
|
target_last = target - data[:, 24:36, :, :, :]
|
|
|
|
|
|
loss = criterion(output[:, :12, :, :71, :], target_last[:, :12, :, :71, :])
|
|
|
val_loss += loss.item()
|
|
|
|
|
|
val_loss /= len(val_loader)
|
|
|
print(f'Epoch {epoch + 1}/{max_epochs}, Val Loss: {val_loss:.4f}')
|
|
|
|
|
|
|
|
|
if val_loss < best_val_loss:
|
|
|
best_val_loss = val_loss
|
|
|
torch.save(model.state_dict(), 'best_model.pth')
|
|
|
print('Best model saved!')
|
|
|
|
|
|
print('Training completed.') |