Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import pytorch_lightning as pl | |
| import torch.nn as nn | |
| import os | |
| import torch.optim.lr_scheduler as lrs | |
| import inspect | |
| class MInterface_base(pl.LightningModule): | |
| def __init__(self, model_name=None, loss=None, lr=None, **kargs): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.load_model() | |
| self.configure_loss() | |
| os.makedirs(os.path.join(self.hparams.res_dir, self.hparams.ex_name), exist_ok=True) | |
| def forward(self, input): | |
| pass | |
| def training_step(self, batch, batch_idx, **kwargs): | |
| pass | |
| def validation_step(self, batch, batch_idx): | |
| pass | |
| def test_step(self, batch, batch_idx): | |
| # Here we just reuse the validation_step for testing | |
| return self.validation_step(batch, batch_idx) | |
| def on_validation_epoch_end(self): | |
| # Make the Progress Bar leave there | |
| self.print('') | |
| def get_schedular(self, optimizer, lr_scheduler='onecycle'): | |
| if lr_scheduler == 'step': | |
| scheduler = lrs.StepLR(optimizer, | |
| step_size=self.hparams.lr_decay_steps, | |
| gamma=self.hparams.lr_decay_rate) | |
| elif lr_scheduler == 'cosine': | |
| scheduler = lrs.CosineAnnealingLR(optimizer, | |
| T_max=self.hparams.lr_decay_steps, | |
| eta_min=self.hparams.lr_decay_min_lr) | |
| elif lr_scheduler == 'onecycle': | |
| scheduler = lrs.OneCycleLR(optimizer, max_lr=self.hparams.lr, steps_per_epoch=self.hparams.steps_per_epoch, epochs=self.hparams.epoch, three_phase=False) | |
| else: | |
| raise ValueError('Invalid lr_scheduler type!') | |
| return scheduler | |
| def configure_optimizers(self): | |
| if hasattr(self.hparams, 'weight_decay'): | |
| weight_decay = self.hparams.weight_decay | |
| else: | |
| weight_decay = 0 | |
| optimizer_g = torch.optim.AdamW(self.model.parameters(), lr=self.hparams.lr, weight_decay=weight_decay, betas=(0.9, 0.98), eps=1e-8) | |
| schecular_g = self.get_schedular(optimizer_g, self.hparams.lr_scheduler) | |
| return [optimizer_g], [{"scheduler": schecular_g, "interval": "step"}] | |
| def lr_scheduler_step(self, *args, **kwargs): | |
| scheduler = self.lr_schedulers() | |
| scheduler.step() | |
| def configure_devices(self): | |
| self.device = torch.device(self.hparams.device) | |
| def configure_loss(self): | |
| self.loss_function = nn.CrossEntropyLoss(reduction='none') | |
| def load_model(self): | |
| self.model = None | |
| def instancialize(self, Model, **other_args): | |
| """ Instancialize a model using the corresponding parameters | |
| from self.hparams dictionary. You can also input any args | |
| to overwrite the corresponding value in self.hparams. | |
| """ | |
| class_args = inspect.getargspec(Model.__init__).args[1:] | |
| inkeys = self.hparams.keys() | |
| args1 = {} | |
| for arg in class_args: | |
| if arg in inkeys: | |
| args1[arg] = getattr(self.hparams, arg) | |
| args1.update(other_args) | |
| return Model(**args1) | |