| import os |
| import sys |
|
|
| |
| |
| |
| |
| _REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
| if _REPO_ROOT not in sys.path: |
| sys.path.insert(0, _REPO_ROOT) |
|
|
| |
| |
| _GVHMR_ROOT = os.path.join(_REPO_ROOT, "third_party", "GVHMR") |
| if os.path.isdir(_GVHMR_ROOT) and _GVHMR_ROOT not in sys.path: |
| sys.path.insert(0, _GVHMR_ROOT) |
|
|
| import builtins |
| from datetime import datetime |
|
|
| import hydra |
| import pytorch_lightning as pl |
| import torch |
| import torch.distributed as dist |
| import yaml |
| from omegaconf import DictConfig, ListConfig, OmegaConf |
| from pytorch_lightning.callbacks.checkpoint import Checkpoint |
| from pytorch_lightning.loggers.tensorboard import TensorBoardLogger |
|
|
| from genmo.callbacks.autoresume_callback import AutoResume, AutoResumeCallback |
| from genmo.utils.net_utils import get_resume_ckpt_path, load_pretrained_model |
| from genmo.utils.pylogger import Log |
| from genmo.utils.tools import ( |
| find_last_version, |
| get_checkpoint_path, |
| rsync_file_from_remote, |
| ) |
| from genmo.utils.vis.rich_logger import print_cfg |
|
|
| OmegaConf.register_new_resolver("eval", builtins.eval) |
|
|
|
|
| def _get_rank(): |
| |
| |
| rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") |
| for key in rank_keys: |
| rank = os.environ.get(key) |
| if rank is not None: |
| return int(rank) |
| |
| return 0 |
|
|
|
|
| global_rank = _get_rank() |
|
|
|
|
| def get_callbacks(cfg: DictConfig) -> list: |
| """Parse and instantiate all the callbacks in the config. |
| |
| Supports both flat and nested callback configs. Only nodes containing |
| a `_target_` are instantiated. |
| """ |
| if not hasattr(cfg, "callbacks") or cfg.callbacks is None: |
| return None |
|
|
| def _collect_callback_nodes(node): |
| collected = [] |
| if node is None: |
| return collected |
| |
| if isinstance(node, (DictConfig, dict)): |
| |
| if "_target_" in node: |
| collected.append(node) |
| else: |
| for child in node.values(): |
| collected.extend(_collect_callback_nodes(child)) |
| |
| elif isinstance(node, (ListConfig, list, tuple)): |
| for child in node: |
| collected.extend(_collect_callback_nodes(child)) |
| |
| return collected |
|
|
| enable_checkpointing = cfg.pl_trainer.get("enable_checkpointing", True) |
| callbacks = [] |
| for cb_conf in _collect_callback_nodes(cfg.callbacks): |
| cb = hydra.utils.instantiate(cb_conf, _recursive_=False) |
| if not enable_checkpointing and isinstance(cb, Checkpoint): |
| continue |
| callbacks.append(cb) |
| return callbacks |
|
|
|
|
| def train(cfg: DictConfig) -> None: |
| """Train/Test""" |
| Log.info(f"[Exp Name]: {cfg.exp_name}") |
| |
| if cfg.task == "fit": |
| Log.info( |
| f"[GPU x Batch] = {cfg.pl_trainer.devices} x {cfg.data.loader_opts.train.batch_size}" |
| ) |
| num_nodes = cfg.pl_trainer.get("num_nodes", 1) |
| cfg.num_test_data *= cfg.pl_trainer.devices * num_nodes |
| if ( |
| "imgfeat_motionx" in cfg.test_datasets |
| and "max_num_motions" in cfg.test_datasets.imgfeat_motionx |
| ): |
| cfg.test_datasets.imgfeat_motionx.max_num_motions *= ( |
| cfg.pl_trainer.devices * num_nodes |
| ) |
| pl.seed_everything(cfg.seed) |
| torch.cuda.set_device(global_rank % 8) |
| version = None |
| tb_logger = None |
|
|
| if cfg.get("timing", False): |
| os.environ["DEBUG_TIMING"] = "TRUE" |
|
|
| if AutoResume is not None: |
| details = AutoResume.get_resume_details() |
| if details: |
| cfg.resume_mode = "last" |
| version = int(details["version"]) |
| print( |
| f"[Auto Resume] Loading. checkpoint: {details['checkpoint']} version: {details['version']}" |
| ) |
|
|
| if cfg.task == "test" and not cfg.get("no_checkpoint", False): |
| test_cp = cfg.get("test_checkpoint", "last") |
| remote_run_dir = cfg.output_dir.replace("outputs", cfg.remote_results_path) |
| version = find_last_version(remote_run_dir, cp=test_cp) |
| checkpoint_dir = f"{remote_run_dir}/version_{version}/checkpoints" |
| remote_ckpt_path = get_checkpoint_path(checkpoint_dir, test_cp) |
| if cfg.get("rsync_ckpt", False): |
| cfg.ckpt_path = remote_ckpt_path.replace(cfg.remote_results_path, "outputs") |
| if not os.path.exists(cfg.ckpt_path): |
| print(f"rsyncing from remote: {remote_ckpt_path}") |
| print(f"output_dir: {cfg.output_dir}") |
| rsync_file_from_remote( |
| cfg.ckpt_path, |
| remote_run_dir, |
| cfg.output_dir, |
| hostname="cs-oci-ord-dc-03", |
| ) |
| else: |
| cfg.ckpt_path = remote_ckpt_path |
| print("ckpt path:", cfg.ckpt_path) |
| cfg.output_dir = f"{cfg.output_dir}/version_{version}" |
| cfg.logger.name = ( |
| f"{cfg.exp_name}_v{version}_{datetime.now().strftime('%Y%m%d%H%M%S')}" |
| ) |
| else: |
| run_root_dir = cfg.output_dir |
| if version is None and cfg.resume_mode == "last": |
| version = find_last_version(run_root_dir, cp="last") |
|
|
| |
| datamodule: pl.LightningDataModule = hydra.utils.instantiate( |
| cfg.data, _recursive_=False |
| ) |
| model: pl.LightningModule = hydra.utils.instantiate(cfg.model, _recursive_=False) |
|
|
| if ( |
| cfg.get("pretrain_ckpt", None) is not None |
| and cfg.ckpt_path is None |
| and cfg.resume_mode is None |
| ): |
| cfg.ckpt_path = cfg.pretrain_ckpt |
|
|
| if cfg.ckpt_path is not None: |
| if cfg.get("rsync_ckpt", False) and not os.path.exists(cfg.ckpt_path): |
| print(f"rsyncing from remote: {cfg.ckpt_path}") |
| cfg.ckpt_path = cfg.ckpt_path.replace(cfg.remote_results_path, "outputs") |
| local_dir = cfg.ckpt_path.split("/version_")[0] |
| os.makedirs(local_dir, exist_ok=True) |
| rsync_file_from_remote( |
| cfg.ckpt_path, |
| cfg.remote_results_path, |
| "outputs", |
| hostname="cs-oci-ord-dc-03", |
| ) |
|
|
| ckpt = load_pretrained_model(model, cfg.ckpt_path) |
| print(f"Loaded pretrained model from {cfg.ckpt_path}") |
| if ckpt is not None: |
| print( |
| "pretrained ckpt info:", |
| {"global_step": ckpt["global_step"], "epoch": ckpt["epoch"]}, |
| ) |
|
|
| |
| if cfg.task == "fit": |
| tb_logger = TensorBoardLogger(run_root_dir, version=version, name="") |
| version = tb_logger.version |
| if global_rank == 0: |
| os.makedirs(tb_logger.log_dir, exist_ok=True) |
| cfg.output_dir = tb_logger.log_dir |
|
|
| if cfg.pl_trainer.devices > 1 and "RANK" in os.environ: |
| dist.init_process_group("nccl") |
| dist.barrier() |
|
|
| if global_rank != 0: |
| if version is None: |
| version = find_last_version(run_root_dir, cp=None) |
| cfg.output_dir = f"{run_root_dir}/version_{version}" |
|
|
| callbacks = get_callbacks(cfg) |
| has_ckpt_cb = any([isinstance(cb, Checkpoint) for cb in callbacks]) |
| if not has_ckpt_cb and cfg.pl_trainer.get("enable_checkpointing", True): |
| Log.warning("No checkpoint-callback found. Disabling PL auto checkpointing.") |
| cfg.pl_trainer = {**cfg.pl_trainer, "enable_checkpointing": False} |
| if AutoResume is not None: |
| callbacks.append(AutoResumeCallback(version)) |
|
|
| logger = tb_logger if tb_logger is not None else False |
|
|
| |
| if cfg.task == "test": |
| Log.info("Test mode forces full-precision.") |
| cfg.pl_trainer = {**cfg.pl_trainer, "precision": 32} |
| trainer = pl.Trainer( |
| accelerator="gpu", |
| logger=logger if logger is not None else False, |
| callbacks=callbacks, |
| **cfg.pl_trainer, |
| ) |
|
|
| print("=" * 20) |
| print("version:", version) |
|
|
| if cfg.task == "fit": |
| resume_path = None |
| if cfg.resume_mode is not None: |
| save_dir = cfg.output_dir + "/checkpoints" |
| resume_path = get_resume_ckpt_path(cfg.resume_mode, ckpt_dir=save_dir) |
| Log.info("Start Fitting...") |
| trainer.fit( |
| model, |
| datamodule.train_dataloader(), |
| datamodule.val_dataloader(), |
| ckpt_path=resume_path, |
| ) |
| elif cfg.task == "test": |
| Log.info("Start Testing...") |
| trainer.test(model, datamodule.test_dataloader()) |
| else: |
| raise ValueError(f"Unknown task: {cfg.task}") |
|
|
| Log.info("End of script.") |
|
|
|
|
| @hydra.main(version_base="1.3", config_path="../configs", config_name="train") |
| def main(cfg) -> None: |
| print_cfg(cfg, use_rich=True) |
| train(cfg) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|