File size: 9,472 Bytes
fbb20ff 7ee73c7 fbb20ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 |
import os
import sys
# Ensure repo root is importable when running as `python scripts/train.py`.
# Without this, `genmo.*` may resolve from site-packages while `third_party.*`
# (a namespace package in this repo) fails to import, which Hydra reports as
# "Error locating target ...".
_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if _REPO_ROOT not in sys.path:
sys.path.insert(0, _REPO_ROOT)
# GVHMR uses absolute imports like `import hmr4d...` internally, so its repo root
# must also be importable.
_GVHMR_ROOT = os.path.join(_REPO_ROOT, "third_party", "GVHMR")
if os.path.isdir(_GVHMR_ROOT) and _GVHMR_ROOT not in sys.path:
sys.path.insert(0, _GVHMR_ROOT)
import builtins
from datetime import datetime
import hydra
import pytorch_lightning as pl
import torch
import torch.distributed as dist
import yaml
from omegaconf import DictConfig, ListConfig, OmegaConf
from pytorch_lightning.callbacks.checkpoint import Checkpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from genmo.callbacks.autoresume_callback import AutoResume, AutoResumeCallback
from genmo.utils.net_utils import get_resume_ckpt_path, load_pretrained_model
from genmo.utils.pylogger import Log
from genmo.utils.tools import (
find_last_version,
get_checkpoint_path,
rsync_file_from_remote,
)
from genmo.utils.vis.rich_logger import print_cfg
OmegaConf.register_new_resolver("eval", builtins.eval)
def _get_rank():
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
# therefore LOCAL_RANK needs to be checked first
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
for key in rank_keys:
rank = os.environ.get(key)
if rank is not None:
return int(rank)
# None to differentiate whether an environment variable was set at all
return 0
global_rank = _get_rank()
def get_callbacks(cfg: DictConfig) -> list:
"""Parse and instantiate all the callbacks in the config.
Supports both flat and nested callback configs. Only nodes containing
a `_target_` are instantiated.
"""
if not hasattr(cfg, "callbacks") or cfg.callbacks is None:
return None
def _collect_callback_nodes(node):
collected = []
if node is None:
return collected
# Dict-like node
if isinstance(node, (DictConfig, dict)):
# direct instantiable config
if "_target_" in node:
collected.append(node)
else:
for child in node.values():
collected.extend(_collect_callback_nodes(child))
# List-like node
elif isinstance(node, (ListConfig, list, tuple)):
for child in node:
collected.extend(_collect_callback_nodes(child))
# primitives are ignored
return collected
enable_checkpointing = cfg.pl_trainer.get("enable_checkpointing", True)
callbacks = []
for cb_conf in _collect_callback_nodes(cfg.callbacks):
cb = hydra.utils.instantiate(cb_conf, _recursive_=False)
if not enable_checkpointing and isinstance(cb, Checkpoint):
continue
callbacks.append(cb)
return callbacks
def train(cfg: DictConfig) -> None:
"""Train/Test"""
Log.info(f"[Exp Name]: {cfg.exp_name}")
# use total batch size
if cfg.task == "fit":
Log.info(
f"[GPU x Batch] = {cfg.pl_trainer.devices} x {cfg.data.loader_opts.train.batch_size}"
)
num_nodes = cfg.pl_trainer.get("num_nodes", 1)
cfg.num_test_data *= cfg.pl_trainer.devices * num_nodes
if (
"imgfeat_motionx" in cfg.test_datasets
and "max_num_motions" in cfg.test_datasets.imgfeat_motionx
):
cfg.test_datasets.imgfeat_motionx.max_num_motions *= (
cfg.pl_trainer.devices * num_nodes
)
pl.seed_everything(cfg.seed)
torch.cuda.set_device(global_rank % 8) # for tinycudann default memory
version = None
tb_logger = None
if cfg.get("timing", False):
os.environ["DEBUG_TIMING"] = "TRUE"
if AutoResume is not None:
details = AutoResume.get_resume_details()
if details:
cfg.resume_mode = "last"
version = int(details["version"])
print(
f"[Auto Resume] Loading. checkpoint: {details['checkpoint']} version: {details['version']}"
)
if cfg.task == "test" and not cfg.get("no_checkpoint", False):
test_cp = cfg.get("test_checkpoint", "last")
remote_run_dir = cfg.output_dir.replace("outputs", cfg.remote_results_path)
version = find_last_version(remote_run_dir, cp=test_cp)
checkpoint_dir = f"{remote_run_dir}/version_{version}/checkpoints"
remote_ckpt_path = get_checkpoint_path(checkpoint_dir, test_cp)
if cfg.get("rsync_ckpt", False):
cfg.ckpt_path = remote_ckpt_path.replace(cfg.remote_results_path, "outputs")
if not os.path.exists(cfg.ckpt_path):
print(f"rsyncing from remote: {remote_ckpt_path}")
print(f"output_dir: {cfg.output_dir}")
rsync_file_from_remote(
cfg.ckpt_path,
remote_run_dir,
cfg.output_dir,
hostname="cs-oci-ord-dc-03",
)
else:
cfg.ckpt_path = remote_ckpt_path
print("ckpt path:", cfg.ckpt_path)
cfg.output_dir = f"{cfg.output_dir}/version_{version}"
cfg.logger.name = (
f"{cfg.exp_name}_v{version}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
)
else:
run_root_dir = cfg.output_dir
if version is None and cfg.resume_mode == "last":
version = find_last_version(run_root_dir, cp="last")
# preparation
datamodule: pl.LightningDataModule = hydra.utils.instantiate(
cfg.data, _recursive_=False
)
model: pl.LightningModule = hydra.utils.instantiate(cfg.model, _recursive_=False)
if (
cfg.get("pretrain_ckpt", None) is not None
and cfg.ckpt_path is None
and cfg.resume_mode is None
):
cfg.ckpt_path = cfg.pretrain_ckpt
if cfg.ckpt_path is not None:
if cfg.get("rsync_ckpt", False) and not os.path.exists(cfg.ckpt_path):
print(f"rsyncing from remote: {cfg.ckpt_path}")
cfg.ckpt_path = cfg.ckpt_path.replace(cfg.remote_results_path, "outputs")
local_dir = cfg.ckpt_path.split("/version_")[0]
os.makedirs(local_dir, exist_ok=True)
rsync_file_from_remote(
cfg.ckpt_path,
cfg.remote_results_path,
"outputs",
hostname="cs-oci-ord-dc-03",
)
ckpt = load_pretrained_model(model, cfg.ckpt_path)
print(f"Loaded pretrained model from {cfg.ckpt_path}")
if ckpt is not None:
print(
"pretrained ckpt info:",
{"global_step": ckpt["global_step"], "epoch": ckpt["epoch"]},
)
# PL callbacks and logger (TensorBoard only)
if cfg.task == "fit":
tb_logger = TensorBoardLogger(run_root_dir, version=version, name="")
version = tb_logger.version
if global_rank == 0:
os.makedirs(tb_logger.log_dir, exist_ok=True)
cfg.output_dir = tb_logger.log_dir
if cfg.pl_trainer.devices > 1 and "RANK" in os.environ:
dist.init_process_group("nccl")
dist.barrier()
if global_rank != 0:
if version is None:
version = find_last_version(run_root_dir, cp=None)
cfg.output_dir = f"{run_root_dir}/version_{version}"
callbacks = get_callbacks(cfg)
has_ckpt_cb = any([isinstance(cb, Checkpoint) for cb in callbacks])
if not has_ckpt_cb and cfg.pl_trainer.get("enable_checkpointing", True):
Log.warning("No checkpoint-callback found. Disabling PL auto checkpointing.")
cfg.pl_trainer = {**cfg.pl_trainer, "enable_checkpointing": False}
if AutoResume is not None:
callbacks.append(AutoResumeCallback(version))
logger = tb_logger if tb_logger is not None else False
# PL-Trainer
if cfg.task == "test":
Log.info("Test mode forces full-precision.")
cfg.pl_trainer = {**cfg.pl_trainer, "precision": 32}
trainer = pl.Trainer(
accelerator="gpu",
logger=logger if logger is not None else False,
callbacks=callbacks,
**cfg.pl_trainer,
)
print("=" * 20)
print("version:", version)
if cfg.task == "fit":
resume_path = None
if cfg.resume_mode is not None:
save_dir = cfg.output_dir + "/checkpoints"
resume_path = get_resume_ckpt_path(cfg.resume_mode, ckpt_dir=save_dir)
Log.info("Start Fitting...")
trainer.fit(
model,
datamodule.train_dataloader(),
datamodule.val_dataloader(),
ckpt_path=resume_path,
)
elif cfg.task == "test":
Log.info("Start Testing...")
trainer.test(model, datamodule.test_dataloader())
else:
raise ValueError(f"Unknown task: {cfg.task}")
Log.info("End of script.")
@hydra.main(version_base="1.3", config_path="../configs", config_name="train")
def main(cfg) -> None:
print_cfg(cfg, use_rich=True)
train(cfg)
if __name__ == "__main__":
main()
|