#!/usr/bin/python3 import sys import os import torch import wandb import lightning.pytorch as pl from omegaconf import OmegaConf from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor from src.madsbm.wt_peptide.sbm_module import MadSBM from src.madsbm.wt_peptide.dataloader import PeptideDataModule, get_datasets wandb.login(key='2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f') # Load yaml config config = OmegaConf.load("/scratch/pranamlab/sgoel/MadSBM/configs/wt_pep.yaml") # Initialize WandB for logging wandb.init(project=config.wandb.project, name=config.wandb.name) wandb_logger = WandbLogger(**config.wandb) # PL checkpoints lr_monitor = LearningRateMonitor(logging_interval="step") every_epoch_cb = ModelCheckpoint( dirpath=config.checkpointing.save_dir, filename="{epoch:02d}_{step}", save_top_k=-1, every_n_epochs=1, save_on_train_epoch_end=True, verbose=True, ) best_ckpt_cb = ModelCheckpoint( monitor="val/loss", dirpath=config.checkpointing.save_dir, filename="best-model_{epoch:02d}_{step}", save_top_k=1, mode="min", verbose=True, save_last=False, ) # PL trainer trainer = pl.Trainer( #max_steps=None, # Ensure training is based on epochs so we can compare with MOG-DFM and DirichletFM max_epochs=config.training.n_epochs, accelerator="cuda" if torch.cuda.is_available() else "cpu", devices=config.training.devices if config.training.mode=='train' else [0], strategy=DDPStrategy(find_unused_parameters=True), callbacks=[every_epoch_cb, best_ckpt_cb, lr_monitor], logger=wandb_logger ) # Folder to save checkpoints ckpt_path = config.checkpointing.save_dir try: os.makedirs(ckpt_path, exist_ok=False) except FileExistsError: pass # PL Model for training sbm_model = MadSBM(config) sbm_model.validate_config() # Get datasets datasets = get_datasets(config) data_module = PeptideDataModule( config=config, train_dataset=datasets['train'], val_dataset=datasets['val'], test_dataset=datasets['test'], tokenizer=sbm_model.tokenizer, ) # Start/resume training or evaluate the model if config.training.mode == "train": trainer.fit(sbm_model, datamodule=data_module) elif config.training.mode == "test": state_dict = sbm_model.get_state_dict(config.checkpointing.best_ckpt_path) sbm_model.load_state_dict(state_dict) trainer.test(sbm_model, datamodule=data_module, ckpt_path=config.checkpointing.best_ckpt_path) elif config.training.mode == "resume_from_checkpoint": state_dict = sbm_model.get_state_dict(config.training.resume_ckpt_path) sbm_model.load_state_dict(state_dict) trainer.fit(sbm_model, datamodule=data_module, ckpt_path=ckpt_path) wandb.finish()