hmr-dataset / scripts /train.py
zirobtc's picture
Upload folder using huggingface_hub
fbb20ff verified
import builtins
import os
from datetime import datetime
import hydra
import pytorch_lightning as pl
import torch
import torch.distributed as dist
import yaml
from omegaconf import DictConfig, ListConfig, OmegaConf
from pytorch_lightning.callbacks.checkpoint import Checkpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from genmo.callbacks.autoresume_callback import AutoResume, AutoResumeCallback
from genmo.utils.net_utils import get_resume_ckpt_path, load_pretrained_model
from genmo.utils.pylogger import Log
from genmo.utils.tools import (
find_last_version,
get_checkpoint_path,
rsync_file_from_remote,
)
from genmo.utils.vis.rich_logger import print_cfg
OmegaConf.register_new_resolver("eval", builtins.eval)
def _get_rank():
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
# therefore LOCAL_RANK needs to be checked first
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
for key in rank_keys:
rank = os.environ.get(key)
if rank is not None:
return int(rank)
# None to differentiate whether an environment variable was set at all
return 0
global_rank = _get_rank()
def get_callbacks(cfg: DictConfig) -> list:
"""Parse and instantiate all the callbacks in the config.
Supports both flat and nested callback configs. Only nodes containing
a `_target_` are instantiated.
"""
if not hasattr(cfg, "callbacks") or cfg.callbacks is None:
return None
def _collect_callback_nodes(node):
collected = []
if node is None:
return collected
# Dict-like node
if isinstance(node, (DictConfig, dict)):
# direct instantiable config
if "_target_" in node:
collected.append(node)
else:
for child in node.values():
collected.extend(_collect_callback_nodes(child))
# List-like node
elif isinstance(node, (ListConfig, list, tuple)):
for child in node:
collected.extend(_collect_callback_nodes(child))
# primitives are ignored
return collected
enable_checkpointing = cfg.pl_trainer.get("enable_checkpointing", True)
callbacks = []
for cb_conf in _collect_callback_nodes(cfg.callbacks):
cb = hydra.utils.instantiate(cb_conf, _recursive_=False)
if not enable_checkpointing and isinstance(cb, Checkpoint):
continue
callbacks.append(cb)
return callbacks
def train(cfg: DictConfig) -> None:
"""Train/Test"""
Log.info(f"[Exp Name]: {cfg.exp_name}")
# use total batch size
if cfg.task == "fit":
Log.info(
f"[GPU x Batch] = {cfg.pl_trainer.devices} x {cfg.data.loader_opts.train.batch_size}"
)
num_nodes = cfg.pl_trainer.get("num_nodes", 1)
cfg.num_test_data *= cfg.pl_trainer.devices * num_nodes
if (
"imgfeat_motionx" in cfg.test_datasets
and "max_num_motions" in cfg.test_datasets.imgfeat_motionx
):
cfg.test_datasets.imgfeat_motionx.max_num_motions *= (
cfg.pl_trainer.devices * num_nodes
)
pl.seed_everything(cfg.seed)
torch.cuda.set_device(global_rank % 8) # for tinycudann default memory
version = None
tb_logger = None
if cfg.get("timing", False):
os.environ["DEBUG_TIMING"] = "TRUE"
if AutoResume is not None:
details = AutoResume.get_resume_details()
if details:
cfg.resume_mode = "last"
version = int(details["version"])
print(
f"[Auto Resume] Loading. checkpoint: {details['checkpoint']} version: {details['version']}"
)
if cfg.task == "test" and not cfg.get("no_checkpoint", False):
test_cp = cfg.get("test_checkpoint", "last")
remote_run_dir = cfg.output_dir.replace("outputs", cfg.remote_results_path)
version = find_last_version(remote_run_dir, cp=test_cp)
checkpoint_dir = f"{remote_run_dir}/version_{version}/checkpoints"
remote_ckpt_path = get_checkpoint_path(checkpoint_dir, test_cp)
if cfg.get("rsync_ckpt", False):
cfg.ckpt_path = remote_ckpt_path.replace(cfg.remote_results_path, "outputs")
if not os.path.exists(cfg.ckpt_path):
print(f"rsyncing from remote: {remote_ckpt_path}")
print(f"output_dir: {cfg.output_dir}")
rsync_file_from_remote(
cfg.ckpt_path,
remote_run_dir,
cfg.output_dir,
hostname="cs-oci-ord-dc-03",
)
else:
cfg.ckpt_path = remote_ckpt_path
print("ckpt path:", cfg.ckpt_path)
cfg.output_dir = f"{cfg.output_dir}/version_{version}"
cfg.logger.name = (
f"{cfg.exp_name}_v{version}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
)
else:
run_root_dir = cfg.output_dir
if version is None and cfg.resume_mode == "last":
version = find_last_version(run_root_dir, cp="last")
# preparation
datamodule: pl.LightningDataModule = hydra.utils.instantiate(
cfg.data, _recursive_=False
)
model: pl.LightningModule = hydra.utils.instantiate(cfg.model, _recursive_=False)
if (
cfg.get("pretrain_ckpt", None) is not None
and cfg.ckpt_path is None
and cfg.resume_mode is None
):
cfg.ckpt_path = cfg.pretrain_ckpt
if cfg.ckpt_path is not None:
if cfg.get("rsync_ckpt", False) and not os.path.exists(cfg.ckpt_path):
print(f"rsyncing from remote: {cfg.ckpt_path}")
cfg.ckpt_path = cfg.ckpt_path.replace(cfg.remote_results_path, "outputs")
local_dir = cfg.ckpt_path.split("/version_")[0]
os.makedirs(local_dir, exist_ok=True)
rsync_file_from_remote(
cfg.ckpt_path,
cfg.remote_results_path,
"outputs",
hostname="cs-oci-ord-dc-03",
)
ckpt = load_pretrained_model(model, cfg.ckpt_path)
print(f"Loaded pretrained model from {cfg.ckpt_path}")
if ckpt is not None:
print(
"pretrained ckpt info:",
{"global_step": ckpt["global_step"], "epoch": ckpt["epoch"]},
)
# PL callbacks and logger (TensorBoard only)
if cfg.task == "fit":
tb_logger = TensorBoardLogger(run_root_dir, version=version, name="")
version = tb_logger.version
if global_rank == 0:
os.makedirs(tb_logger.log_dir, exist_ok=True)
cfg.output_dir = tb_logger.log_dir
if cfg.pl_trainer.devices > 1 and "RANK" in os.environ:
dist.init_process_group("nccl")
dist.barrier()
if global_rank != 0:
if version is None:
version = find_last_version(run_root_dir, cp=None)
cfg.output_dir = f"{run_root_dir}/version_{version}"
callbacks = get_callbacks(cfg)
has_ckpt_cb = any([isinstance(cb, Checkpoint) for cb in callbacks])
if not has_ckpt_cb and cfg.pl_trainer.get("enable_checkpointing", True):
Log.warning("No checkpoint-callback found. Disabling PL auto checkpointing.")
cfg.pl_trainer = {**cfg.pl_trainer, "enable_checkpointing": False}
if AutoResume is not None:
callbacks.append(AutoResumeCallback(version))
logger = tb_logger if tb_logger is not None else False
# PL-Trainer
if cfg.task == "test":
Log.info("Test mode forces full-precision.")
cfg.pl_trainer = {**cfg.pl_trainer, "precision": 32}
trainer = pl.Trainer(
accelerator="gpu",
logger=logger if logger is not None else False,
callbacks=callbacks,
**cfg.pl_trainer,
)
print("=" * 20)
print("version:", version)
if cfg.task == "fit":
resume_path = None
if cfg.resume_mode is not None:
save_dir = cfg.output_dir + "/checkpoints"
resume_path = get_resume_ckpt_path(cfg.resume_mode, ckpt_dir=save_dir)
Log.info("Start Fitting...")
trainer.fit(
model,
datamodule.train_dataloader(),
datamodule.val_dataloader(),
ckpt_path=resume_path,
)
elif cfg.task == "test":
Log.info("Start Testing...")
trainer.test(model, datamodule.test_dataloader())
else:
raise ValueError(f"Unknown task: {cfg.task}")
Log.info("End of script.")
@hydra.main(version_base="1.3", config_path="../configs", config_name="train")
def main(cfg) -> None:
print_cfg(cfg, use_rich=True)
train(cfg)
if __name__ == "__main__":
main()