| |
| |
| |
| |
|
|
| import torch |
|
|
| from models.utils.continual_model import ContinualModel |
| from utils.args import add_management_args, add_experiment_args, add_rehearsal_args, ArgumentParser |
| from utils.buffer import Buffer |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from sklearn.manifold import TSNE |
|
|
| def get_parser() -> ArgumentParser: |
| parser = ArgumentParser(description='Continual learning via' |
| ' Experience Replay.') |
| add_management_args(parser) |
| add_experiment_args(parser) |
| add_rehearsal_args(parser) |
| return parser |
|
|
|
|
| class Er(ContinualModel): |
| NAME = 'er' |
| COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] |
|
|
| def __init__(self, backbone, loss, args, transform): |
| super(Er, self).__init__(backbone, loss, args, transform) |
| self.buffer = Buffer(self.args.buffer_size, self.device) |
| self.task=0 |
|
|
| def observe(self, inputs, labels, not_aug_inputs): |
|
|
| real_batch_size = inputs.shape[0] |
|
|
| self.opt.zero_grad() |
| if not self.buffer.is_empty(): |
| buf_inputs, buf_labels = self.buffer.get_data( |
| self.args.minibatch_size, transform=self.transform) |
| inputs = torch.cat((inputs, buf_inputs)) |
| labels = torch.cat((labels, buf_labels)) |
|
|
| outputs = self.net(inputs) |
| loss = self.loss(outputs, labels) |
| loss.backward() |
| self.opt.step() |
|
|
| self.buffer.add_data(examples=not_aug_inputs, |
| labels=labels[:real_batch_size]) |
|
|
| return loss.item() |
| def end_task(self, dataset): |
| print('\n\n') |
| self.task+=1 |
| print(self.task) |
|
|