PRIMA-demo / train.py
mwmathis's picture
Deploy PRIMA Gradio app to Space
1800173 verified
"""
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
Official implementation of the paper:
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
Licensed under a modified MIT license
"""
from typing import Optional
import pyrootutils
root = pyrootutils.setup_root(
search_from=__file__,
indicator=[".git", "pyproject.toml"],
pythonpath=True,
dotenv=True,
)
import os
import sys
import hydra
import pytorch_lightning as pl
from omegaconf import DictConfig
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.callbacks import TQDMProgressBar
from tqdm import tqdm
from prima.datasets import DataModule
from prima.models.prima import PRIMA
from prima.utils.pylogger import get_pylogger
from prima.utils.misc import log_hyperparameters
import signal
signal.signal(signal.SIGUSR1, signal.SIG_DFL)
class MyTQDMProgressBar(TQDMProgressBar):
def __init__(self):
super(MyTQDMProgressBar, self).__init__()
def init_train_tqdm(self):
bar = super().init_train_tqdm()
bar.ncols = 150
bar.dynamic_ncols=False
return bar
def init_validation_tqdm(self):
bar = tqdm(
desc=self.validation_description,
position=0,
disable=self.is_disabled,
leave=True,
# dynamic_ncols=True,
file=sys.stdout,
dynamic_ncols= False,
ncols = 150,
)
return bar
@hydra.main(version_base="1.2", config_path= "./configs_hydra", config_name="train.yaml")
def main(cfg: DictConfig) -> Optional[float]:
datamodule = DataModule(cfg)
model = PRIMA(cfg)
# Setup Tensorboard logger
logger = TensorBoardLogger(os.path.join(cfg.paths.output_dir, 'tensorboard'), name='', version='',
default_hp_metric=False)
loggers = [logger]
# Setup checkpoint saving
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=os.path.join(cfg.paths.output_dir, 'checkpoints'),
# every_n_train_steps=cfg.GENERAL.CHECKPOINT_STEPS,
every_n_epochs=cfg.GENERAL.CHECKPOINT_EPOCHS,
save_last=True,
# Monitor a metric so `save_top_k` keeps the best checkpoint instead of the last one.
# We monitor the validation loss logged as 'val/loss' (lower is better).
monitor='val/loss',
mode='min',
save_top_k=cfg.GENERAL.CHECKPOINT_SAVE_TOP_K,
filename="best-{epoch:03d}-{val_loss:.4f}", # Clearly label the best checkpoint
)
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')
callbacks = [
checkpoint_callback,
lr_monitor,
# rich_callback
MyTQDMProgressBar()
]
log = get_pylogger(__name__)
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=loggers,
plugins=(SLURMEnvironment(requeue_signal=signal.SIGUSR2) if (cfg.get('launcher', None) is not None) else None),
sync_batchnorm=True,
)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
log_hyperparameters(object_dict)
# Train the model
# Determine checkpoint path
ckpt_path = None
last_v1_ckpt = os.path.join(cfg.paths.output_dir, 'checkpoints', 'last-v1.ckpt')
last_ckpt = os.path.join(cfg.paths.output_dir, 'checkpoints', 'last.ckpt')
if os.path.exists(last_v1_ckpt):
ckpt_path = last_v1_ckpt
log.info(f"Resuming from checkpoint: {ckpt_path}")
elif os.path.exists(last_ckpt):
ckpt_path = last_ckpt
log.info(f"Resuming from checkpoint: {ckpt_path}")
else:
log.info("No checkpoint found, starting from scratch")
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
log.info("Fitting done")
if __name__ == "__main__":
import torch
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
for i in range(torch.cuda.device_count()):
print(f"GPU {i}: {torch.cuda.memory_allocated(i)/1024**2:.2f} MiB allocated, "
f"{torch.cuda.memory_reserved(i)/1024**2:.2f} MiB reserved")
main()