|
|
import os |
|
|
import random |
|
|
import torch |
|
|
import numpy as np |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import torch.utils.data as data |
|
|
import torch.distributed as dist |
|
|
import logging |
|
|
from tqdm import tqdm |
|
|
from torch.utils.data.distributed import DistributedSampler |
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
|
import torch.utils.data as data_utils |
|
|
import yaml |
|
|
|
|
|
with open('config.yaml', 'r') as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
selected_model = config['selected_model'] |
|
|
|
|
|
model_config = config['models'][selected_model] |
|
|
training_config = config['trainings'][selected_model] |
|
|
data_config = config['datas'][selected_model] |
|
|
logging_config = config['loggings'][selected_model] |
|
|
|
|
|
backbone = logging_config['backbone'] |
|
|
log_dir = logging_config['log_dir'] |
|
|
checkpoint_dir = logging_config['checkpoint_dir'] |
|
|
result_dir = logging_config['result_dir'] |
|
|
|
|
|
os.makedirs(log_dir, exist_ok=True) |
|
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
os.makedirs(result_dir, exist_ok=True) |
|
|
|
|
|
logging.basicConfig(filename=f'{log_dir}/{backbone}_training_log.log', |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s %(message)s') |
|
|
|
|
|
seed = training_config['seed'] |
|
|
|
|
|
def set_seed(seed): |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
set_seed(seed) |
|
|
|
|
|
|
|
|
parallel_method = training_config.get('parallel_method', 'DistributedDataParallel') |
|
|
|
|
|
if parallel_method == 'DistributedDataParallel': |
|
|
dist.init_process_group(backend='nccl') |
|
|
local_rank = int(os.environ.get('LOCAL_RANK', 0)) |
|
|
torch.cuda.set_device(local_rank) |
|
|
device = torch.device("cuda", local_rank) |
|
|
num_gpus = dist.get_world_size() |
|
|
|
|
|
def reduce_mean(tensor, nprocs): |
|
|
rt = tensor.clone() |
|
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
|
|
rt /= nprocs |
|
|
return rt |
|
|
else: |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
num_gpus = torch.cuda.device_count() |
|
|
local_rank = 0 |
|
|
|
|
|
def reduce_mean(tensor, nprocs): |
|
|
return tensor |
|
|
|
|
|
import torch |
|
|
import torch.utils.data as data_utils |
|
|
|
|
|
|
|
|
data_path = data_config['data_path'] |
|
|
data = torch.load(data_path) |
|
|
ns_all = data['vorticity'] |
|
|
|
|
|
|
|
|
total_samples = ns_all.shape[0] |
|
|
train_end = int(0.8 * total_samples) |
|
|
val_end = int(0.9 * total_samples) |
|
|
|
|
|
train_data = ns_all[:train_end] |
|
|
val_data = ns_all[train_end:val_end] |
|
|
test_data = ns_all[val_end:] |
|
|
|
|
|
args = { |
|
|
'input_length': data_config['input_length'], |
|
|
'target_length': data_config['target_length'], |
|
|
'variables_input': data_config.get('variables_input', [0]), |
|
|
'variables_output': data_config.get('variables_output', [0]), |
|
|
'downsample_factor': data_config['downsample_factor'] |
|
|
} |
|
|
|
|
|
from dataloader_ns import train_Dataset, test_Dataset |
|
|
|
|
|
train_dataset = train_Dataset(train_data, args) |
|
|
val_dataset = test_Dataset(val_data, args) |
|
|
test_dataset = test_Dataset(test_data, args) |
|
|
|
|
|
if parallel_method == 'DistributedDataParallel': |
|
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) |
|
|
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) |
|
|
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) |
|
|
else: |
|
|
train_sampler = None |
|
|
val_sampler = None |
|
|
test_sampler = None |
|
|
|
|
|
train_loader = data_utils.DataLoader( |
|
|
train_dataset, |
|
|
num_workers=0, |
|
|
batch_size=training_config['batch_size'], |
|
|
sampler=train_sampler, |
|
|
shuffle=(train_sampler is None) |
|
|
) |
|
|
|
|
|
val_loader = data_utils.DataLoader( |
|
|
val_dataset, |
|
|
num_workers=0, |
|
|
batch_size=training_config['batch_size'], |
|
|
sampler=val_sampler, |
|
|
shuffle=False |
|
|
) |
|
|
|
|
|
test_loader = data_utils.DataLoader( |
|
|
test_dataset, |
|
|
num_workers=0, |
|
|
batch_size=training_config['batch_size'], |
|
|
sampler=test_sampler, |
|
|
shuffle=False |
|
|
) |
|
|
|
|
|
|
|
|
if local_rank == 0: |
|
|
for input_frames, output_frames in train_loader: |
|
|
print(f'Dataloader Input shape: {input_frames.shape}, Output shape: {output_frames.shape}') |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
from model.triton_model import Triton |
|
|
from model_baselines.fno import FNO2d |
|
|
from model_baselines.dit import Dit |
|
|
from model_baselines.simvp import SimVP |
|
|
from model.triton_model_v2 import Triton_v2 |
|
|
from model_baselines.cno import CNO |
|
|
from model_baselines.mgno import MgNO |
|
|
from model_baselines.lsm import LSM |
|
|
from model_baselines.pastnet import PastNetModel |
|
|
from model_baselines.resnet import ResNet |
|
|
from model_baselines.unet import U_net |
|
|
|
|
|
|
|
|
model_dict = { |
|
|
'Triton': Triton, |
|
|
'Triton_V2': Triton_v2, |
|
|
'FNO': FNO2d, |
|
|
'DiT': Dit, |
|
|
'SimVP': SimVP, |
|
|
'CNO': CNO, |
|
|
'MGNO': MgNO, |
|
|
'LSM': LSM, |
|
|
'PastNet': PastNetModel, |
|
|
'ResNet': ResNet, |
|
|
'U_net': U_net, |
|
|
} |
|
|
|
|
|
model_name = selected_model |
|
|
print(f"{model_name} has been successful load !") |
|
|
model_params = model_config['parameters'] |
|
|
|
|
|
|
|
|
if model_name in model_dict: |
|
|
ModelClass = model_dict[model_name] |
|
|
model = ModelClass(**model_params) |
|
|
else: |
|
|
raise ValueError(f"Model {model_name} is not defined.") |
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
if parallel_method == 'DistributedDataParallel': |
|
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) |
|
|
elif parallel_method == 'DataParallel': |
|
|
model = nn.DataParallel(model) |
|
|
else: |
|
|
raise ValueError(f"Unknown parallel method: {parallel_method}") |
|
|
|
|
|
|
|
|
criterion = nn.MSELoss() |
|
|
learning_rate = training_config['learning_rate'] |
|
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate) |
|
|
|
|
|
num_epochs = training_config['num_epochs'] |
|
|
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0) |
|
|
|
|
|
|
|
|
def train(model, train_loader, criterion, optimizer, device, epoch): |
|
|
model.train() |
|
|
if parallel_method == 'DistributedDataParallel' and train_loader.sampler is not None: |
|
|
train_loader.sampler.set_epoch(epoch) |
|
|
train_loss = 0.0 |
|
|
for inputs, targets in tqdm(train_loader, desc="Training", disable=local_rank != 0): |
|
|
inputs = inputs.to(device, non_blocking=True).float() |
|
|
targets = targets.to(device, non_blocking=True).float() |
|
|
optimizer.zero_grad() |
|
|
outputs = model(inputs) |
|
|
loss = criterion(outputs, targets) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
if parallel_method == 'DistributedDataParallel': |
|
|
loss_value = reduce_mean(loss, num_gpus).item() |
|
|
else: |
|
|
loss_value = loss.item() |
|
|
train_loss += loss_value * inputs.size(0) |
|
|
return train_loss / len(train_loader.dataset) |
|
|
|
|
|
def validate(model, val_loader, criterion, device): |
|
|
model.eval() |
|
|
val_loss = 0.0 |
|
|
with torch.no_grad(): |
|
|
for inputs, targets in tqdm(val_loader, desc="Validation", disable=local_rank != 0): |
|
|
inputs = inputs.to(device, non_blocking=True).float() |
|
|
targets = targets.to(device, non_blocking=True).float() |
|
|
outputs = model(inputs) |
|
|
loss = criterion(outputs, targets) |
|
|
if parallel_method == 'DistributedDataParallel': |
|
|
loss_value = reduce_mean(loss, num_gpus).item() |
|
|
else: |
|
|
loss_value = loss.item() |
|
|
val_loss += loss_value * inputs.size(0) |
|
|
return val_loss / len(val_loader.dataset) |
|
|
|
|
|
def test(model, test_loader, criterion, device): |
|
|
path = result_dir |
|
|
model.eval() |
|
|
test_loss = 0.0 |
|
|
all_inputs = [] |
|
|
all_targets = [] |
|
|
all_outputs = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for inputs, targets in tqdm(test_loader, desc="Testing", disable=local_rank != 0): |
|
|
inputs = inputs.to(device, non_blocking=True).float() |
|
|
targets = targets.to(device, non_blocking=True).float() |
|
|
outputs = model(inputs) |
|
|
|
|
|
if local_rank == 0: |
|
|
all_inputs.append(inputs.cpu().numpy()) |
|
|
all_targets.append(targets.cpu().numpy()) |
|
|
all_outputs.append(outputs.cpu().numpy()) |
|
|
|
|
|
loss = criterion(outputs, targets) |
|
|
if parallel_method == 'DistributedDataParallel': |
|
|
loss_value = reduce_mean(loss, num_gpus).item() |
|
|
else: |
|
|
loss_value = loss.item() |
|
|
test_loss += loss_value * inputs.size(0) |
|
|
|
|
|
if local_rank == 0: |
|
|
all_inputs = np.concatenate(all_inputs, axis=0) |
|
|
all_targets = np.concatenate(all_targets, axis=0) |
|
|
all_outputs = np.concatenate(all_outputs, axis=0) |
|
|
|
|
|
np.save(f'{path}/{backbone}_inputs.npy', all_inputs) |
|
|
np.save(f'{path}/{backbone}_targets.npy', all_targets) |
|
|
np.save(f'{path}/{backbone}_outputs.npy', all_outputs) |
|
|
|
|
|
return test_loss / len(test_loader.dataset) |
|
|
|
|
|
|
|
|
best_val_loss = float('inf') |
|
|
best_model_path = f'{checkpoint_dir}/{backbone}_best_model.pth' |
|
|
|
|
|
if local_rank == 0 and os.path.exists(best_model_path): |
|
|
try: |
|
|
logging.info('Loading best model from checkpoint.') |
|
|
checkpoint = torch.load(best_model_path, map_location=device) |
|
|
model.load_state_dict(checkpoint) |
|
|
except Exception as e: |
|
|
logging.error(f'Error loading model checkpoint: {e}') |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
if local_rank == 0: |
|
|
logging.info(f'Epoch {epoch + 1}/{num_epochs}') |
|
|
train_loss = train(model, train_loader, criterion, optimizer, device, epoch) |
|
|
val_loss = validate(model, val_loader, criterion, device) |
|
|
|
|
|
scheduler.step() |
|
|
|
|
|
if local_rank == 0: |
|
|
current_lr = optimizer.param_groups[0]['lr'] |
|
|
logging.info(f'Current Learning Rate: {current_lr:.10f}') |
|
|
|
|
|
if val_loss < best_val_loss: |
|
|
best_val_loss = val_loss |
|
|
torch.save(model.state_dict(), best_model_path) |
|
|
|
|
|
logging.info(f'Train Loss: {train_loss:.7f}, Val Loss: {val_loss:.7f}') |
|
|
|
|
|
if local_rank == 0: |
|
|
try: |
|
|
model.load_state_dict(torch.load(best_model_path)) |
|
|
test_loss = test(model, test_loader, criterion, device) |
|
|
logging.info(f"Testing completed. Test Loss: {test_loss:.7f}") |
|
|
except Exception as e: |
|
|
logging.error(f'Error loading model checkpoint during testing: {e}') |
|
|
|
|
|
if parallel_method == 'DistributedDataParallel': |
|
|
dist.destroy_process_group() |