Spaces:
No application file
No application file
| import torch | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.loggers import WandbLogger | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| import os | |
| import hydra | |
| from omegaconf import DictConfig, OmegaConf | |
| from datetime import datetime | |
| import wandb | |
| import data | |
| from lightning_modules import ( | |
| MaskedDiffusionModule, | |
| AutoregressiveModule, | |
| AnyOrderInsertionFlowModule, | |
| ) | |
| from pytorch_lightning.utilities import rank_zero_only | |
| torch.set_printoptions(threshold=10_000) | |
| torch.set_float32_matmul_precision("high") | |
| def train(config: DictConfig): | |
| # set the random seed | |
| pl.seed_everything(42) | |
| torch.manual_seed(42) | |
| if "wandb" in config and rank_zero_only.rank == 0: | |
| init_kwargs = dict( | |
| project="interpretable-flow", | |
| entity=config.wandb.entity, | |
| config=OmegaConf.to_container(config, resolve=True), | |
| name=config.wandb.name, | |
| ) | |
| # resume wandb run if we're resuming from a checkpoint | |
| if "resume_path" in config.training: | |
| init_kwargs["resume"] = "allow" | |
| wandb.init(**init_kwargs) | |
| wandb_logger = WandbLogger( | |
| project=wandb.run.project, | |
| name=wandb.run.name, | |
| log_model=True, | |
| ) | |
| else: | |
| wandb_logger = None | |
| time_string = datetime.now().strftime("%Y%m%d-%H%M%S") | |
| config.training.checkpoint_dir = os.path.join( | |
| config.training.checkpoint_dir, time_string | |
| ) | |
| # Create checkpoint directory | |
| os.makedirs(config.training.checkpoint_dir, exist_ok=True) | |
| dataset_bundle = data.setup_data_and_update_config(config) | |
| match config.trainer: | |
| case "mdm": | |
| module = MaskedDiffusionModule(config) | |
| case "autoregressive": | |
| module = AutoregressiveModule(config) | |
| case "any-order-flow": | |
| module = AnyOrderInsertionFlowModule(config) | |
| case _: | |
| raise NotImplementedError(f"Trainer {config.trainer} is not supported") | |
| # Initialize trainer | |
| # Configure trainer arguments | |
| trainer_kwargs = dict( | |
| num_nodes=config.training.nodes, | |
| accelerator="gpu", | |
| devices=config.training.devices, | |
| strategy="ddp", | |
| accumulate_grad_batches=( | |
| config.training.batch_size | |
| // ( | |
| config.training.per_gpu_batch_size | |
| * config.training.nodes | |
| * config.training.devices | |
| ) | |
| ), | |
| log_every_n_steps=10, | |
| enable_checkpointing=True, | |
| default_root_dir=config.training.checkpoint_dir, | |
| gradient_clip_val=1.0, | |
| ) | |
| # Only one of max_steps or max_epochs will be used | |
| if config.training.max_steps is not None: | |
| trainer_kwargs["max_steps"] = config.training.max_steps | |
| elif config.training.num_epochs is not None: | |
| trainer_kwargs["max_epochs"] = config.training.num_epochs | |
| config.training.max_steps = config.training.num_epochs * len( | |
| dataset_bundle.train_loader | |
| ) | |
| else: | |
| raise ValueError( | |
| "Either max_steps or num_epochs must be specified in the config" | |
| ) | |
| if config.training.warmup_steps is None: | |
| config.training.warmup_steps = int(config.training.max_steps * 0.01) | |
| # Add ModelCheckpoint callback to save the checkpoint when validation loss is at a new low | |
| checkpoint_callback = ModelCheckpoint( | |
| monitor="val_loss", | |
| mode="min", | |
| save_top_k=config.training.save_top_k, | |
| save_last=True, | |
| filename="epoch-{epoch:02d}-val_loss-{val_loss:.4f}", | |
| dirpath=config.training.checkpoint_dir, | |
| every_n_train_steps=10000, | |
| # every_n_epochs=config.training.save_every_n_epochs, | |
| ) | |
| trainer_kwargs["callbacks"] = [checkpoint_callback] | |
| if wandb_logger is not None: | |
| trainer_kwargs["logger"] = wandb_logger | |
| trainer = pl.Trainer(**trainer_kwargs) | |
| # Train the model | |
| ckpt_path = None | |
| if "resume_path" in config.training: | |
| ckpt_path = config.training.resume_path | |
| trainer.fit( | |
| module, | |
| train_dataloaders=dataset_bundle.train_loader, | |
| val_dataloaders=dataset_bundle.val_loader, | |
| ckpt_path=ckpt_path, | |
| ) | |
| if "wandb" in config: | |
| wandb.finish() | |
| def main(cfg: DictConfig): | |
| train(cfg) | |
| if __name__ == "__main__": | |
| main() | |