File size: 4,677 Bytes
c3a4f1d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | """
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
Official implementation of the paper:
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
Licensed under a modified MIT license
"""
from typing import Optional
import pyrootutils
root = pyrootutils.setup_root(
search_from=__file__,
indicator=[".git", "pyproject.toml"],
pythonpath=True,
dotenv=True,
)
import os
import sys
import hydra
import pytorch_lightning as pl
from omegaconf import DictConfig
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.callbacks import TQDMProgressBar
from tqdm import tqdm
from prima.datasets import DataModule
from prima.models.prima import PRIMA
from prima.utils.pylogger import get_pylogger
from prima.utils.misc import log_hyperparameters
import signal
signal.signal(signal.SIGUSR1, signal.SIG_DFL)
class MyTQDMProgressBar(TQDMProgressBar):
def __init__(self):
super(MyTQDMProgressBar, self).__init__()
def init_train_tqdm(self):
bar = super().init_train_tqdm()
bar.ncols = 150
bar.dynamic_ncols=False
return bar
def init_validation_tqdm(self):
bar = tqdm(
desc=self.validation_description,
position=0,
disable=self.is_disabled,
leave=True,
# dynamic_ncols=True,
file=sys.stdout,
dynamic_ncols= False,
ncols = 150,
)
return bar
@hydra.main(version_base="1.2", config_path= "./configs_hydra", config_name="train.yaml")
def main(cfg: DictConfig) -> Optional[float]:
datamodule = DataModule(cfg)
model = PRIMA(cfg)
# Setup Tensorboard logger
logger = TensorBoardLogger(os.path.join(cfg.paths.output_dir, 'tensorboard'), name='', version='',
default_hp_metric=False)
loggers = [logger]
# Setup checkpoint saving
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=os.path.join(cfg.paths.output_dir, 'checkpoints'),
# every_n_train_steps=cfg.GENERAL.CHECKPOINT_STEPS,
every_n_epochs=cfg.GENERAL.CHECKPOINT_EPOCHS,
save_last=True,
# Monitor a metric so `save_top_k` keeps the best checkpoint instead of the last one.
# We monitor the validation loss logged as 'val/loss' (lower is better).
monitor='val/loss',
mode='min',
save_top_k=cfg.GENERAL.CHECKPOINT_SAVE_TOP_K,
filename="best-{epoch:03d}-{val_loss:.4f}", # Clearly label the best checkpoint
)
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')
callbacks = [
checkpoint_callback,
lr_monitor,
# rich_callback
MyTQDMProgressBar()
]
log = get_pylogger(__name__)
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=loggers,
plugins=(SLURMEnvironment(requeue_signal=signal.SIGUSR2) if (cfg.get('launcher', None) is not None) else None),
sync_batchnorm=True,
)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
log_hyperparameters(object_dict)
# Train the model
# Determine checkpoint path
ckpt_path = None
last_v1_ckpt = os.path.join(cfg.paths.output_dir, 'checkpoints', 'last-v1.ckpt')
last_ckpt = os.path.join(cfg.paths.output_dir, 'checkpoints', 'last.ckpt')
if os.path.exists(last_v1_ckpt):
ckpt_path = last_v1_ckpt
log.info(f"Resuming from checkpoint: {ckpt_path}")
elif os.path.exists(last_ckpt):
ckpt_path = last_ckpt
log.info(f"Resuming from checkpoint: {ckpt_path}")
else:
log.info("No checkpoint found, starting from scratch")
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
log.info("Fitting done")
if __name__ == "__main__":
import torch
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
for i in range(torch.cuda.device_count()):
print(f"GPU {i}: {torch.cuda.memory_allocated(i)/1024**2:.2f} MiB allocated, "
f"{torch.cuda.memory_reserved(i)/1024**2:.2f} MiB reserved")
main()
|