| import sys |
| import os |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
| import argparse |
| import datetime |
| import numpy as np |
| import time |
| import torch |
| import torch.backends.cudnn as cudnn |
| import json |
| import os |
| from functools import partial |
| from pathlib import Path |
|
|
| from timm.models import create_model |
| from optim_factory import create_optimizer, get_parameter_groups, LayerDecayValueAssigner |
|
|
| from engine import train_one_epoch_postprocess, validation_one_epoch_postprocess |
| from utils import NativeScalerWithGradNormCount as NativeScaler |
| from utils import multiple_samples_collate |
| import utils |
| from aurora import Aurora |
| from dataset import utils_data |
| from torch import nn |
| import xarray as xr |
|
|
| from aurora.batch import interpolate_numpy |
| from datetime import timedelta |
| import multiprocessing as mp |
|
|
| from dataset.score import CrpsGaussianLoss, EECRPSGaussianLoss |
|
|
|
|
| |
| def get_args(): |
| parser = argparse.ArgumentParser('Finetuning for post-processing', add_help=False) |
| parser.add_argument('--model', default='Aurora', type=str, metavar='MODEL', |
| help='Name of model to train') |
|
|
| parser.add_argument('--batch_size', default=64, type=int) |
| parser.add_argument('--val_batch_size', default=64, type=int) |
| parser.add_argument('--epochs', default=30, type=int) |
| parser.add_argument('--update_freq', default=1, type=int) |
| parser.add_argument('--save_ckpt_freq', default=100, type=int) |
|
|
| |
| parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', |
| help='Optimizer (default: "adamw"') |
| parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', |
| help='Optimizer Epsilon (default: 1e-8)') |
| parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', |
| help='Optimizer Betas (default: None, use opt default)') |
| parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', |
| help='Clip gradient norm (default: None, no clipping)') |
| parser.add_argument('--momentum', type=float, default=0.9, metavar='M', |
| help='SGD momentum (default: 0.9)') |
| parser.add_argument('--weight_decay', type=float, default=0.05, |
| help='weight decay (default: 0.05)') |
| parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the |
| weight decay. We use a cosine schedule for WD and using a larger decay by |
| the end of training improves performance for ViTs.""") |
|
|
| parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', |
| help='learning rate (default: 1e-3)') |
| parser.add_argument('--layer_decay', type=float, default=0.75) |
|
|
| parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR', |
| help='warmup learning rate (default: 1e-6)') |
| parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', |
| help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') |
|
|
| parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', |
| help='epochs to warmup LR, if scheduler supports') |
| parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', |
| help='num of steps to warmup LR, will overload warmup_epochs if set > 0') |
|
|
|
|
| |
| parser.add_argument('--crop_pct', type=float, default=None) |
| parser.add_argument('--short_side_size', type=int, default=224) |
| parser.add_argument('--test_num_segment', type=int, default=5) |
| parser.add_argument('--test_num_crop', type=int, default=3) |
| |
|
|
| |
| parser.add_argument('--finetune', default='', help='finetune from checkpoint') |
| parser.add_argument('--train_start_date', default='', help='train_start_date') |
| parser.add_argument('--train_end_date', default='', help='train_end_date') |
| parser.add_argument('--val_start_date', default='', help='val_start_date') |
| parser.add_argument('--val_end_date', default='', help='val_end_date') |
| parser.add_argument('--model_key', default='model|module', type=str) |
| parser.add_argument('--model_prefix', default='', type=str) |
| parser.add_argument('--init_scale', default=0.001, type=float) |
| parser.add_argument('--use_checkpoint', action='store_true') |
| parser.set_defaults(use_checkpoint=False) |
| parser.add_argument('--use_mean_pooling', action='store_true') |
| parser.set_defaults(use_mean_pooling=True) |
| parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling') |
|
|
| |
| parser.add_argument('--var', default='T2M', type=str,help='target variable') |
| parser.add_argument('--nb_classes', default=400, type=int, |
| help='number of the classification types') |
|
|
| parser.add_argument('--output_dir', default='', |
| help='path where to save, empty for no saving') |
| parser.add_argument('--log_dir', default=None, |
| help='path where to tensorboard log') |
| parser.add_argument('--device', default='cuda', |
| help='device to use for training / testing') |
| parser.add_argument('--seed', default=0, type=int) |
| parser.add_argument('--resume', default='', |
| help='resume from checkpoint') |
| parser.add_argument('--auto_resume', action='store_true') |
| parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume') |
| parser.set_defaults(auto_resume=True) |
|
|
| parser.add_argument('--save_ckpt', action='store_true') |
| parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt') |
| parser.set_defaults(save_ckpt=True) |
|
|
| parser.add_argument('--start_epoch', default=0, type=int, metavar='N', |
| help='start epoch') |
| parser.add_argument('--eval', action='store_true', |
| help='Perform evaluation only') |
| parser.add_argument('--dist_eval', action='store_true', default=False, |
| help='Enabling distributed evaluation') |
| parser.add_argument('--num_workers', default=6, type=int) |
| parser.add_argument('--pin_mem', action='store_true', |
| help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') |
| parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') |
| parser.set_defaults(pin_mem=True) |
| |
| parser.add_argument('--full', action='store_true', dest='full') |
| parser.set_defaults(pin_mem=False) |
|
|
| |
| parser.add_argument('--world_size', default=1, type=int, |
| help='number of distributed processes') |
| parser.add_argument('--local_rank', default=-1, type=int) |
| parser.add_argument('--dist_on_itp', action='store_true') |
| parser.add_argument('--dist_url', default='env://', |
| help='url used to set up distributed training') |
| parser.add_argument('--mode', default='full', help='fine_tuning mode') |
|
|
| ds_init = None |
|
|
| return parser.parse_args(), ds_init |
|
|
|
|
| def main(args, ds_init): |
| utils.init_distributed_mode(args) |
|
|
| if ds_init is not None: |
| utils.create_ds_config(args) |
|
|
| print(args) |
|
|
| device = torch.device(args.device) |
|
|
| |
| seed = args.seed + utils.get_rank() |
| |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| |
|
|
| cudnn.benchmark = True |
|
|
|
|
| static_vars_ds = xr.open_dataset("../aux_data/static.nc", engine="netcdf4") |
| |
| lat = torch.linspace(90, -90, 361) |
| lon = torch.linspace(0, 360, 721)[:-1] |
| level = (500, 850) |
| static_vars = { |
| |
| "lsm": torch.from_numpy(interpolate_numpy(static_vars_ds["lsm"].values[0], lat=static_vars_ds.latitude.values, lon=static_vars_ds.longitude.values, |
| lat_new=lat, lon_new=lon)).float(), |
| "z": torch.from_numpy(interpolate_numpy(static_vars_ds["z"].values[0], lat=static_vars_ds.latitude.values, lon=static_vars_ds.longitude.values, |
| lat_new=lat, lon_new=lon)).float(), |
| "slt": torch.from_numpy(interpolate_numpy(static_vars_ds["slt"].values[0], lat=static_vars_ds.latitude.values, lon=static_vars_ds.longitude.values, |
| lat_new=lat, lon_new=lon)).float(), |
| } |
| |
| surf_vars = ('SSTK', 'TCW', 'TCWV', 'CP', 'MSL', 'TCC', 'U10M', 'V10M', 'T2M', 'TP', 'SKT') |
| upper_vars = ["Z", "T", "Q", "W", "D", "U", "V"] |
| |
| out_surf_vars = [] |
| out_upper_vars = [] |
| out_upper_level = [] |
| |
| out_surf_vars = ["T2M", "T2M_std", 'U10M', 'U10M_std', 'V10M', 'V10M_std'] |
| out_upper_vars = [ "T", "T_std", "Z", "Z_std"] |
| out_upper_level = [850, 500] |
| |
| |
| criterion1 = CrpsGaussianLoss() |
| criterion2 = EECRPSGaussianLoss() |
| surface_efis = [] |
| upper_efis = [] |
| |
| args.data_path = "../datasets/post-process/process" |
| |
| if out_surf_vars: |
| for var in out_surf_vars[::2]: |
| surface_efis.append(xr.open_dataset(os.path.join(args.data_path, f"efi_{var}.nc"))) |
| |
| if out_upper_vars: |
| for i, var in enumerate(out_upper_vars[::2]): |
| upper_efis.append(xr.open_dataset(os.path.join(args.data_path, f"efi_{var}{str(out_upper_level[i])}.nc"))) |
| |
| if args.full: |
| T = 10 |
| print("using full ensemble dataset") |
| else: |
| T = 2 |
| |
| dataset_train = utils_data.Dataset_Postprocessing(data_path=args.data_path, start_date=args.train_start_date, end_date=args.train_end_date, full=args.full, surface=surf_vars, upper=upper_vars, levels=level, target_surface=out_surf_vars[::2], target_upper=out_upper_vars[::2], target_level=out_upper_level) |
| dataset_val = utils_data.Dataset_Postprocessing(data_path=args.data_path, start_date=args.val_start_date, end_date=args.val_end_date, val=True, full=args.full, surface=surf_vars, upper=upper_vars, levels=level, target_surface=out_surf_vars[::2], target_upper=out_upper_vars[::2], target_level=out_upper_level) |
| dataset_test = None |
| sampler_test = None |
| |
| num_tasks = utils.get_world_size() |
| global_rank = utils.get_rank() |
| sampler_train = torch.utils.data.DistributedSampler( |
| dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True |
| ) |
| |
| print("Sampler_train = %s" % str(sampler_train)) |
| if args.dist_eval: |
| if len(dataset_val) % num_tasks != 0: |
| print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' |
| 'This will slightly alter validation results as extra duplicate entries are added to achieve ' |
| 'equal num of samples per-process.') |
| sampler_val = torch.utils.data.DistributedSampler( |
| dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) |
| |
| |
| else: |
| sampler_val = torch.utils.data.SequentialSampler(dataset_val) |
|
|
| if global_rank == 0 and args.log_dir is not None: |
| os.makedirs(args.log_dir, exist_ok=True) |
| log_writer = utils.TensorboardLogger(log_dir=args.log_dir) |
| else: |
| log_writer = None |
|
|
|
|
| collate_func = None |
|
|
| data_loader_train = torch.utils.data.DataLoader( |
| dataset_train, sampler=sampler_train, |
| batch_size=args.batch_size, |
| num_workers=args.num_workers, |
| pin_memory=args.pin_mem, |
| drop_last=True, |
| collate_fn=collate_func, |
| ) |
|
|
| if dataset_val is not None: |
| data_loader_val = torch.utils.data.DataLoader( |
| dataset_val, sampler=sampler_val, |
| batch_size= args.val_batch_size, |
| num_workers=args.num_workers, |
| pin_memory=args.pin_mem, |
| drop_last=False |
| ) |
| else: |
| data_loader_val = None |
|
|
| if dataset_test is not None: |
| data_loader_test = torch.utils.data.DataLoader( |
| dataset_test, sampler=sampler_test, |
| batch_size=args.batch_size, |
| num_workers=args.num_workers, |
| pin_memory=args.pin_mem, |
| drop_last=False |
| ) |
| else: |
| data_loader_test = None |
| |
| in_channels = (len(surf_vars)+len(upper_vars)*len(level)) * T |
| out_channels = (len(out_surf_vars) + len(out_upper_vars)) * 2 |
|
|
| use_ours = True |
| model = Aurora(use_ours=True, autocast=True, |
| surf_vars = surf_vars, |
| atmos_vars = upper_vars, |
| timestep = timedelta(hours=0), |
| out_surf_vars = out_surf_vars, |
| out_atmos_vars = out_upper_vars, |
| task = "postprocess", |
| max_history_size = T, |
| ours_prompt_length=5 |
| ) |
|
|
|
|
|
|
| model.configure_activation_checkpointing() |
| print("load pretrain model") |
| model.load_checkpoint_local("../aurora-0.25-pretrained.ckpt", strict=False) |
| |
| model.to(device) |
| model.train() |
|
|
| model_ema = None |
|
|
| model_without_ddp = model |
|
|
| total_batch_size = args.batch_size * args.update_freq * utils.get_world_size() |
| num_training_steps_per_epoch = len(dataset_train) // total_batch_size |
| args.lr = args.lr * total_batch_size / 256 |
| args.min_lr = args.min_lr * total_batch_size / 256 |
| args.warmup_lr = args.warmup_lr * total_batch_size / 256 |
| print("LR = %.8f" % args.lr) |
| print("Batch size = %d" % total_batch_size) |
| print("Update frequent = %d" % args.update_freq) |
| print("Number of training examples = %d" % len(dataset_train)) |
| print("Number of training training per epoch = %d" % num_training_steps_per_epoch) |
|
|
| num_layers = model_without_ddp.get_num_layers() |
| if args.layer_decay < 1.0: |
| assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))) |
| else: |
| assigner = None |
|
|
|
|
| skip_weight_decay_list = model.no_weight_decay() |
|
|
| if args.distributed: |
|
|
| model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) |
| model_without_ddp = model.module |
|
|
| optimizer = create_optimizer( |
| args, model_without_ddp, skip_list=skip_weight_decay_list, |
| get_num_layer=assigner.get_layer_id if assigner is not None else None, |
| get_layer_scale=assigner.get_scale if assigner is not None else None) |
| loss_scaler = NativeScaler() |
|
|
| print("Use step level LR scheduler!") |
| lr_schedule_values = utils.cosine_scheduler( |
| args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, |
| warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, |
| ) |
| if args.weight_decay_end is None: |
| args.weight_decay_end = args.weight_decay |
| wd_schedule_values = utils.cosine_scheduler( |
| args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) |
| print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values))) |
|
|
| utils.auto_load_model( |
| args=args, model=model, model_without_ddp=model_without_ddp, |
| optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema) |
|
|
| if args.eval: |
| print("Start evaluating") |
| test_stats = validation_one_epoch_postprocess(data_loader_val, model, device,lat = lat, lon = lon, level = level, criterion1 = criterion1, criterion2= criterion2, |
| static_vars = static_vars, surf_vars=surf_vars, upper_vars=upper_vars, model_name=args.model, surface_efis = surface_efis, |
| upper_efis = upper_efis, out_surf_vars = out_surf_vars[::2], out_upper_vars = out_upper_vars[::2], out_upper_level = out_upper_level,) |
| test_log = {**{f'val_{k}': v for k, v in test_stats.items()}} |
| |
| copy_log = " &"+" &".join([str(round(v, 3)) for k, v in test_stats.items() if k!="valid_loss"]) |
| |
| if args.output_dir and utils.is_main_process(): |
| if log_writer is not None: |
| log_writer.flush() |
| with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: |
| f.write(json.dumps("Eval:") + "\n") |
| f.write(json.dumps(test_log) + "\n") |
| f.write(json.dumps(copy_log) + "\n" + "\n") |
| |
| |
| exit(0) |
| |
| print(f"Start training for {args.epochs} epochs") |
| start_time = time.time() |
| train_time_only = 0 |
| total_step = args.epochs * num_training_steps_per_epoch |
| for epoch in range(args.start_epoch, args.epochs): |
| if args.distributed: |
| data_loader_train.sampler.set_epoch(epoch) |
| if log_writer is not None: |
| log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq) |
| train_start_time = time.time() |
| train_stats = train_one_epoch_postprocess( |
| model, data_loader_train, optimizer, |
| device, epoch, loss_scaler, args.clip_grad, |
| log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch, |
| lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, |
| num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq, criterion = criterion1, |
| lat = lat, lon = lon, level = level, static_vars = static_vars, surf_vars=surf_vars, upper_vars=upper_vars, model_name=args.model, |
| out_surf_vars = out_surf_vars[::2], out_upper_vars = out_upper_vars[::2], out_upper_level = out_upper_level,use_ours=use_ours, total_step=total_step |
| ) |
| train_time_only += time.time() - train_start_time |
| if args.output_dir and args.save_ckpt: |
| if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: |
| utils.save_model( |
| args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, |
| loss_scaler=loss_scaler, epoch=epoch) |
| |
| if data_loader_val is not None: |
| test_stats = validation_one_epoch_postprocess(data_loader_val, model, device,lat = lat, lon = lon, level = level, criterion1 = criterion1, criterion2= criterion2, |
| static_vars = static_vars, surf_vars=surf_vars, upper_vars=upper_vars, model_name=args.model, surface_efis = surface_efis, |
| upper_efis = upper_efis, out_surf_vars = out_surf_vars[::2], out_upper_vars = out_upper_vars[::2], out_upper_level = out_upper_level,) |
| log_stats = {'epoch': epoch, |
| **{f'train_{k}': v for k, v in train_stats.items()}} |
| test_log = {**{f'val_{k}': v for k, v in test_stats.items()}} |
| |
| copy_log = " &"+" &".join([str(round(v, 3)) for k, v in test_stats.items() if k!="valid_loss"]) |
| |
| if args.output_dir and utils.is_main_process(): |
| if log_writer is not None: |
| log_writer.flush() |
| with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: |
| f.write(json.dumps(log_stats) + "\n") |
| f.write(json.dumps(test_log) + "\n") |
| f.write(json.dumps(copy_log) + "\n" + "\n") |
|
|
| total_time = time.time() - start_time |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
| |
| train_time_str = str(datetime.timedelta(seconds=int(train_time_only))) |
| print('Training time only {}'.format(train_time_str)) |
| print('Total time {}'.format(total_time_str)) |
|
|
| if args.output_dir and utils.is_main_process(): |
| if log_writer is not None: |
| log_writer.flush() |
| with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: |
| f.write(json.dumps('Training time only {}'.format(train_time_str) + "\n")) |
| f.write(json.dumps('Total time {}'.format(total_time_str)) + "\n") |
|
|
| if __name__ == '__main__': |
| |
| opts, ds_init = get_args() |
| if opts.output_dir: |
| Path(opts.output_dir).mkdir(parents=True, exist_ok=True) |
| main(opts, ds_init) |
|
|