easylearning's picture
Upload 91 files
a57e1d5 verified
import os
import random
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.distributed as dist
import logging
from tqdm import tqdm
from torch.utils.data.distributed import DistributedSampler
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.utils.data as data_utils
import yaml
with open('config.yaml', 'r') as f:
config = yaml.safe_load(f)
selected_model = config['selected_model']
model_config = config['models'][selected_model]
training_config = config['trainings'][selected_model]
data_config = config['datas'][selected_model]
logging_config = config['loggings'][selected_model]
backbone = logging_config['backbone']
log_dir = logging_config['log_dir']
checkpoint_dir = logging_config['checkpoint_dir']
result_dir = logging_config['result_dir']
os.makedirs(log_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(result_dir, exist_ok=True)
logging.basicConfig(filename=f'{log_dir}/{backbone}_training_log.log',
level=logging.INFO,
format='%(asctime)s %(message)s')
seed = training_config['seed']
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(seed)
# ============== Distributed Training Settings ===============
parallel_method = training_config.get('parallel_method', 'DistributedDataParallel')
if parallel_method == 'DistributedDataParallel':
dist.init_process_group(backend='nccl')
local_rank = int(os.environ.get('LOCAL_RANK', 0))
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
num_gpus = dist.get_world_size()
def reduce_mean(tensor, nprocs):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= nprocs
return rt
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_gpus = torch.cuda.device_count()
local_rank = 0 # For DataParallel, can be setted as 0
def reduce_mean(tensor, nprocs):
return tensor
import torch
import torch.utils.data as data_utils
# ============== Data loader ==============
data_path = data_config['data_path']
data = torch.load(data_path)
ns_all = data['vorticity'] # Assuming your data dict has a key 'vorticity'
# Split the data into training, validation, and test sets
total_samples = ns_all.shape[0] # Total number of samples (1280)
train_end = int(0.8 * total_samples) # 80% for training
val_end = int(0.9 * total_samples) # Next 10% for validation
train_data = ns_all[:train_end] # Shape: [1024, 100, 128, 128] #
val_data = ns_all[train_end:val_end] # Shape: [128, 100, 128, 128]
test_data = ns_all[val_end:] # Shape: [128, 100, 128, 128]
args = {
'input_length': data_config['input_length'],
'target_length': data_config['target_length'],
'variables_input': data_config.get('variables_input', [0]),
'variables_output': data_config.get('variables_output', [0]),
'downsample_factor': data_config['downsample_factor']
}
from dataloader_ns import train_Dataset, test_Dataset
train_dataset = train_Dataset(train_data, args)
val_dataset = test_Dataset(val_data, args)
test_dataset = test_Dataset(test_data, args)
if parallel_method == 'DistributedDataParallel':
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
else:
train_sampler = None
val_sampler = None
test_sampler = None
train_loader = data_utils.DataLoader(
train_dataset,
num_workers=0,
batch_size=training_config['batch_size'],
sampler=train_sampler,
shuffle=(train_sampler is None)
)
val_loader = data_utils.DataLoader(
val_dataset,
num_workers=0,
batch_size=training_config['batch_size'],
sampler=val_sampler,
shuffle=False
)
test_loader = data_utils.DataLoader(
test_dataset,
num_workers=0,
batch_size=training_config['batch_size'],
sampler=test_sampler,
shuffle=False
)
if local_rank == 0:
for input_frames, output_frames in train_loader:
print(f'Dataloader Input shape: {input_frames.shape}, Output shape: {output_frames.shape}')
break
# ============== Model settings ==============
# Model registry
from model.triton_model import Triton
from model_baselines.fno import FNO2d
from model_baselines.dit import Dit
from model_baselines.simvp import SimVP
from model.triton_model_v2 import Triton_v2
from model_baselines.cno import CNO
from model_baselines.mgno import MgNO
from model_baselines.lsm import LSM
from model_baselines.pastnet import PastNetModel
from model_baselines.resnet import ResNet
from model_baselines.unet import U_net
model_dict = {
'Triton': Triton,
'Triton_V2': Triton_v2,
'FNO': FNO2d,
'DiT': Dit,
'SimVP': SimVP,
'CNO': CNO,
'MGNO': MgNO,
'LSM': LSM,
'PastNet': PastNetModel,
'ResNet': ResNet,
'U_net': U_net,
}
model_name = selected_model
print(f"{model_name} has been successful load !")
model_params = model_config['parameters']
# Check if the model is in the registry
if model_name in model_dict:
ModelClass = model_dict[model_name]
model = ModelClass(**model_params)
else:
raise ValueError(f"Model {model_name} is not defined.")
model = model.to(device)
# Process according to the parallelization method
if parallel_method == 'DistributedDataParallel':
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
elif parallel_method == 'DataParallel':
model = nn.DataParallel(model)
else:
raise ValueError(f"Unknown parallel method: {parallel_method}")
# ============== Loss Function and Optimizer ==============
criterion = nn.MSELoss()
learning_rate = training_config['learning_rate']
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
num_epochs = training_config['num_epochs']
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0)
# ============== Training, validation, and testing functions ==============
def train(model, train_loader, criterion, optimizer, device, epoch):
model.train()
if parallel_method == 'DistributedDataParallel' and train_loader.sampler is not None:
train_loader.sampler.set_epoch(epoch)
train_loss = 0.0
for inputs, targets in tqdm(train_loader, desc="Training", disable=local_rank != 0):
inputs = inputs.to(device, non_blocking=True).float() # B 1 1 128 128
targets = targets.to(device, non_blocking=True).float()
optimizer.zero_grad()
outputs = model(inputs) # B 1 1 128 128
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if parallel_method == 'DistributedDataParallel':
loss_value = reduce_mean(loss, num_gpus).item()
else:
loss_value = loss.item()
train_loss += loss_value * inputs.size(0)
return train_loss / len(train_loader.dataset)
def validate(model, val_loader, criterion, device):
model.eval()
val_loss = 0.0
with torch.no_grad():
for inputs, targets in tqdm(val_loader, desc="Validation", disable=local_rank != 0):
inputs = inputs.to(device, non_blocking=True).float()
targets = targets.to(device, non_blocking=True).float()
outputs = model(inputs)
loss = criterion(outputs, targets)
if parallel_method == 'DistributedDataParallel':
loss_value = reduce_mean(loss, num_gpus).item()
else:
loss_value = loss.item()
val_loss += loss_value * inputs.size(0)
return val_loss / len(val_loader.dataset)
def test(model, test_loader, criterion, device):
path = result_dir
model.eval()
test_loss = 0.0
all_inputs = []
all_targets = []
all_outputs = []
with torch.no_grad():
for inputs, targets in tqdm(test_loader, desc="Testing", disable=local_rank != 0):
inputs = inputs.to(device, non_blocking=True).float()
targets = targets.to(device, non_blocking=True).float()
outputs = model(inputs)
if local_rank == 0:
all_inputs.append(inputs.cpu().numpy())
all_targets.append(targets.cpu().numpy())
all_outputs.append(outputs.cpu().numpy())
loss = criterion(outputs, targets)
if parallel_method == 'DistributedDataParallel':
loss_value = reduce_mean(loss, num_gpus).item()
else:
loss_value = loss.item()
test_loss += loss_value * inputs.size(0)
if local_rank == 0:
all_inputs = np.concatenate(all_inputs, axis=0)
all_targets = np.concatenate(all_targets, axis=0)
all_outputs = np.concatenate(all_outputs, axis=0)
np.save(f'{path}/{backbone}_inputs.npy', all_inputs)
np.save(f'{path}/{backbone}_targets.npy', all_targets)
np.save(f'{path}/{backbone}_outputs.npy', all_outputs)
return test_loss / len(test_loader.dataset)
# ============== Main training Loop ==============
best_val_loss = float('inf')
best_model_path = f'{checkpoint_dir}/{backbone}_best_model.pth'
if local_rank == 0 and os.path.exists(best_model_path):
try:
logging.info('Loading best model from checkpoint.')
checkpoint = torch.load(best_model_path, map_location=device)
model.load_state_dict(checkpoint)
except Exception as e:
logging.error(f'Error loading model checkpoint: {e}')
for epoch in range(num_epochs):
if local_rank == 0:
logging.info(f'Epoch {epoch + 1}/{num_epochs}')
train_loss = train(model, train_loader, criterion, optimizer, device, epoch)
val_loss = validate(model, val_loader, criterion, device)
scheduler.step()
if local_rank == 0:
current_lr = optimizer.param_groups[0]['lr']
logging.info(f'Current Learning Rate: {current_lr:.10f}')
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), best_model_path)
logging.info(f'Train Loss: {train_loss:.7f}, Val Loss: {val_loss:.7f}')
if local_rank == 0:
try:
model.load_state_dict(torch.load(best_model_path))
test_loss = test(model, test_loader, criterion, device)
logging.info(f"Testing completed. Test Loss: {test_loss:.7f}")
except Exception as e:
logging.error(f'Error loading model checkpoint during testing: {e}')
if parallel_method == 'DistributedDataParallel':
dist.destroy_process_group()