| import lightning as L | |
| import torch | |
| from lightning.pytorch.callbacks import ( | |
| ModelCheckpoint, | |
| LearningRateMonitor, | |
| EarlyStopping, | |
| ) | |
| from lightning.pytorch.loggers import TensorBoardLogger | |
| from src.dataset import DRDataModule | |
| from src.model import DRModel | |
| # seed everything for reproducibility | |
| SEED = 42 | |
| L.seed_everything(SEED, workers=True) | |
| torch.set_float32_matmul_precision("high") | |
| # Init DataModule | |
| dm = DRDataModule(batch_size=128, num_workers=24) | |
| dm.setup() | |
| # Init model from datamodule's attributes | |
| model = DRModel( | |
| num_classes=dm.num_classes, learning_rate=3e-4, class_weights=dm.class_weights | |
| ) | |
| # Init logger | |
| logger = TensorBoardLogger(save_dir="artifacts") | |
| # Init callbacks | |
| checkpoint_callback = ModelCheckpoint( | |
| monitor="val_loss", | |
| mode="min", | |
| save_top_k=2, | |
| dirpath="artifacts/checkpoints", | |
| filename="{epoch}-{step}-{val_loss:.2f}-{val_acc:.2f}-{val_kappa:.2f}", | |
| ) | |
| # Init LearningRateMonitor | |
| lr_monitor = LearningRateMonitor(logging_interval="step") | |
| # early stopping | |
| early_stopping = EarlyStopping( | |
| monitor="val_loss", | |
| patience=5, | |
| verbose=True, | |
| mode="min", | |
| ) | |
| # Init trainer | |
| trainer = L.Trainer( | |
| max_epochs=20, | |
| accelerator="auto", | |
| devices="auto", | |
| logger=logger, | |
| callbacks=[checkpoint_callback, lr_monitor, early_stopping], | |
| # check_val_every_n_epoch=4, | |
| ) | |
| # Pass the datamodule as arg to trainer.fit to override model hooks :) | |
| trainer.fit(model, dm) | |