File size: 3,176 Bytes
f34af6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import torch

import hydra
from omegaconf import DictConfig, OmegaConf

# Pytorch lightning imports
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

# from dataset.data import PdbDataModule
from dataset.classification_data import PdbDataModule
from models.classifier_wrapper_v2 import ClasfModule
from dataset.classification_data import get_dataloaders

from utils.experiments import get_pylogger, flatten_dict
import wandb

log = get_pylogger(__name__)
torch.set_float32_matmul_precision('high')


def _configure_wandb_env(exp_cfg: DictConfig):
    """Apply optional wandb environment overrides from YAML config."""
    wandb_cfg = exp_cfg.wandb
    if wandb_cfg.get("offline", False):
        os.environ["WANDB_MODE"] = "offline"
    wandb_dir = wandb_cfg.get("dir", None)
    if wandb_dir:
        os.environ["WANDB_DIR"] = wandb_dir


class ClassifierTrainer:
    def __init__(self, *, cfg: DictConfig):
        self._cfg = cfg
        self._data_cfg = cfg.data
        self._exp_cfg = cfg.experiment
        self._datamodule: LightningDataModule = PdbDataModule(self._data_cfg)
        # self.train_loader, self.val_loader = get_dataloaders(self._data_cfg)
        self._model: LightningModule = ClasfModule(self._cfg)
    
    def train(self):
        callbacks = []
        _configure_wandb_env(self._exp_cfg)
        logger = WandbLogger(
            **self._exp_cfg.wandb,
        )
        
        # Checkpoint directory
        ckpt_dir = self._exp_cfg.checkpointer.dirpath
        os.makedirs(ckpt_dir, exist_ok=True)
        log.info(f"Checkpoints saved to {ckpt_dir}")

        # Model Checkpoints
        callbacks.append(ModelCheckpoint(**self._exp_cfg.checkpointer))
        
        # Save config
        cfg_path = os.path.join(ckpt_dir, 'config.yaml')
        with open(cfg_path, 'w') as f:
            OmegaConf.save(config=self._cfg, f=f.name)
        cfg_dict = OmegaConf.to_container(self._cfg, resolve=True)
        flat_cfg = dict(flatten_dict(cfg_dict))
        if isinstance(logger.experiment.config, wandb.sdk.wandb_config.Config):
            logger.experiment.config.update(flat_cfg)
        
        num_devices = int(self._exp_cfg.get("num_devices", 1))
        log.info(f"Using {num_devices} device(s)")
        trainer = Trainer(
            **self._exp_cfg.trainer,
            callbacks=callbacks,
            logger=logger,
            use_distributed_sampler=False,
            enable_progress_bar=True,
            enable_model_summary=True,
            devices=num_devices,
        )
        trainer.fit(
            model=self._model,
            # train_dataloaders=self.train_loader,
            # val_dataloaders=self.val_loader,
            datamodule=self._datamodule,
            ckpt_path=self._exp_cfg.warm_start
        )

@hydra.main(version_base=None, config_path="./configs", config_name="classifier.yaml")
def main(cfg: DictConfig):
    exp = ClassifierTrainer(cfg=cfg)
    exp.train()

if __name__ == "__main__":
    main()