File size: 4,677 Bytes
c3a4f1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation

Official implementation of the paper:
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
Licensed under a modified MIT license
"""

from typing import Optional
import pyrootutils


root = pyrootutils.setup_root(
    search_from=__file__,
    indicator=[".git", "pyproject.toml"],
    pythonpath=True,
    dotenv=True,
)

import os
import sys

import hydra
import pytorch_lightning as pl
from omegaconf import DictConfig
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.callbacks import TQDMProgressBar
from tqdm import tqdm
from prima.datasets import DataModule
from prima.models.prima import PRIMA
from prima.utils.pylogger import get_pylogger
from prima.utils.misc import log_hyperparameters
import signal

signal.signal(signal.SIGUSR1, signal.SIG_DFL)


class MyTQDMProgressBar(TQDMProgressBar):

    def __init__(self):
        super(MyTQDMProgressBar, self).__init__()
        
    def init_train_tqdm(self):
        bar = super().init_train_tqdm()
        bar.ncols = 150
        bar.dynamic_ncols=False
        return bar

    def init_validation_tqdm(self):
        bar = tqdm(
            desc=self.validation_description,
            position=0,
            disable=self.is_disabled,
            leave=True,
            # dynamic_ncols=True,
            file=sys.stdout,
            dynamic_ncols= False,
            ncols = 150,
        )
        return bar


@hydra.main(version_base="1.2", config_path= "./configs_hydra", config_name="train.yaml")

def main(cfg: DictConfig) -> Optional[float]:
    datamodule = DataModule(cfg)
    model = PRIMA(cfg)

    # Setup Tensorboard logger
    logger = TensorBoardLogger(os.path.join(cfg.paths.output_dir, 'tensorboard'), name='', version='',
                               default_hp_metric=False)
    loggers = [logger]

    # Setup checkpoint saving
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=os.path.join(cfg.paths.output_dir, 'checkpoints'),
        # every_n_train_steps=cfg.GENERAL.CHECKPOINT_STEPS,
        every_n_epochs=cfg.GENERAL.CHECKPOINT_EPOCHS,
        save_last=True,
        # Monitor a metric so `save_top_k` keeps the best checkpoint instead of the last one.
    # We monitor the validation loss logged as 'val/loss' (lower is better).
        monitor='val/loss',
        mode='min',
        save_top_k=cfg.GENERAL.CHECKPOINT_SAVE_TOP_K,
        filename="best-{epoch:03d}-{val_loss:.4f}",  # Clearly label the best checkpoint
    )

    lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')
    callbacks = [
        checkpoint_callback,
        lr_monitor,
        # rich_callback
        MyTQDMProgressBar()
    ]

    log = get_pylogger(__name__)
    log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
    trainer: Trainer = hydra.utils.instantiate(
        cfg.trainer,
        callbacks=callbacks,
        logger=loggers,
        plugins=(SLURMEnvironment(requeue_signal=signal.SIGUSR2) if (cfg.get('launcher', None) is not None) else None),
        sync_batchnorm=True,
    )

    object_dict = {
        "cfg": cfg,
        "datamodule": datamodule,
        "model": model,
        "callbacks": callbacks,
        "logger": logger,
        "trainer": trainer,
    }

    if logger:
        log.info("Logging hyperparameters!")
        log_hyperparameters(object_dict)

    # Train the model
    # Determine checkpoint path
    ckpt_path = None
    last_v1_ckpt = os.path.join(cfg.paths.output_dir, 'checkpoints', 'last-v1.ckpt')
    last_ckpt = os.path.join(cfg.paths.output_dir, 'checkpoints', 'last.ckpt')
    
    if os.path.exists(last_v1_ckpt):
        ckpt_path = last_v1_ckpt
        log.info(f"Resuming from checkpoint: {ckpt_path}")
    elif os.path.exists(last_ckpt):
        ckpt_path = last_ckpt
        log.info(f"Resuming from checkpoint: {ckpt_path}")
    else:
        log.info("No checkpoint found, starting from scratch")

    trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
    log.info("Fitting done")


if __name__ == "__main__":
    import torch
    import gc
    
    gc.collect()
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        for i in range(torch.cuda.device_count()):
            print(f"GPU {i}: {torch.cuda.memory_allocated(i)/1024**2:.2f} MiB allocated, "
                f"{torch.cuda.memory_reserved(i)/1024**2:.2f} MiB reserved")
    main()