| | import os |
| | from pathlib import Path |
| |
|
| | import hydra |
| | import torch |
| | import wandb |
| | import random |
| | from colorama import Fore |
| | from jaxtyping import install_import_hook |
| | from lightning.pytorch import Trainer |
| | from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint |
| | from lightning.pytorch.loggers.wandb import WandbLogger |
| | from lightning.pytorch.plugins.environments import SLURMEnvironment |
| | from lightning.pytorch.strategies import DeepSpeedStrategy |
| | from omegaconf import DictConfig, OmegaConf |
| | from hydra.core.hydra_config import HydraConfig |
| |
|
| | import sys |
| | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| | from src.model.model import get_model |
| | from src.misc.weight_modify import checkpoint_filter_fn |
| |
|
| | import warnings |
| | warnings.filterwarnings("ignore") |
| |
|
| | |
| | with install_import_hook( |
| | ("src",), |
| | ("beartype", "beartype"), |
| | ): |
| | from src.config import load_typed_root_config |
| | from src.dataset.data_module import DataModule |
| | from src.global_cfg import set_cfg |
| | from src.loss import get_losses |
| | from src.misc.LocalLogger import LocalLogger |
| | from src.misc.step_tracker import StepTracker |
| | from src.misc.wandb_tools import update_checkpoint_path |
| | from src.model.decoder import get_decoder |
| | from src.model.encoder import get_encoder |
| | from src.model.model_wrapper import ModelWrapper |
| |
|
| |
|
| | def cyan(text: str) -> str: |
| | return f"{Fore.CYAN}{text}{Fore.RESET}" |
| |
|
| |
|
| | @hydra.main( |
| | version_base=None, |
| | config_path="../config", |
| | config_name="main", |
| | ) |
| | def train(cfg_dict: DictConfig): |
| | cfg = load_typed_root_config(cfg_dict) |
| | set_cfg(cfg_dict) |
| | |
| | |
| | output_dir = Path( |
| | hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"] |
| | ) |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| | print(cyan(f"Saving outputs to {output_dir}.")) |
| | |
| | cfg.train.output_path = output_dir |
| | |
| | |
| | callbacks = [] |
| | if cfg_dict.wandb.mode != "disabled": |
| | logger = WandbLogger( |
| | project=cfg_dict.wandb.project, |
| | mode=cfg_dict.wandb.mode, |
| | name=f"{cfg_dict.wandb.name} ({output_dir.parent.name}/{output_dir.name})", |
| | tags=cfg_dict.wandb.get("tags", None), |
| | log_model=False, |
| | save_dir=output_dir, |
| | config=OmegaConf.to_container(cfg_dict), |
| | ) |
| | callbacks.append(LearningRateMonitor("step", True)) |
| | |
| | |
| | if wandb.run is not None: |
| | wandb.run.log_code("src") |
| | else: |
| | logger = LocalLogger() |
| | |
| | |
| | callbacks.append( |
| | ModelCheckpoint( |
| | output_dir / "checkpoints", |
| | every_n_train_steps=cfg.checkpointing.every_n_train_steps, |
| | save_top_k=cfg.checkpointing.save_top_k, |
| | save_weights_only=cfg.checkpointing.save_weights_only, |
| | monitor="info/global_step", |
| | mode="max", |
| | ) |
| | ) |
| | callbacks[-1].CHECKPOINT_EQUALS_CHAR = '_' |
| | |
| | |
| | checkpoint_path = update_checkpoint_path(cfg.checkpointing.load, cfg.wandb) |
| | |
| | |
| | step_tracker = StepTracker() |
| | |
| | trainer = Trainer( |
| | max_epochs=-1, |
| | num_nodes=cfg.trainer.num_nodes, |
| | |
| | accelerator="gpu", |
| | logger=logger, |
| | devices="auto", |
| | strategy=( |
| | "ddp_find_unused_parameters_true" |
| | if torch.cuda.device_count() > 1 |
| | else "auto" |
| | ), |
| | |
| | callbacks=callbacks, |
| | val_check_interval=cfg.trainer.val_check_interval, |
| | check_val_every_n_epoch=None, |
| | enable_progress_bar=False, |
| | gradient_clip_val=cfg.trainer.gradient_clip_val, |
| | max_steps=cfg.trainer.max_steps, |
| | precision=cfg.trainer.precision, |
| | accumulate_grad_batches=cfg.trainer.accumulate_grad_batches, |
| | |
| | inference_mode=False if (cfg.mode == "test" and cfg.test.align_pose) else True, |
| | ) |
| | torch.manual_seed(cfg_dict.seed + trainer.global_rank) |
| | |
| | model = get_model(cfg.model.encoder, cfg.model.decoder) |
| | |
| | model_wrapper = ModelWrapper( |
| | cfg.optimizer, |
| | cfg.test, |
| | cfg.train, |
| | model, |
| | get_losses(cfg.loss), |
| | step_tracker |
| | ) |
| | data_module = DataModule( |
| | cfg.dataset, |
| | cfg.data_loader, |
| | step_tracker, |
| | global_rank=trainer.global_rank, |
| | ) |
| | |
| | if cfg.mode == "train": |
| | trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=checkpoint_path) |
| | else: |
| | trainer.test( |
| | model_wrapper, |
| | datamodule=data_module, |
| | ckpt_path=checkpoint_path, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | train() |
| |
|