| | import pytorch_lightning as L |
| | from pytorch_lightning.strategies import DDPStrategy |
| | from configs.config import Config |
| | from utils.data_loader import get_dataloaders |
| | from models.diffusion import Diffusion |
| |
|
| | |
| | train_loader, val_loader, _ = get_dataloaders(Config) |
| |
|
| | |
| | latent_diffusion_model = Diffusion(Config, latent_dim=Config.latent_dim) |
| |
|
| | |
| | trainer = L.Trainer( |
| | max_epochs=Config.training["epochs"], |
| | gpus=Config.training["gpus"], |
| | precision=Config.training["precision"], |
| | strategy=DDPStrategy(find_unused_parameters=False), |
| | accumulate_grad_batches=Config.training["accumulate_grad_batches"], |
| | default_root_dir=Config.training["save_dir"] |
| | ) |
| |
|
| | |
| | trainer.fit(latent_diffusion_model, train_loader, val_loader) |
| |
|