| import torch |
| import copy |
| from models.utils.continual_model import ContinualModel |
| from utils.args import add_management_args, add_experiment_args, add_rehearsal_args, ArgumentParser |
| from utils.buffer import Buffer |
| import torch.nn.functional as F |
| from utils.args import * |
|
|
|
|
|
|
| def add_parser(parser): |
| parser.add_argument('--weighta', type=float, help='Penalty weight for idempotence distillation.') |
| parser.add_argument('--weightb', type=float, help='Penalty weight for current idempotence distillation.') |
| parser.add_argument('--weightc', type=float, help='Penalty weight for er.') |
| parser.add_argument('--weightmask', type=float, help='Penalty weight for mask ratio.') |
| parser.add_argument("--class_balance", type=str2bool, default=True, |
| help="If set, the memory buffer will be balanced by class") |
| return parser |
| def get_parser() -> ArgumentParser: |
| parser = ArgumentParser(description='Idempotent Continual learning via' |
| ' Experience Replay.') |
| add_management_args(parser) |
| add_experiment_args(parser) |
| add_rehearsal_args(parser) |
| parser = add_parser(parser) |
| return parser |
|
|
| class Ider(ContinualModel): |
| NAME = 'ider' |
| COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] |
|
|
| def __init__(self, backbone, loss, args, transform): |
| super(Ider, self).__init__(backbone, loss, args, transform) |
| self.buffer = Buffer(self.args.buffer_size, self.device,class_balance = self.args.class_balance) |
| self.ft=True |
| self.task=0 |
| self.num_classs = backbone.num_classes |
| self.s = backbone.num_classes |
| self.first_task = True |
| self.old_model=self.deepcopy_model(self.net) |
|
|
| |
| |
| def observe(self, inputs, labels, not_aug_inputs): |
| batch_size, _, H, W = inputs.shape |
|
|
| self.opt.zero_grad() |
| mask_current = torch.rand(1) > self.args.weightmask |
| y_0_current = F.one_hot(labels, self.num_classs).float() if mask_current else torch.ones(batch_size, self.num_classs).to(self.device) /self.s |
| z_current = self.net.f1(inputs) |
| y_1_current, z1_current = self.net.f2(z_current, y_0_current) |
| y_2_current, z2_current = self.net.f2(z_current , y_1_current.softmax(-1)) |
| loss_supervised_1 = self.loss(y_1_current, labels) |
| loss_supervised_2 = self.loss(y_2_current, labels) |
|
|
| loss = 0.5*(loss_supervised_1 + loss_supervised_2) |
| |
| if self.args.weightb!=0 and self.task>0: |
| y_current_mask = torch.ones(batch_size, self.num_classs).to(self.device) /self.s |
| z = self.net.f1(inputs) |
| y_1, z1 = self.net.f2(z, y_current_mask) |
| z_old = self.old_model.f1(inputs) |
| y_2, z2 = self.old_model.f2(z_old, y_1.softmax(-1)) |
| loss += self.args.weightb*F.mse_loss(y_1, y_2) |
| |
| if not self.buffer.is_empty() and self.args.weightc !=0: |
| buf_inputs, buf_labels,_,_,_ = self.buffer.get_data( |
| self.args.minibatch_size, transform=self.transform) |
| batch_size, _, H, W = buf_inputs.shape |
| mask = torch.rand(1) > self.args.weightmask |
| y_0_buf = F.one_hot(buf_labels, self.num_classs).float() if mask else torch.ones(batch_size, self.num_classs).to(self.device) /self.s |
| z_buf = self.net.f1(buf_inputs) |
| y_1_buf, z1_buf = self.net.f2(z_buf, y_0_buf) |
| y_2_buf, z2_buf = self.net.f2(z_buf , y_1_buf.softmax(-1)) |
| loss_supervised_1_buf = self.loss(y_1_buf, buf_labels) |
| loss_supervised_2_buf = self.loss(y_2_buf, buf_labels) |
|
|
| loss += self.args.weightc*(loss_supervised_1_buf + loss_supervised_2_buf) |
|
|
| if not self.buffer.is_empty() and self.task>0 and self.args.weighta!=0: |
| buf_inputs, buf_labels,_,_,_ = self.buffer.get_data( |
| self.args.minibatch_size, transform=self.transform) |
| batch_size, _, H, W = buf_inputs.shape |
| |
| y_0 = torch.ones(batch_size, self.num_classs).to(self.device) /self.s |
| |
| z = self.net.f1(buf_inputs) |
| y_1, z1 = self.net.f2(z, y_0) |
| |
| z_old = self.old_model.f1(buf_inputs) |
| y_2, z2 = self.old_model.f2(z_old, y_1.softmax(-1)) |
| loss_unsupervised_y = F.mse_loss(y_1, y_2) |
| loss += self.args.weighta * loss_unsupervised_y |
| |
| |
| loss.backward() |
| self.opt.step() |
| self.buffer.add_data(examples=not_aug_inputs, labels=labels,logits=y_1_current.data, logits2=y_2_current.data,mask=y_0_current) |
|
|
| return loss.item() |
| def end_task(self, dataset): |
| print('\n\n') |
| self.task+=1 |
| print(self.task) |
|
|
| if self.first_task: |
| self.first_task = False |
| self.old_model = self.deepcopy_model(self.net).to(self.device) |
| else: |
| self.old_model = self.deepcopy_model(self.net).to(self.device) |
| |
| @staticmethod |
| def deepcopy_model(model): |
| model_copy = copy.deepcopy(model) |
| return model_copy |
|
|