| | import pytorch_lightning as L |
| | from pytorch_lightning.strategies import DDPStrategy |
| | from pytorch_lightning.callbacks import ModelCheckpoint |
| | import config |
| | from data_loader import get_dataloaders |
| | from esm_utils import load_esm2_model |
| | from diffusion import Diffusion |
| | import wandb |
| | import sys |
| |
|
| | |
| | train_loader, val_loader, _ = get_dataloaders(config) |
| |
|
| | |
| | tokenizer, _, _ = load_esm2_model(config.MODEL_NAME) |
| |
|
| | |
| | latent_diffusion_model = Diffusion(config, latent_dim=config.LATENT_DIM, tokenizer=tokenizer) |
| | print(latent_diffusion_model) |
| | sys.stdout.flush() |
| |
|
| | |
| | checkpoint_callback = ModelCheckpoint( |
| | monitor='val_loss', |
| | save_top_k=1, |
| | mode='min', |
| | dirpath="/workspace/a03-sgoel/MDpLM/", |
| | filename="best_model_epoch{epoch:02d}" |
| | ) |
| |
|
| | |
| | trainer = L.Trainer( |
| | max_epochs=config.Training.NUM_EPOCHS, |
| | precision=config.Training.PRECISION, |
| | devices=1, |
| | accelerator='gpu', |
| | strategy=DDPStrategy(find_unused_parameters=False), |
| | accumulate_grad_batches=config.Training.ACCUMULATE_GRAD_BATCHES, |
| | default_root_dir=config.Training.SAVE_DIR, |
| | callbacks=[checkpoint_callback] |
| | ) |
| |
|
| | print(trainer) |
| | print("Training model...") |
| | sys.stdout.flush() |
| |
|
| | |
| | trainer.fit(latent_diffusion_model, train_loader, val_loader) |
| |
|
| | wandb.finish() |
| |
|