Spaces:
Runtime error
Runtime error
| import argparse | |
| import pytorch_lightning as pl | |
| from datamodules import CIFAR10QADataModule, ImageDataModule | |
| from datamodules.utils import datamodule_factory | |
| from models import ImageClassificationNet | |
| from models.utils import model_factory | |
| from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint | |
| from pytorch_lightning.loggers import WandbLogger | |
| def main(args: argparse.Namespace): | |
| # Seed | |
| pl.seed_everything(args.seed) | |
| # Create base model | |
| base = model_factory(args) | |
| # Load datamodule | |
| dm = datamodule_factory(args) | |
| dm.prepare_data() | |
| dm.setup("fit") | |
| if args.checkpoint: | |
| # Load the model from the specified checkpoint | |
| model = ImageClassificationNet.load_from_checkpoint(args.checkpoint, model=base) | |
| else: | |
| # Create a new instance of the classification model | |
| model = ImageClassificationNet( | |
| model=base, | |
| num_train_steps=args.num_epochs * len(dm.train_dataloader()), | |
| optimizer=args.optimizer, | |
| weight_decay=args.weight_decay, | |
| lr=args.lr, | |
| ) | |
| # Create wandb logger | |
| wandb_logger = WandbLogger( | |
| name=f"{args.dataset}_training_{args.base_model} ({args.from_pretrained})", | |
| project="Patch-DiffMask", | |
| ) | |
| # Create checkpoint callback | |
| ckpt_cb = ModelCheckpoint(dirpath=f"checkpoints/{wandb_logger.version}") | |
| # Create early stopping callback | |
| es_cb = EarlyStopping(monitor="val_acc", mode="max", patience=5) | |
| # Create trainer | |
| trainer = pl.Trainer( | |
| accelerator="auto", | |
| callbacks=[ckpt_cb, es_cb], | |
| logger=wandb_logger, | |
| max_epochs=args.num_epochs, | |
| enable_progress_bar=args.enable_progress_bar, | |
| ) | |
| trainer_args = {} | |
| if args.checkpoint: | |
| # Resume trainer from checkpoint | |
| trainer_args["ckpt_path"] = args.checkpoint | |
| # Train the model | |
| trainer.fit(model, dm, **trainer_args) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--checkpoint", | |
| type=str, | |
| help="Checkpoint to resume the training from.", | |
| ) | |
| # Trainer | |
| parser.add_argument( | |
| "--enable_progress_bar", | |
| action="store_true", | |
| help="Whether to show progress bar during training. NOT recommended when logging to files.", | |
| ) | |
| parser.add_argument( | |
| "--num_epochs", | |
| type=int, | |
| default=5, | |
| help="Number of epochs to train.", | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=123, | |
| help="Random seed for reproducibility.", | |
| ) | |
| # Base (classification) model | |
| ImageClassificationNet.add_model_specific_args(parser) | |
| parser.add_argument( | |
| "--base_model", | |
| type=str, | |
| default="ViT", | |
| choices=["ViT", "ConvNeXt"], | |
| help="Base model architecture to train.", | |
| ) | |
| parser.add_argument( | |
| "--from_pretrained", | |
| type=str, | |
| # default="tanlq/vit-base-patch16-224-in21k-finetuned-cifar10", | |
| help="The name of the pretrained HF model to fine-tune from.", | |
| ) | |
| # Datamodule | |
| ImageDataModule.add_model_specific_args(parser) | |
| CIFAR10QADataModule.add_model_specific_args(parser) | |
| parser.add_argument( | |
| "--dataset", | |
| type=str, | |
| default="toy", | |
| choices=["MNIST", "CIFAR10", "CIFAR10_QA", "toy"], | |
| help="The dataset to use.", | |
| ) | |
| args = parser.parse_args() | |
| main(args) | |