|
|
import builtins |
|
|
import os |
|
|
from datetime import datetime |
|
|
|
|
|
import hydra |
|
|
import pytorch_lightning as pl |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import yaml |
|
|
from omegaconf import DictConfig, ListConfig, OmegaConf |
|
|
from pytorch_lightning.callbacks.checkpoint import Checkpoint |
|
|
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger |
|
|
|
|
|
from genmo.callbacks.autoresume_callback import AutoResume, AutoResumeCallback |
|
|
from genmo.utils.net_utils import get_resume_ckpt_path, load_pretrained_model |
|
|
from genmo.utils.pylogger import Log |
|
|
from genmo.utils.tools import ( |
|
|
find_last_version, |
|
|
get_checkpoint_path, |
|
|
rsync_file_from_remote, |
|
|
) |
|
|
from genmo.utils.vis.rich_logger import print_cfg |
|
|
|
|
|
OmegaConf.register_new_resolver("eval", builtins.eval) |
|
|
|
|
|
|
|
|
def _get_rank(): |
|
|
|
|
|
|
|
|
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") |
|
|
for key in rank_keys: |
|
|
rank = os.environ.get(key) |
|
|
if rank is not None: |
|
|
return int(rank) |
|
|
|
|
|
return 0 |
|
|
|
|
|
|
|
|
global_rank = _get_rank() |
|
|
|
|
|
|
|
|
def get_callbacks(cfg: DictConfig) -> list: |
|
|
"""Parse and instantiate all the callbacks in the config. |
|
|
|
|
|
Supports both flat and nested callback configs. Only nodes containing |
|
|
a `_target_` are instantiated. |
|
|
""" |
|
|
if not hasattr(cfg, "callbacks") or cfg.callbacks is None: |
|
|
return None |
|
|
|
|
|
def _collect_callback_nodes(node): |
|
|
collected = [] |
|
|
if node is None: |
|
|
return collected |
|
|
|
|
|
if isinstance(node, (DictConfig, dict)): |
|
|
|
|
|
if "_target_" in node: |
|
|
collected.append(node) |
|
|
else: |
|
|
for child in node.values(): |
|
|
collected.extend(_collect_callback_nodes(child)) |
|
|
|
|
|
elif isinstance(node, (ListConfig, list, tuple)): |
|
|
for child in node: |
|
|
collected.extend(_collect_callback_nodes(child)) |
|
|
|
|
|
return collected |
|
|
|
|
|
enable_checkpointing = cfg.pl_trainer.get("enable_checkpointing", True) |
|
|
callbacks = [] |
|
|
for cb_conf in _collect_callback_nodes(cfg.callbacks): |
|
|
cb = hydra.utils.instantiate(cb_conf, _recursive_=False) |
|
|
if not enable_checkpointing and isinstance(cb, Checkpoint): |
|
|
continue |
|
|
callbacks.append(cb) |
|
|
return callbacks |
|
|
|
|
|
|
|
|
def train(cfg: DictConfig) -> None: |
|
|
"""Train/Test""" |
|
|
Log.info(f"[Exp Name]: {cfg.exp_name}") |
|
|
|
|
|
if cfg.task == "fit": |
|
|
Log.info( |
|
|
f"[GPU x Batch] = {cfg.pl_trainer.devices} x {cfg.data.loader_opts.train.batch_size}" |
|
|
) |
|
|
num_nodes = cfg.pl_trainer.get("num_nodes", 1) |
|
|
cfg.num_test_data *= cfg.pl_trainer.devices * num_nodes |
|
|
if ( |
|
|
"imgfeat_motionx" in cfg.test_datasets |
|
|
and "max_num_motions" in cfg.test_datasets.imgfeat_motionx |
|
|
): |
|
|
cfg.test_datasets.imgfeat_motionx.max_num_motions *= ( |
|
|
cfg.pl_trainer.devices * num_nodes |
|
|
) |
|
|
pl.seed_everything(cfg.seed) |
|
|
torch.cuda.set_device(global_rank % 8) |
|
|
version = None |
|
|
tb_logger = None |
|
|
|
|
|
if cfg.get("timing", False): |
|
|
os.environ["DEBUG_TIMING"] = "TRUE" |
|
|
|
|
|
if AutoResume is not None: |
|
|
details = AutoResume.get_resume_details() |
|
|
if details: |
|
|
cfg.resume_mode = "last" |
|
|
version = int(details["version"]) |
|
|
print( |
|
|
f"[Auto Resume] Loading. checkpoint: {details['checkpoint']} version: {details['version']}" |
|
|
) |
|
|
|
|
|
if cfg.task == "test" and not cfg.get("no_checkpoint", False): |
|
|
test_cp = cfg.get("test_checkpoint", "last") |
|
|
remote_run_dir = cfg.output_dir.replace("outputs", cfg.remote_results_path) |
|
|
version = find_last_version(remote_run_dir, cp=test_cp) |
|
|
checkpoint_dir = f"{remote_run_dir}/version_{version}/checkpoints" |
|
|
remote_ckpt_path = get_checkpoint_path(checkpoint_dir, test_cp) |
|
|
if cfg.get("rsync_ckpt", False): |
|
|
cfg.ckpt_path = remote_ckpt_path.replace(cfg.remote_results_path, "outputs") |
|
|
if not os.path.exists(cfg.ckpt_path): |
|
|
print(f"rsyncing from remote: {remote_ckpt_path}") |
|
|
print(f"output_dir: {cfg.output_dir}") |
|
|
rsync_file_from_remote( |
|
|
cfg.ckpt_path, |
|
|
remote_run_dir, |
|
|
cfg.output_dir, |
|
|
hostname="cs-oci-ord-dc-03", |
|
|
) |
|
|
else: |
|
|
cfg.ckpt_path = remote_ckpt_path |
|
|
print("ckpt path:", cfg.ckpt_path) |
|
|
cfg.output_dir = f"{cfg.output_dir}/version_{version}" |
|
|
cfg.logger.name = ( |
|
|
f"{cfg.exp_name}_v{version}_{datetime.now().strftime('%Y%m%d%H%M%S')}" |
|
|
) |
|
|
else: |
|
|
run_root_dir = cfg.output_dir |
|
|
if version is None and cfg.resume_mode == "last": |
|
|
version = find_last_version(run_root_dir, cp="last") |
|
|
|
|
|
|
|
|
datamodule: pl.LightningDataModule = hydra.utils.instantiate( |
|
|
cfg.data, _recursive_=False |
|
|
) |
|
|
model: pl.LightningModule = hydra.utils.instantiate(cfg.model, _recursive_=False) |
|
|
|
|
|
if ( |
|
|
cfg.get("pretrain_ckpt", None) is not None |
|
|
and cfg.ckpt_path is None |
|
|
and cfg.resume_mode is None |
|
|
): |
|
|
cfg.ckpt_path = cfg.pretrain_ckpt |
|
|
|
|
|
if cfg.ckpt_path is not None: |
|
|
if cfg.get("rsync_ckpt", False) and not os.path.exists(cfg.ckpt_path): |
|
|
print(f"rsyncing from remote: {cfg.ckpt_path}") |
|
|
cfg.ckpt_path = cfg.ckpt_path.replace(cfg.remote_results_path, "outputs") |
|
|
local_dir = cfg.ckpt_path.split("/version_")[0] |
|
|
os.makedirs(local_dir, exist_ok=True) |
|
|
rsync_file_from_remote( |
|
|
cfg.ckpt_path, |
|
|
cfg.remote_results_path, |
|
|
"outputs", |
|
|
hostname="cs-oci-ord-dc-03", |
|
|
) |
|
|
|
|
|
ckpt = load_pretrained_model(model, cfg.ckpt_path) |
|
|
print(f"Loaded pretrained model from {cfg.ckpt_path}") |
|
|
if ckpt is not None: |
|
|
print( |
|
|
"pretrained ckpt info:", |
|
|
{"global_step": ckpt["global_step"], "epoch": ckpt["epoch"]}, |
|
|
) |
|
|
|
|
|
|
|
|
if cfg.task == "fit": |
|
|
tb_logger = TensorBoardLogger(run_root_dir, version=version, name="") |
|
|
version = tb_logger.version |
|
|
if global_rank == 0: |
|
|
os.makedirs(tb_logger.log_dir, exist_ok=True) |
|
|
cfg.output_dir = tb_logger.log_dir |
|
|
|
|
|
if cfg.pl_trainer.devices > 1 and "RANK" in os.environ: |
|
|
dist.init_process_group("nccl") |
|
|
dist.barrier() |
|
|
|
|
|
if global_rank != 0: |
|
|
if version is None: |
|
|
version = find_last_version(run_root_dir, cp=None) |
|
|
cfg.output_dir = f"{run_root_dir}/version_{version}" |
|
|
|
|
|
callbacks = get_callbacks(cfg) |
|
|
has_ckpt_cb = any([isinstance(cb, Checkpoint) for cb in callbacks]) |
|
|
if not has_ckpt_cb and cfg.pl_trainer.get("enable_checkpointing", True): |
|
|
Log.warning("No checkpoint-callback found. Disabling PL auto checkpointing.") |
|
|
cfg.pl_trainer = {**cfg.pl_trainer, "enable_checkpointing": False} |
|
|
if AutoResume is not None: |
|
|
callbacks.append(AutoResumeCallback(version)) |
|
|
|
|
|
logger = tb_logger if tb_logger is not None else False |
|
|
|
|
|
|
|
|
if cfg.task == "test": |
|
|
Log.info("Test mode forces full-precision.") |
|
|
cfg.pl_trainer = {**cfg.pl_trainer, "precision": 32} |
|
|
trainer = pl.Trainer( |
|
|
accelerator="gpu", |
|
|
logger=logger if logger is not None else False, |
|
|
callbacks=callbacks, |
|
|
**cfg.pl_trainer, |
|
|
) |
|
|
|
|
|
print("=" * 20) |
|
|
print("version:", version) |
|
|
|
|
|
if cfg.task == "fit": |
|
|
resume_path = None |
|
|
if cfg.resume_mode is not None: |
|
|
save_dir = cfg.output_dir + "/checkpoints" |
|
|
resume_path = get_resume_ckpt_path(cfg.resume_mode, ckpt_dir=save_dir) |
|
|
Log.info("Start Fitting...") |
|
|
trainer.fit( |
|
|
model, |
|
|
datamodule.train_dataloader(), |
|
|
datamodule.val_dataloader(), |
|
|
ckpt_path=resume_path, |
|
|
) |
|
|
elif cfg.task == "test": |
|
|
Log.info("Start Testing...") |
|
|
trainer.test(model, datamodule.test_dataloader()) |
|
|
else: |
|
|
raise ValueError(f"Unknown task: {cfg.task}") |
|
|
|
|
|
Log.info("End of script.") |
|
|
|
|
|
|
|
|
@hydra.main(version_base="1.3", config_path="../configs", config_name="train") |
|
|
def main(cfg) -> None: |
|
|
print_cfg(cfg, use_rich=True) |
|
|
train(cfg) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|