Spaces:
Sleeping
Sleeping
File size: 3,650 Bytes
96da58e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
from typing import Optional, Tuple
import pyrootutils
root = pyrootutils.setup_root(
search_from=__file__,
indicator=[".git", "pyproject.toml"],
pythonpath=True,
dotenv=True,
)
import os
from pathlib import Path
import hydra
import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins.environments import SLURMEnvironment
#from pytorch_lightning.trainingtype import DDPPlugin
from yacs.config import CfgNode
from hamer.configs import dataset_config
from hamer.datasets import HAMERDataModule
from hamer.models.hamer import HAMER
from hamer.utils.pylogger import get_pylogger
from hamer.utils.misc import task_wrapper, log_hyperparameters
# HACK reset the signal handling so the lightning is free to set it
# Based on https://github.com/facebookincubator/submitit/issues/1709#issuecomment-1246758283
import signal
signal.signal(signal.SIGUSR1, signal.SIG_DFL)
log = get_pylogger(__name__)
@pl.utilities.rank_zero.rank_zero_only
def save_configs(model_cfg: CfgNode, dataset_cfg: CfgNode, rootdir: str):
"""Save config files to rootdir."""
Path(rootdir).mkdir(parents=True, exist_ok=True)
OmegaConf.save(config=model_cfg, f=os.path.join(rootdir, 'model_config.yaml'))
with open(os.path.join(rootdir, 'dataset_config.yaml'), 'w') as f:
f.write(dataset_cfg.dump())
@task_wrapper
def train(cfg: DictConfig) -> Tuple[dict, dict]:
# Load dataset config
dataset_cfg = dataset_config()
# Save configs
save_configs(cfg, dataset_cfg, cfg.paths.output_dir)
# Setup training and validation datasets
datamodule = HAMERDataModule(cfg, dataset_cfg)
# Setup model
model = HAMER(cfg)
# Setup Tensorboard logger
logger = TensorBoardLogger(os.path.join(cfg.paths.output_dir, 'tensorboard'), name='', version='', default_hp_metric=False)
loggers = [logger]
# Setup checkpoint saving
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=os.path.join(cfg.paths.output_dir, 'checkpoints'),
every_n_train_steps=cfg.GENERAL.CHECKPOINT_STEPS,
save_last=True,
save_top_k=cfg.GENERAL.CHECKPOINT_SAVE_TOP_K,
)
rich_callback = pl.callbacks.RichProgressBar()
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')
callbacks = [
checkpoint_callback,
lr_monitor,
# rich_callback
]
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=loggers,
#plugins=(SLURMEnvironment(requeue_signal=signal.SIGUSR2) if (cfg.get('launcher',None) is not None) else DDPPlugin(find_unused_parameters=False)), # Submitit uses SIGUSR2
plugins=(SLURMEnvironment(requeue_signal=signal.SIGUSR2) if (cfg.get('launcher',None) is not None) else None), # Submitit uses SIGUSR2
)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
log_hyperparameters(object_dict)
# Train the model
trainer.fit(model, datamodule=datamodule, ckpt_path='last')
log.info("Fitting done")
@hydra.main(version_base="1.2", config_path=str(root/"hamer/configs_hydra"), config_name="train.yaml")
def main(cfg: DictConfig) -> Optional[float]:
# train the model
train(cfg)
if __name__ == "__main__":
main()
|