Loupe / src /train.py
xxwyyds's picture
Upload 86 files
891e05c verified
import os
import sys
import hydra
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import (
LearningRateMonitor,
RichProgressBar,
ModelCheckpoint,
)
from pytorch_lightning.loggers import TensorBoardLogger
from omegaconf import DictConfig
from pathlib import Path
from models.loupe import LoupeModel, LoupeConfig
from data_module import DataModule
from lit_model import LitModel
from utils import convert_deepspeed_checkpoint
sys.path.insert(0, ".")
project_root = Path(__file__).resolve().parent.parent
save_path = os.environ.get("CATTINO_TASK_HOME", f"{project_root}/results")
@hydra.main(
config_path=str(project_root / "configs"),
config_name="train",
version_base=None,
)
def main(cfg: DictConfig):
pl.seed_everything(cfg.seed)
if cfg.stage.name == "test":
raise ValueError(
"This script is for training only. Please use one of cls, seg, or cls_seg stages."
)
checkpoint_callback = hydra.utils.instantiate(
cfg.ckpt.saver,
dirpath=os.path.join(save_path, "checkpoints"),
)
logger = TensorBoardLogger(
save_dir=save_path,
name=cfg.stage.name,
default_hp_metric=False,
)
trainer_overrides = dict(
devices=1 if cfg.trainer.fast_dev_run else "auto",
max_epochs=cfg.hparams.epoch,
val_check_interval=0.05,
log_every_n_steps=2,
accumulate_grad_batches=cfg.hparams.accumulate_grad_batches,
gradient_clip_val=cfg.hparams.grad_clip_val,
)
trainer_overrides.update(cfg.trainer)
trainer = pl.Trainer(
logger=logger,
callbacks=[LearningRateMonitor(), RichProgressBar(), checkpoint_callback],
**trainer_overrides,
)
torch.set_float32_matmul_precision("medium")
loupe_config = LoupeConfig(stage=cfg.stage.name, **cfg.model)
loupe = LoupeModel(loupe_config)
model = LitModel(cfg, loupe)
data_module = DataModule(cfg, loupe_config)
trainer.fit(model, data_module)
if "deepspeed" in cfg.trainer.get("strategy", "") and trainer_overrides.get(
"enable_checkpointing", False
):
convert_deepspeed_checkpoint(
cfg,
checkpoint_callback,
os.path.join(
project_root,
"checkpoints",
os.path.basename(checkpoint_callback.best_model_path),
),
)
if __name__ == "__main__":
main()