vrevar
Add application file
04c78c7
from pathlib import Path
import jsonargparse
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from ..source import Vanilla, DenseReg
from ..callbacks import VisualizeCallback
from ..data.module import DataModule
#! refactor this simplification required
class LightningArgumentParser(jsonargparse.ArgumentParser):
"""
Extension of jsonargparse.ArgumentParser to parse pl.classes and more.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def add_datamodule(self, datamodule_obj: pl.LightningDataModule):
self.add_method_arguments(datamodule_obj, '__init__', 'data', as_group=True)
def add_lossmodule(self, lossmodule_obj: nn.Module):
self.add_class(lossmodule_obj, 'loss')
def add_routine(self, model_obj: pl.LightningModule):
skip = {'ae', 'decoder', 'loss', 'transnet', 'model', 'discr', 'adv_loss', 'stage'}
self.add_class_arguments(model_obj, 'routine', as_group=True, skip=skip)
def add_logger(self, logger_obj):
skip = {'version', 'config', 'name', 'save_dir'}
self.add_class_arguments(logger_obj, 'logger', as_group=True, skip=skip)
def add_class(self, cls, group, **kwargs):
self.add_class_arguments(cls, group, as_group=True, **kwargs)
def add_trainer(self):
skip = {'default_root_dir', 'logger', 'callbacks'}
self.add_class_arguments(pl.Trainer, 'trainer', as_group=True, skip=skip)
def get_args(datamodule=DataModule, loss=DenseReg, routine=Vanilla, viz=VisualizeCallback):
parser = LightningArgumentParser()
parser.add_argument('--config', action=jsonargparse.ActionConfigFile, required=True)
parser.add_argument('--archi', type=str, required=True)
parser.add_argument('--out_dir', type=lambda x: Path(x), required=True)
parser.add_argument('--seed', default=666, type=int)
parser.add_argument('--load_weights_from', type=lambda x: Path(x))
parser.add_argument('--save_ckpt_every', default=10, type=int)
parser.add_argument('--wandb', action='store_true', default=False)
parser.add_argument('--mode', choices=['train', 'eval', 'test', 'predict'], default='train', type=str)
parser.add_argument('--resume_from', default=None, type=str)
if datamodule is not None:
parser.add_datamodule(datamodule)
if loss is not None:
parser.add_lossmodule(loss)
if routine is not None:
parser.add_routine(routine)
if viz is not None:
parser.add_class_arguments(viz, 'viz', skip={'out_dir', 'exist_ok'})
# bindings between modules (data/routine/loss)
parser.link_arguments('data.batch_size', 'routine.batch_size')
parser.add_logger(WandbLogger)
parser.add_trainer()
args = parser.parse_args()
return args