|
|
import os |
|
|
import sys |
|
|
import time |
|
|
import torch |
|
|
import gc |
|
|
import logging |
|
|
import argparse |
|
|
import numpy as np |
|
|
from icecream import ic |
|
|
from shutil import copyfile |
|
|
from collections import OrderedDict |
|
|
import torch.nn as nn |
|
|
import torch.cuda.amp as amp |
|
|
import torch.distributed as dist |
|
|
from torch.nn.parallel import DistributedDataParallel |
|
|
import torch.utils.checkpoint as checkpoint |
|
|
|
|
|
from my_utils import logging_utils |
|
|
logging_utils.config_logger() |
|
|
from my_utils.YParams import YParams |
|
|
from my_utils.data_loader import get_data_loader |
|
|
from my_utils.darcy_loss import LossScaler, channel_wise_LpLoss |
|
|
|
|
|
from ruamel.yaml import YAML |
|
|
from ruamel.yaml.comments import CommentedMap as ruamelDict |
|
|
|
|
|
|
|
|
class Trainer(): |
|
|
def count_parameters(self): |
|
|
return sum(p.numel() for p in self.model.parameters() if p.requires_grad) |
|
|
|
|
|
def __init__(self, params, world_rank): |
|
|
|
|
|
self.params = params |
|
|
self.world_rank = world_rank |
|
|
self.device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
|
|
|
|
|
local_rank = int(os.environ["LOCAL_RANK"]) |
|
|
torch.cuda.set_device(local_rank) |
|
|
self.device = torch.device('cuda', local_rank) |
|
|
logging.info('device: %s' % self.device) |
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
train_data_path = os.path.join(script_dir, params.train_data_path) |
|
|
valid_data_path = os.path.join(script_dir, params.valid_data_path) |
|
|
|
|
|
|
|
|
|
|
|
logging.info('rank %d, begin data loader init' % world_rank) |
|
|
self.train_data_loader, self.train_dataset, self.train_sampler = get_data_loader( |
|
|
params, |
|
|
train_data_path, |
|
|
dist.is_initialized(), |
|
|
train=True) |
|
|
self.valid_data_loader, self.valid_dataset, self.valid_sampler = get_data_loader( |
|
|
params, |
|
|
valid_data_path, |
|
|
dist.is_initialized(), |
|
|
train=True) |
|
|
|
|
|
if params.loss_channel_wise: |
|
|
self.loss_obj = channel_wise_LpLoss(scale = params.loss_scale) |
|
|
|
|
|
|
|
|
self.mse_loss_scaler = LossScaler() |
|
|
|
|
|
logging.info('rank %d, data loader initialized' % world_rank) |
|
|
|
|
|
if params.nettype == 'OneForecast': |
|
|
from models.OneForecast import OneForecast as model |
|
|
else: |
|
|
raise Exception("not implemented") |
|
|
|
|
|
self.model = model(params).to(self.device) |
|
|
|
|
|
if params.optimizer_type == 'FusedAdam': |
|
|
self.optimizer = torch.optim.Adam(self.model.parameters(), lr = params.lr) |
|
|
|
|
|
if params.enable_amp == True: |
|
|
self.gscaler = amp.GradScaler() |
|
|
|
|
|
if dist.is_initialized(): |
|
|
self.model = DistributedDataParallel( |
|
|
self.model, |
|
|
device_ids=[params.local_rank], |
|
|
output_device=[params.local_rank], |
|
|
find_unused_parameters=False |
|
|
) |
|
|
|
|
|
self.iters = 0 |
|
|
self.startEpoch = 0 |
|
|
|
|
|
if (params.multi_steps_finetune == 1) and (params.resuming): |
|
|
logging.info("Loading checkpoint %s" % params.checkpoint_path) |
|
|
self.restore_checkpoint(params.checkpoint_path) |
|
|
|
|
|
if params.multi_steps_finetune > 1: |
|
|
logging.info("Starting from pretrained one-step model at %s"%params.pretrained_ckpt_path) |
|
|
self.restore_checkpoint(params.pretrained_ckpt_path) |
|
|
self.iters = 0 |
|
|
self.startEpoch = 0 |
|
|
logging.info("Adding %d epochs specified in config file for refining pretrained model"%params.finetune_max_epochs) |
|
|
params['max_epochs'] = params.finetune_max_epochs |
|
|
|
|
|
self.epoch = self.startEpoch |
|
|
|
|
|
|
|
|
if params.scheduler == 'CosineAnnealingLR': |
|
|
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
|
|
self.optimizer, |
|
|
T_max=params.max_epochs, |
|
|
last_epoch=self.startEpoch - 1 |
|
|
) |
|
|
else: |
|
|
self.scheduler = None |
|
|
|
|
|
if params.log_to_screen: |
|
|
logging.info("Number of trainable model parameters: {}".format(self.count_parameters())) |
|
|
|
|
|
def switch_off_grad(self, model): |
|
|
for param in model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def train(self): |
|
|
if self.params.log_to_screen: |
|
|
logging.info("Starting Training Loop...") |
|
|
|
|
|
best_valid_loss = 1.e6 |
|
|
for epoch in range(self.startEpoch, self.params.max_epochs): |
|
|
if dist.is_initialized(): |
|
|
self.train_sampler.set_epoch(epoch) |
|
|
self.valid_sampler.set_epoch(epoch) |
|
|
|
|
|
start = time.time() |
|
|
tr_time, data_time, step_time, train_logs = self.train_one_epoch() |
|
|
valid_time, valid_logs = self.validate_one_epoch() |
|
|
|
|
|
if self.params.scheduler == 'CosineAnnealingLR': |
|
|
self.scheduler.step() |
|
|
if self.epoch >= self.params.max_epochs: |
|
|
logging.info('Time taken for epoch {} is {} sec'.format(epoch + 1, time.time() - start)) |
|
|
logging.info('lr for epoch {} is {}'.format(epoch + 1, self.optimizer.param_groups[0]['lr'])) |
|
|
logging.info('train data time={}, train per epoch time={}, train per step time={}, valid time={}'.format(data_time, tr_time, step_time, valid_time)) |
|
|
logging.info('Train loss: {}. Valid loss: {}'.format(train_logs['train_loss'], valid_logs['valid_loss'])) |
|
|
logging.info("Terminating training after reaching params.max_epochs while LR scheduler is set to CosineAnnealingLR") |
|
|
exit() |
|
|
|
|
|
if self.world_rank == 0: |
|
|
if self.params.save_checkpoint: |
|
|
|
|
|
self.save_checkpoint(self.params.checkpoint_path) |
|
|
if valid_logs['valid_loss'] <= best_valid_loss: |
|
|
logging.info('Val loss improved from {} to {}'.format(best_valid_loss, valid_logs['valid_loss'])) |
|
|
self.save_checkpoint(self.params.best_checkpoint_path) |
|
|
best_valid_loss = valid_logs['valid_loss'] |
|
|
|
|
|
if self.params.log_to_screen: |
|
|
logging.info('Time taken for epoch {} is {} sec'.format(epoch + 1, time.time() - start)) |
|
|
logging.info('lr for epoch {} is {}'.format(epoch + 1, self.optimizer.param_groups[0]['lr'])) |
|
|
logging.info('train data time={}, train per epoch time={}, train per step time={}, valid time={}'.format(data_time, tr_time, step_time, valid_time)) |
|
|
logging.info('Train loss: {}. Valid loss: {}'.format(train_logs['train_loss'], valid_logs['valid_loss'])) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
def train_one_epoch(self): |
|
|
self.epoch += 1 |
|
|
tr_time = 0 |
|
|
data_time = 0 |
|
|
self.model.train() |
|
|
|
|
|
steps_in_one_epoch = 0 |
|
|
for i, data in enumerate(self.train_data_loader, 0): |
|
|
self.iters += 1 |
|
|
steps_in_one_epoch += 1 |
|
|
|
|
|
data_start = time.time() |
|
|
|
|
|
(inp, tar) = data |
|
|
|
|
|
data_time += time.time() - data_start |
|
|
|
|
|
tr_start = time.time() |
|
|
self.model.zero_grad() |
|
|
|
|
|
num_steps = params.multi_steps_finetune |
|
|
|
|
|
with amp.autocast(self.params.enable_amp): |
|
|
|
|
|
gen_prev = None |
|
|
loss = 0.0 |
|
|
cw_loss = 0.0 |
|
|
|
|
|
for step_idx in range(num_steps): |
|
|
if step_idx == 0: |
|
|
inp_step_1 = inp.to(self.device, dtype = torch.float32) |
|
|
if params.multi_steps_finetune == 1: |
|
|
gen_cur = self.model(inp_step_1) |
|
|
else: |
|
|
gen_cur = checkpoint.checkpoint(self.model, inp_step_1, use_reentrant=False) |
|
|
else: |
|
|
gen_cur = checkpoint.checkpoint(self.model, gen_prev, use_reentrant=False) |
|
|
|
|
|
if params.multi_steps_finetune == 1: |
|
|
tar_step = tar[:, self.params.out_channels].to(self.device, dtype=torch.float) |
|
|
else: |
|
|
tar_step = tar[:, step_idx, self.params.out_channels].to(self.device, dtype=torch.float) |
|
|
|
|
|
if self.params.use_loss_scaler_from_metnet3: |
|
|
gen_cur = self.mse_loss_scaler(gen_cur) |
|
|
|
|
|
loss_step, cw_loss_step = self.loss_obj(gen_cur, tar_step) |
|
|
|
|
|
loss += loss_step |
|
|
cw_loss += cw_loss_step |
|
|
if step_idx == 0: |
|
|
del inp |
|
|
mse1 = torch.mean((gen_cur - tar_step) ** 2).item() |
|
|
|
|
|
gen_prev = gen_cur |
|
|
|
|
|
del tar_step, gen_cur |
|
|
del gen_prev |
|
|
|
|
|
if self.params.enable_amp: |
|
|
self.gscaler.scale(loss).backward() |
|
|
self.gscaler.step(self.optimizer) |
|
|
else: |
|
|
loss.backward() |
|
|
self.optimizer.step() |
|
|
print('1_step_mse:', mse1, '1_step_mse:', mse1) |
|
|
|
|
|
if self.params.enable_amp: |
|
|
self.gscaler.update() |
|
|
break |
|
|
|
|
|
tr_time += time.time() - tr_start |
|
|
|
|
|
logs = {'train_loss': loss} |
|
|
|
|
|
for vi, v in enumerate(self.params.out_variables): |
|
|
logs[f'{v}_train_loss'] = cw_loss[vi] |
|
|
|
|
|
if dist.is_initialized(): |
|
|
for key in sorted(logs.keys()): |
|
|
dist.all_reduce(logs[key].detach()) |
|
|
logs[key] = float(logs[key] / dist.get_world_size()) |
|
|
|
|
|
|
|
|
step_time = tr_time / steps_in_one_epoch |
|
|
|
|
|
return tr_time, data_time, step_time, logs |
|
|
|
|
|
def validate_one_epoch(self): |
|
|
|
|
|
logging.info('validating...') |
|
|
self.model.eval() |
|
|
|
|
|
valid_buff = torch.zeros((3+self.params.N_out_channels), dtype=torch.float32, device=self.device) |
|
|
valid_loss = valid_buff[0].view(-1) |
|
|
valid_l1 = valid_buff[1].view(-1) |
|
|
valid_steps = valid_buff[-1].view(-1) |
|
|
|
|
|
valid_start = time.time() |
|
|
with torch.no_grad(): |
|
|
for i, data in enumerate(self.valid_data_loader, 0): |
|
|
if i > 1: |
|
|
break |
|
|
inp, tar = map(lambda x: x.to(self.device, dtype=torch.float), data) |
|
|
num_steps = params.multi_steps_finetune |
|
|
for step_idx in range(num_steps): |
|
|
if step_idx == 0: |
|
|
inp_step_1 = inp.to(self.device, dtype = torch.float32) |
|
|
gen_cur = self.model(inp_step_1) |
|
|
else: |
|
|
gen_cur = self.model(gen_prev) |
|
|
|
|
|
if params.multi_steps_finetune == 1: |
|
|
tar_step = tar[:, self.params.out_channels].to(self.device, dtype=torch.float) |
|
|
else: |
|
|
tar_step = tar[:, step_idx, self.params.out_channels].to(self.device, dtype=torch.float) |
|
|
if step_idx == 0: |
|
|
del inp_step_1 |
|
|
gen_prev = gen_cur |
|
|
|
|
|
if step_idx == params.multi_steps_finetune - 1: |
|
|
gen, tar = gen_cur, tar_step |
|
|
|
|
|
del tar_step, gen_cur |
|
|
del gen_prev |
|
|
|
|
|
gen.to(self.device, dtype=torch.float) |
|
|
|
|
|
_, cw_valid_loss = self.loss_obj(gen, tar) |
|
|
valid_loss_ = torch.mean((gen[:, :, :, :] - tar[:, :, :, :]) ** 2).item() |
|
|
valid_loss += valid_loss_ |
|
|
valid_l1 += nn.functional.l1_loss(gen, tar) |
|
|
|
|
|
for vi, v in enumerate(self.params.out_variables): |
|
|
valid_buff[vi+2] += cw_valid_loss[vi] |
|
|
|
|
|
valid_steps += 1. |
|
|
|
|
|
del inp, gen, tar |
|
|
|
|
|
|
|
|
os.makedirs(params['experiment_dir'] + "/" + str(i), exist_ok =True) |
|
|
|
|
|
if dist.is_initialized(): |
|
|
dist.all_reduce(valid_buff) |
|
|
|
|
|
|
|
|
valid_buff[0:-1] = valid_buff[0:-1] / valid_buff[-1] |
|
|
valid_buff_cpu = valid_buff.detach().cpu().numpy() |
|
|
|
|
|
valid_time = time.time() - valid_start |
|
|
|
|
|
logs = {'valid_loss': valid_buff_cpu[0], |
|
|
'valid_l1': valid_buff_cpu[1]} |
|
|
for vi, v in enumerate(self.params.out_variables): |
|
|
logs[f'{v}_valid_loss'] = valid_buff_cpu[vi+2] |
|
|
|
|
|
|
|
|
return valid_time, logs |
|
|
|
|
|
|
|
|
def save_checkpoint(self, checkpoint_path, model=None): |
|
|
""" We intentionally require a checkpoint_dir to be passed |
|
|
in order to allow Ray Tune to use this function """ |
|
|
|
|
|
if not model: |
|
|
model = self.model |
|
|
|
|
|
torch.save({'iters': self.iters, 'epoch': self.epoch, 'model_state': model.state_dict(), |
|
|
'optimizer_state_dict': self.optimizer.state_dict()}, checkpoint_path) |
|
|
|
|
|
def restore_checkpoint(self, checkpoint_path): |
|
|
""" We intentionally require a checkpoint_dir to be passed |
|
|
in order to allow Ray Tune to use this function """ |
|
|
checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.params.local_rank)) |
|
|
try: |
|
|
self.model.load_state_dict(checkpoint['model_state']) |
|
|
except: |
|
|
new_state_dict = OrderedDict() |
|
|
for key, val in checkpoint['model_state'].items(): |
|
|
name = key[7:] |
|
|
new_state_dict[name] = val |
|
|
self.model.load_state_dict(new_state_dict) |
|
|
self.iters = checkpoint['iters'] |
|
|
self.startEpoch = checkpoint['epoch'] |
|
|
if self.params.resuming: |
|
|
|
|
|
|
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--run_num", default='00', type=str) |
|
|
parser.add_argument("--yaml_config", default='./config/Model.yaml', type=str) |
|
|
parser.add_argument("--multi_steps_finetune", default=1, type=int) |
|
|
parser.add_argument("--finetune_max_epochs", default=50, type=int) |
|
|
parser.add_argument("--batch_size", default=16, type=int) |
|
|
parser.add_argument("--config", default='AFNO', type=str) |
|
|
parser.add_argument("--enable_amp", action='store_true') |
|
|
parser.add_argument("--epsilon_factor", default=0, type=float) |
|
|
parser.add_argument("--local_rank", default=-1, type=int, help='node rank for distributed training') |
|
|
args = parser.parse_args() |
|
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
yaml_path = os.path.join(script_dir, args.yaml_config) |
|
|
params = YParams(os.path.abspath(yaml_path), args.config, True) |
|
|
params['epsilon_factor'] = args.epsilon_factor |
|
|
params['multi_steps_finetune'] = args.multi_steps_finetune |
|
|
params['finetune_max_epochs'] = args.finetune_max_epochs |
|
|
|
|
|
params['world_size'] = 1 |
|
|
if 'WORLD_SIZE' in os.environ: |
|
|
params['world_size'] = int(os.environ['WORLD_SIZE']) |
|
|
print('world_size :', params['world_size']) |
|
|
|
|
|
print('Initialize distributed process group...') |
|
|
dist.init_process_group(backend='nccl') |
|
|
local_rank = int(os.environ["LOCAL_RANK"]) |
|
|
torch.cuda.set_device(local_rank) |
|
|
params['local_rank'] = local_rank |
|
|
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
world_rank = dist.get_rank() |
|
|
|
|
|
params['global_batch_size'] = args.batch_size |
|
|
params['batch_size'] = int(args.batch_size // params['world_size']) |
|
|
params['enable_amp'] = args.enable_amp |
|
|
|
|
|
if params['multi_steps_finetune'] > 1: |
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
exp_dir_path = os.path.join(script_dir, params.exp_dir) |
|
|
pretrained_expDir = os.path.join(exp_dir_path, args.config, str(args.run_num)) |
|
|
multi_steps = params['multi_steps_finetune'] |
|
|
if params['multi_steps_finetune'] > 2: |
|
|
params['pretrained_ckpt_path'] = os.path.join(pretrained_expDir, f'{multi_steps-1}_steps_finetune/training_checkpoints/best_ckpt.tar') |
|
|
else: |
|
|
params['pretrained_ckpt_path'] = os.path.join(pretrained_expDir, 'training_checkpoints/best_ckpt.tar') |
|
|
|
|
|
expDir = os.path.join(pretrained_expDir, f'{multi_steps}_steps_finetune') |
|
|
if world_rank == 0: |
|
|
os.makedirs(expDir, exist_ok=True) |
|
|
os.makedirs(os.path.join(expDir, 'training_checkpoints/'), exist_ok=True) |
|
|
|
|
|
params['experiment_dir'] = os.path.abspath(expDir) |
|
|
params['checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/ckpt.tar') |
|
|
params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar') |
|
|
|
|
|
params['resuming'] = True |
|
|
else: |
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
exp_dir_path = os.path.join(script_dir, params.exp_dir) |
|
|
expDir = os.path.join(exp_dir_path, args.config, str(args.run_num)) |
|
|
if world_rank == 0: |
|
|
os.makedirs(expDir, exist_ok =True) |
|
|
os.makedirs(os.path.join(expDir, 'training_checkpoints/'), exist_ok =True) |
|
|
copyfile(os.path.abspath(args.yaml_config), os.path.join(expDir, 'config.yaml')) |
|
|
|
|
|
params['experiment_dir'] = os.path.abspath(expDir) |
|
|
params['checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/ckpt.tar') |
|
|
params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar') |
|
|
|
|
|
args.resuming = True if os.path.isfile(params.checkpoint_path) else False |
|
|
params['resuming'] = args.resuming |
|
|
|
|
|
if world_rank == 0: |
|
|
logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'train.log')) |
|
|
logging_utils.log_versions() |
|
|
params.log() |
|
|
|
|
|
params['log_to_screen'] = (world_rank == 0) and params['log_to_screen'] |
|
|
|
|
|
params['in_channels'] = np.array(params['in_channels']) |
|
|
params['out_channels'] = np.array(params['out_channels']) |
|
|
params['N_out_channels'] = len(params['out_channels']) |
|
|
params['N_in_channels'] = len(params['in_channels']) |
|
|
|
|
|
if world_rank == 0: |
|
|
hparams = ruamelDict() |
|
|
yaml = YAML() |
|
|
for key, value in params.params.items(): |
|
|
hparams[str(key)] = str(value) |
|
|
with open(os.path.join(expDir, 'hyperparams.yaml'), 'w') as hpfile: |
|
|
yaml.dump(hparams, hpfile) |
|
|
|
|
|
trainer = Trainer(params, world_rank) |
|
|
trainer.train() |
|
|
logging.info('DONE ---- rank %d' % world_rank) |