# Please note that the current implementation of DER only contains the dynamic expansion process, since masking and pruning are not implemented by the source repo. import logging import numpy as np from tqdm import tqdm import torch from torch import nn from torch import optim from torch.nn import functional as F from torch.utils.data import DataLoader from models.base import BaseLearner from utils.inc_net import DERNet, IncrementalNet from utils.toolkit import count_parameters, target2onehot, tensor2numpy EPSILON = 1e-8 init_epoch = 200 init_lr = 0.1 init_milestones = [60, 120, 170] init_lr_decay = 0.1 init_weight_decay = 0.0005 epochs = 170 lrate = 0.1 milestones = [80, 120, 150] lrate_decay = 0.1 batch_size = 128 weight_decay = 2e-4 num_workers = 0 T = 2 class DER(BaseLearner): def __init__(self, args): super().__init__(args) self.args = args self._network = DERNet(args, False) def after_task(self): self._known_classes = self._total_classes logging.info("Exemplar size: {}".format(self.exemplar_size)) def incremental_train(self, data_manager): self._cur_task += 1 self._total_classes = self._known_classes + data_manager.get_task_size( self._cur_task ) self._network.update_fc(self._total_classes) logging.info( "Learning on {}-{}".format(self._known_classes, self._total_classes) ) if not self.args["attack"]: if self._cur_task > 0: for i in range(self._cur_task): for p in self._network.convnets[i].parameters(): p.requires_grad = False logging.info("All params: {}".format(count_parameters(self._network))) logging.info( "Trainable params: {}".format(count_parameters(self._network, True)) ) train_dataset = data_manager.get_dataset( np.arange(self._known_classes, self._total_classes), source="train", mode="train", appendent=self._get_memory(), ) self.train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers ) test_dataset = data_manager.get_dataset( np.arange(0, self._total_classes), source="test", mode="test" ) self.test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers ) if len(self._multiple_gpus) > 1: self._network = nn.DataParallel(self._network, self._multiple_gpus) self._train(self.train_loader, self.test_loader) self.build_rehearsal_memory(data_manager, self.samples_per_class) if len(self._multiple_gpus) > 1: self._network = self._network.module else: test_dataset = data_manager.get_dataset( np.arange(0, self._total_classes), source="test", mode="test" ) self.test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers ) def train(self): self._network.train() if len(self._multiple_gpus) > 1 : self._network_module_ptr = self._network.module else: self._network_module_ptr = self._network self._network_module_ptr.convnets[-1].train() if self._cur_task >= 1: for i in range(self._cur_task): self._network_module_ptr.convnets[i].eval() def _train(self, train_loader, test_loader): self._network.to(self._device) if self._cur_task == 0: optimizer = optim.SGD( filter(lambda p: p.requires_grad, self._network.parameters()), momentum=0.9, lr=init_lr, weight_decay=init_weight_decay, ) scheduler = optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay ) self._init_train(train_loader, test_loader, optimizer, scheduler) else: optimizer = optim.SGD( filter(lambda p: p.requires_grad, self._network.parameters()), lr=lrate, momentum=0.9, weight_decay=weight_decay, ) scheduler = optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=milestones, gamma=lrate_decay ) self._update_representation(train_loader, test_loader, optimizer, scheduler) if len(self._multiple_gpus) > 1: self._network.module.weight_align( self._total_classes - self._known_classes ) else: self._network.weight_align(self._total_classes - self._known_classes) def _init_train(self, train_loader, test_loader, optimizer, scheduler): prog_bar = tqdm(range(init_epoch)) for _, epoch in enumerate(prog_bar): self.train() losses = 0.0 correct, total = 0, 0 for i, (_, inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(self._device), targets.to(self._device) logits = self._network(inputs)["logits"] loss = F.cross_entropy(logits, targets) optimizer.zero_grad() loss.backward() optimizer.step() losses += loss.item() _, preds = torch.max(logits, dim=1) correct += preds.eq(targets.expand_as(preds)).cpu().sum() total += len(targets) scheduler.step() train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) if epoch % 5 == 0: test_acc = self._compute_accuracy(self._network, test_loader) info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( self._cur_task, epoch + 1, init_epoch, losses / len(train_loader), train_acc, test_acc, ) else: info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( self._cur_task, epoch + 1, init_epoch, losses / len(train_loader), train_acc, ) prog_bar.set_description(info) logging.info(info) def _update_representation(self, train_loader, test_loader, optimizer, scheduler): prog_bar = tqdm(range(epochs)) for _, epoch in enumerate(prog_bar): self.train() losses = 0.0 losses_clf = 0.0 losses_aux = 0.0 correct, total = 0, 0 for i, (_, inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(self._device), targets.to(self._device) outputs = self._network(inputs) logits, aux_logits = outputs["logits"], outputs["aux_logits"] loss_clf = F.cross_entropy(logits, targets) aux_targets = targets.clone() aux_targets = torch.where( aux_targets - self._known_classes + 1 > 0, aux_targets - self._known_classes + 1, 0, ) loss_aux = F.cross_entropy(aux_logits, aux_targets) loss = loss_clf + loss_aux optimizer.zero_grad() loss.backward() optimizer.step() losses += loss.item() losses_aux += loss_aux.item() losses_clf += loss_clf.item() _, preds = torch.max(logits, dim=1) correct += preds.eq(targets.expand_as(preds)).cpu().sum() total += len(targets) scheduler.step() train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) if epoch % 5 == 0: test_acc = self._compute_accuracy(self._network, test_loader) info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( self._cur_task, epoch + 1, epochs, losses / len(train_loader), losses_clf / len(train_loader), losses_aux / len(train_loader), train_acc, test_acc, ) else: info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}".format( self._cur_task, epoch + 1, epochs, losses / len(train_loader), losses_clf / len(train_loader), losses_aux / len(train_loader), train_acc, ) prog_bar.set_description(info) logging.info(info)