| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| from typing import Tuple, List |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| import os |
| import torch |
| import argparse |
| from pathlib import Path |
| from omegaconf import OmegaConf, DictConfig |
| from einops._torch_specific import allow_ops_in_compiled_graph |
| allow_ops_in_compiled_graph() |
|
|
| import pytorch_lightning as pl |
| from pytorch_lightning.callbacks import ModelCheckpoint, Callback |
| from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy |
| from pytorch_lightning.loggers import Logger, TensorBoardLogger |
| from pytorch_lightning.utilities import rank_zero_info |
|
|
| from hy3dshape.utils import get_config_from_file, instantiate_from_config |
|
|
|
|
| class SetupCallback(Callback): |
| def __init__(self, config: DictConfig, basedir: Path, logdir: str = "log", ckptdir: str = "ckpt") -> None: |
| super().__init__() |
| self.logdir = basedir / logdir |
| self.ckptdir = basedir / ckptdir |
| self.config = config |
|
|
| def on_fit_start(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule) -> None: |
| if trainer.global_rank == 0: |
| os.makedirs(self.logdir, exist_ok=True) |
| os.makedirs(self.ckptdir, exist_ok=True) |
|
|
|
|
| def setup_callbacks(config: DictConfig) -> Tuple[List[Callback], Logger]: |
| training_cfg = config.training |
| basedir = Path(training_cfg.output_dir) |
| os.makedirs(basedir, exist_ok=True) |
| all_callbacks = [] |
|
|
| setup_callback = SetupCallback(config, basedir) |
| all_callbacks.append(setup_callback) |
| |
| checkpoint_callback = ModelCheckpoint( |
| dirpath=setup_callback.ckptdir, |
| filename="ckpt-{step:08d}", |
| monitor=training_cfg.monitor, |
| mode="max", |
| save_top_k=-1, |
| verbose=False, |
| every_n_train_steps=training_cfg.every_n_train_steps) |
| all_callbacks.append(checkpoint_callback) |
|
|
| if "callbacks" in config: |
| for key, value in config['callbacks'].items(): |
| custom_callback = instantiate_from_config(value) |
| all_callbacks.append(custom_callback) |
|
|
| logger = TensorBoardLogger(save_dir=str(setup_callback.logdir), name="tensorboard") |
|
|
| return all_callbacks, logger |
|
|
|
|
| def merge_cfg(cfg, arg_cfg): |
| for key in arg_cfg.keys(): |
| if key in cfg.training: |
| arg_cfg[key] = cfg.training[key] |
| cfg.training = DictConfig(arg_cfg) |
| return cfg |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--fast", action='store_true') |
| parser.add_argument("-c", "--config", type=str, required=True) |
| parser.add_argument("-s", "--seed", type=int, default=0) |
| parser.add_argument("-nn", "--num_nodes", type=int, default=1) |
| parser.add_argument("-ng", "--num_gpus", type=int, default=1) |
| parser.add_argument("-u", "--update_every", type=int, default=1) |
| parser.add_argument("-st", "--steps", type=int, default=50000000) |
| parser.add_argument("-lr", "--base_lr", type=float, default=4.5e-6) |
| parser.add_argument("-a", "--use_amp", default=False, action="store_true") |
| parser.add_argument("--amp_type", type=str, default="16") |
| parser.add_argument("--gradient_clip_val", type=float, default=None) |
| parser.add_argument("--gradient_clip_algorithm", type=str, default=None) |
| parser.add_argument("--every_n_train_steps", type=int, default=50000) |
| parser.add_argument("--log_every_n_steps", type=int, default=50) |
| parser.add_argument("--val_check_interval", type=int, default=1024) |
| parser.add_argument("--limit_val_batches", type=int, default=64) |
| parser.add_argument("--monitor", type=str, default="val/total_loss") |
| parser.add_argument("--output_dir", type=str, help="the output directory to save everything.") |
| parser.add_argument("--ckpt_path", type=str, default="", help="the restore checkpoints.") |
| parser.add_argument("--deepspeed", default=False, action="store_true") |
| parser.add_argument("--deepspeed2", default=False, action="store_true") |
| parser.add_argument("--scale_lr", type=bool, nargs="?", const=True, default=False, |
| help="scale base-lr by ngpu * batch_size * n_accumulate") |
| return parser.parse_args() |
| |
|
|
| if __name__ == "__main__": |
| |
| args = get_args() |
| |
| if args.fast: |
| torch.backends.cudnn.allow_tf32 = True |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.set_float32_matmul_precision('medium') |
| torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL = 0.05 |
|
|
| |
| pl.seed_everything(args.seed, workers=True) |
|
|
| |
| config = get_config_from_file(args.config) |
| config = merge_cfg(config, vars(args)) |
| training_cfg = config.training |
|
|
| |
| rank_zero_info("Begin to print configuration ...") |
| rank_zero_info(OmegaConf.to_yaml(config)) |
| rank_zero_info("Finish print ...") |
|
|
| |
| callbacks, loggers = setup_callbacks(config) |
|
|
| |
| data: pl.LightningDataModule = instantiate_from_config(config.dataset) |
|
|
| |
| model: pl.LightningModule = instantiate_from_config(config.model) |
| |
| nodes = args.num_nodes |
| ngpus = args.num_gpus |
| base_lr = training_cfg.base_lr |
| accumulate_grad_batches = training_cfg.update_every |
| batch_size = config.dataset.params.batch_size |
|
|
| if 'NNODES' in os.environ: |
| nodes = int(os.environ['NNODES']) |
| training_cfg.num_nodes = nodes |
| args.num_nodes = nodes |
|
|
| if args.scale_lr: |
| model.learning_rate = accumulate_grad_batches * nodes * ngpus * batch_size * base_lr |
| info = f"Setting learning rate to {model.learning_rate:.2e} = {accumulate_grad_batches} (accumulate)" |
| info += f" * {nodes} (nodes) * {ngpus} (num_gpus) * {batch_size} (batchsize) * {base_lr:.2e} (base_lr)" |
| rank_zero_info(info) |
| else: |
| model.learning_rate = base_lr |
| rank_zero_info("++++ NOT USING LR SCALING ++++") |
| rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}") |
|
|
| |
| if args.num_nodes > 1 or args.num_gpus > 1: |
| if args.deepspeed: |
| ddp_strategy = DeepSpeedStrategy(stage=1) |
| elif args.deepspeed2: |
| ddp_strategy = 'deepspeed_stage_2' |
| else: |
| ddp_strategy = DDPStrategy(find_unused_parameters=False, bucket_cap_mb=1500) |
| else: |
| ddp_strategy = None |
|
|
| rank_zero_info(f'*' * 100) |
| if training_cfg.use_amp: |
| amp_type = training_cfg.amp_type |
| assert amp_type in ['bf16', '16', '32'], f"Invalid amp_type: {amp_type}" |
| rank_zero_info(f'Using {amp_type} precision') |
| else: |
| amp_type = 32 |
| rank_zero_info(f'Using 32 bit precision') |
| rank_zero_info(f'*' * 100) |
|
|
| trainer = pl.Trainer( |
| max_steps=training_cfg.steps, |
| precision=amp_type, |
| callbacks=callbacks, |
| accelerator="gpu", |
| devices=args.num_gpus, |
| num_nodes=training_cfg.num_nodes, |
| strategy=ddp_strategy, |
| gradient_clip_val=training_cfg.get('gradient_clip_val'), |
| gradient_clip_algorithm=training_cfg.get('gradient_clip_algorithm'), |
| accumulate_grad_batches=args.update_every, |
| logger=loggers, |
| log_every_n_steps=training_cfg.log_every_n_steps, |
| val_check_interval=training_cfg.val_check_interval, |
| limit_val_batches=training_cfg.limit_val_batches, |
| check_val_every_n_epoch=None |
| ) |
|
|
| |
| if training_cfg.ckpt_path == '': |
| training_cfg.ckpt_path = None |
| trainer.fit(model, datamodule=data, ckpt_path=training_cfg.ckpt_path) |
|
|