MadSBM / src /madsbm /wt_peptide /main.py
Shrey Goel
initial commit
94c2704
#!/usr/bin/python3
import sys
import os
import torch
import wandb
import lightning.pytorch as pl
from omegaconf import OmegaConf
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from src.madsbm.wt_peptide.sbm_module import MadSBM
from src.madsbm.wt_peptide.dataloader import PeptideDataModule, get_datasets
wandb.login(key='2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f')
# Load yaml config
config = OmegaConf.load("/scratch/pranamlab/sgoel/MadSBM/configs/wt_pep.yaml")
# Initialize WandB for logging
wandb.init(project=config.wandb.project, name=config.wandb.name)
wandb_logger = WandbLogger(**config.wandb)
# PL checkpoints
lr_monitor = LearningRateMonitor(logging_interval="step")
every_epoch_cb = ModelCheckpoint(
dirpath=config.checkpointing.save_dir,
filename="{epoch:02d}_{step}",
save_top_k=-1,
every_n_epochs=1,
save_on_train_epoch_end=True,
verbose=True,
)
best_ckpt_cb = ModelCheckpoint(
monitor="val/loss",
dirpath=config.checkpointing.save_dir,
filename="best-model_{epoch:02d}_{step}",
save_top_k=1,
mode="min",
verbose=True,
save_last=False,
)
# PL trainer
trainer = pl.Trainer(
#max_steps=None, # Ensure training is based on epochs so we can compare with MOG-DFM and DirichletFM
max_epochs=config.training.n_epochs,
accelerator="cuda" if torch.cuda.is_available() else "cpu",
devices=config.training.devices if config.training.mode=='train' else [0],
strategy=DDPStrategy(find_unused_parameters=True),
callbacks=[every_epoch_cb, best_ckpt_cb, lr_monitor],
logger=wandb_logger
)
# Folder to save checkpoints
ckpt_path = config.checkpointing.save_dir
try: os.makedirs(ckpt_path, exist_ok=False)
except FileExistsError: pass
# PL Model for training
sbm_model = MadSBM(config)
sbm_model.validate_config()
# Get datasets
datasets = get_datasets(config)
data_module = PeptideDataModule(
config=config,
train_dataset=datasets['train'],
val_dataset=datasets['val'],
test_dataset=datasets['test'],
tokenizer=sbm_model.tokenizer,
)
# Start/resume training or evaluate the model
if config.training.mode == "train":
trainer.fit(sbm_model, datamodule=data_module)
elif config.training.mode == "test":
state_dict = sbm_model.get_state_dict(config.checkpointing.best_ckpt_path)
sbm_model.load_state_dict(state_dict)
trainer.test(sbm_model, datamodule=data_module, ckpt_path=config.checkpointing.best_ckpt_path)
elif config.training.mode == "resume_from_checkpoint":
state_dict = sbm_model.get_state_dict(config.training.resume_ckpt_path)
sbm_model.load_state_dict(state_dict)
trainer.fit(sbm_model, datamodule=data_module, ckpt_path=ckpt_path)
wandb.finish()