| |
|
|
| import sys |
| import os |
| import torch |
| import wandb |
| import lightning.pytorch as pl |
|
|
| from omegaconf import OmegaConf |
| from lightning.pytorch.strategies import DDPStrategy |
| from lightning.pytorch.loggers import WandbLogger |
| from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor |
|
|
|
|
| from src.madsbm.wt_peptide.sbm_module import MadSBM |
| from src.madsbm.wt_peptide.dataloader import PeptideDataModule, get_datasets |
|
|
| wandb.login(key='2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f') |
|
|
|
|
| |
| config = OmegaConf.load("/scratch/pranamlab/sgoel/MadSBM/configs/wt_pep.yaml") |
|
|
| |
| wandb.init(project=config.wandb.project, name=config.wandb.name) |
| wandb_logger = WandbLogger(**config.wandb) |
|
|
| |
| lr_monitor = LearningRateMonitor(logging_interval="step") |
|
|
| every_epoch_cb = ModelCheckpoint( |
| dirpath=config.checkpointing.save_dir, |
| filename="{epoch:02d}_{step}", |
| save_top_k=-1, |
| every_n_epochs=1, |
| save_on_train_epoch_end=True, |
| verbose=True, |
| ) |
|
|
| best_ckpt_cb = ModelCheckpoint( |
| monitor="val/loss", |
| dirpath=config.checkpointing.save_dir, |
| filename="best-model_{epoch:02d}_{step}", |
| save_top_k=1, |
| mode="min", |
| verbose=True, |
| save_last=False, |
| ) |
|
|
| |
| trainer = pl.Trainer( |
| |
| max_epochs=config.training.n_epochs, |
| accelerator="cuda" if torch.cuda.is_available() else "cpu", |
| devices=config.training.devices if config.training.mode=='train' else [0], |
| strategy=DDPStrategy(find_unused_parameters=True), |
| callbacks=[every_epoch_cb, best_ckpt_cb, lr_monitor], |
| logger=wandb_logger |
| ) |
|
|
|
|
| |
| ckpt_path = config.checkpointing.save_dir |
| try: os.makedirs(ckpt_path, exist_ok=False) |
| except FileExistsError: pass |
|
|
| |
| sbm_model = MadSBM(config) |
| sbm_model.validate_config() |
|
|
| |
| datasets = get_datasets(config) |
| data_module = PeptideDataModule( |
| config=config, |
| train_dataset=datasets['train'], |
| val_dataset=datasets['val'], |
| test_dataset=datasets['test'], |
| tokenizer=sbm_model.tokenizer, |
| ) |
|
|
| |
| if config.training.mode == "train": |
| trainer.fit(sbm_model, datamodule=data_module) |
|
|
| elif config.training.mode == "test": |
| state_dict = sbm_model.get_state_dict(config.checkpointing.best_ckpt_path) |
| sbm_model.load_state_dict(state_dict) |
| trainer.test(sbm_model, datamodule=data_module, ckpt_path=config.checkpointing.best_ckpt_path) |
|
|
| elif config.training.mode == "resume_from_checkpoint": |
| state_dict = sbm_model.get_state_dict(config.training.resume_ckpt_path) |
| sbm_model.load_state_dict(state_dict) |
| trainer.fit(sbm_model, datamodule=data_module, ckpt_path=ckpt_path) |
|
|
| wandb.finish() |
|
|
|
|