| |
| |
| |
| |
|
|
|
|
| from argparse import Namespace |
| from typing import List |
|
|
| import torch |
| import torch.nn as nn |
| from torch.optim import SGD |
|
|
| from utils.conf import get_device |
|
|
| class ContinualModel(nn.Module): |
| """ |
| Continual learning model. |
| """ |
| NAME: str |
| COMPATIBILITY: List[str] |
|
|
| def __init__(self, backbone: nn.Module, loss: nn.Module, |
| args: Namespace, transform: nn.Module) -> None: |
| super(ContinualModel, self).__init__() |
|
|
| self.net = backbone |
| self.loss = loss |
| self.args = args |
| self.transform = transform |
| self.opt = SGD(self.net.parameters(), lr=self.args.lr,weight_decay=args.optim_wd, momentum=args.optim_mom) |
| self.device = args.device |
| |
| |
|
|
| if not self.NAME or not self.COMPATIBILITY: |
| raise NotImplementedError('Please specify the name and the compatibility of the model.') |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Computes a forward pass. |
| :param x: batch of inputs |
| :param task_label: some models require the task label |
| :return: the result of the computation |
| """ |
| return self.net(x) |
|
|
| def meta_observe(self, *args, **kwargs): |
| ret = self.observe(*args, **kwargs) |
| return ret |
|
|
| def observe(self, inputs: torch.Tensor, labels: torch.Tensor, |
| not_aug_inputs: torch.Tensor) -> float: |
| """ |
| Compute a training step over a given batch of examples. |
| :param inputs: batch of examples |
| :param labels: ground-truth labels |
| :param kwargs: some methods could require additional parameters |
| :return: the value of the loss function |
| """ |
| raise NotImplementedError |
|
|