File size: 5,108 Bytes
377dccd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | import torch
import copy
from models.utils.continual_model import ContinualModel
from utils.args import add_management_args, add_experiment_args, add_rehearsal_args, ArgumentParser
from utils.buffer import Buffer
import torch.nn.functional as F
from utils.args import *
def add_parser(parser):
parser.add_argument('--weighta', type=float, help='Penalty weight for idempotence distillation.')
parser.add_argument('--weightb', type=float, help='Penalty weight for current idempotence distillation.')
parser.add_argument('--weightc', type=float, help='Penalty weight for er.')
parser.add_argument('--weightmask', type=float, help='Penalty weight for mask ratio.')
parser.add_argument("--class_balance", type=str2bool, default=True,
help="If set, the memory buffer will be balanced by class")
return parser
def get_parser() -> ArgumentParser:
parser = ArgumentParser(description='Idempotent Continual learning via'
' Experience Replay.')
add_management_args(parser)
add_experiment_args(parser)
add_rehearsal_args(parser)
parser = add_parser(parser)
return parser
class Ider(ContinualModel):
NAME = 'ider'
COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
def __init__(self, backbone, loss, args, transform):
super(Ider, self).__init__(backbone, loss, args, transform)
self.buffer = Buffer(self.args.buffer_size, self.device,class_balance = self.args.class_balance)
self.ft=True
self.task=0
self.num_classs = backbone.num_classes
self.s = backbone.num_classes
self.first_task = True
self.old_model=self.deepcopy_model(self.net)
def observe(self, inputs, labels, not_aug_inputs):
batch_size, _, H, W = inputs.shape
self.opt.zero_grad()
mask_current = torch.rand(1) > self.args.weightmask
y_0_current = F.one_hot(labels, self.num_classs).float() if mask_current else torch.ones(batch_size, self.num_classs).to(self.device) /self.s
z_current = self.net.f1(inputs)
y_1_current, z1_current = self.net.f2(z_current, y_0_current)
y_2_current, z2_current = self.net.f2(z_current , y_1_current.softmax(-1))
loss_supervised_1 = self.loss(y_1_current, labels)
loss_supervised_2 = self.loss(y_2_current, labels)
loss = 0.5*(loss_supervised_1 + loss_supervised_2)
if self.args.weightb!=0 and self.task>0:
y_current_mask = torch.ones(batch_size, self.num_classs).to(self.device) /self.s
z = self.net.f1(inputs)
y_1, z1 = self.net.f2(z, y_current_mask)
z_old = self.old_model.f1(inputs)
y_2, z2 = self.old_model.f2(z_old, y_1.softmax(-1))
loss += self.args.weightb*F.mse_loss(y_1, y_2)
if not self.buffer.is_empty() and self.args.weightc !=0:
buf_inputs, buf_labels,_,_,_ = self.buffer.get_data(
self.args.minibatch_size, transform=self.transform)
batch_size, _, H, W = buf_inputs.shape
mask = torch.rand(1) > self.args.weightmask
y_0_buf = F.one_hot(buf_labels, self.num_classs).float() if mask else torch.ones(batch_size, self.num_classs).to(self.device) /self.s
z_buf = self.net.f1(buf_inputs)
y_1_buf, z1_buf = self.net.f2(z_buf, y_0_buf)
y_2_buf, z2_buf = self.net.f2(z_buf , y_1_buf.softmax(-1))
loss_supervised_1_buf = self.loss(y_1_buf, buf_labels)
loss_supervised_2_buf = self.loss(y_2_buf, buf_labels)
loss += self.args.weightc*(loss_supervised_1_buf + loss_supervised_2_buf)
if not self.buffer.is_empty() and self.task>0 and self.args.weighta!=0:
buf_inputs, buf_labels,_,_,_ = self.buffer.get_data(
self.args.minibatch_size, transform=self.transform)
batch_size, _, H, W = buf_inputs.shape
y_0 = torch.ones(batch_size, self.num_classs).to(self.device) /self.s
z = self.net.f1(buf_inputs)
y_1, z1 = self.net.f2(z, y_0)
z_old = self.old_model.f1(buf_inputs)
y_2, z2 = self.old_model.f2(z_old, y_1.softmax(-1))
loss_unsupervised_y = F.mse_loss(y_1, y_2)
loss += self.args.weighta * loss_unsupervised_y
loss.backward()
self.opt.step()
self.buffer.add_data(examples=not_aug_inputs, labels=labels,logits=y_1_current.data, logits2=y_2_current.data,mask=y_0_current)
return loss.item()
def end_task(self, dataset):
print('\n\n')
self.task+=1
print(self.task)
if self.first_task:
self.first_task = False
self.old_model = self.deepcopy_model(self.net).to(self.device)
else:
self.old_model = self.deepcopy_model(self.net).to(self.device)
@staticmethod
def deepcopy_model(model):
model_copy = copy.deepcopy(model)
return model_copy
|