Image Classification
English
zhanwang's picture
update
377dccd verified
# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from argparse import Namespace
from typing import List
import torch
import torch.nn as nn
from torch.optim import SGD
from utils.conf import get_device
class ContinualModel(nn.Module):
"""
Continual learning model.
"""
NAME: str
COMPATIBILITY: List[str]
def __init__(self, backbone: nn.Module, loss: nn.Module,
args: Namespace, transform: nn.Module) -> None:
super(ContinualModel, self).__init__()
self.net = backbone
self.loss = loss
self.args = args
self.transform = transform
self.opt = SGD(self.net.parameters(), lr=self.args.lr,weight_decay=args.optim_wd, momentum=args.optim_mom)
self.device = args.device
if not self.NAME or not self.COMPATIBILITY:
raise NotImplementedError('Please specify the name and the compatibility of the model.')
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Computes a forward pass.
:param x: batch of inputs
:param task_label: some models require the task label
:return: the result of the computation
"""
return self.net(x)
def meta_observe(self, *args, **kwargs):
ret = self.observe(*args, **kwargs)
return ret
def observe(self, inputs: torch.Tensor, labels: torch.Tensor,
not_aug_inputs: torch.Tensor) -> float:
"""
Compute a training step over a given batch of examples.
:param inputs: batch of examples
:param labels: ground-truth labels
:param kwargs: some methods could require additional parameters
:return: the value of the loss function
"""
raise NotImplementedError