Spaces:
Sleeping
Sleeping
| import logging | |
| import numpy as np | |
| from tqdm import tqdm | |
| import torch | |
| from torch import nn | |
| from torch import optim | |
| from torch.nn import functional as F | |
| from torch.utils.data import DataLoader | |
| from models.base import BaseLearner | |
| from utils.inc_net import FOSTERNet | |
| from utils.toolkit import count_parameters, target2onehot, tensor2numpy | |
| # Please refer to https://github.com/G-U-N/ECCV22-FOSTER for the full source code to reproduce foster. | |
| EPSILON = 1e-8 | |
| class FOSTER(BaseLearner): | |
| def __init__(self, args): | |
| super().__init__(args) | |
| self.args = args | |
| self._network = FOSTERNet(args, False) | |
| self._snet = None | |
| self.beta1 = args["beta1"] | |
| self.beta2 = args["beta2"] | |
| self.per_cls_weights = None | |
| self.is_teacher_wa = args["is_teacher_wa"] | |
| self.is_student_wa = args["is_student_wa"] | |
| self.lambda_okd = args["lambda_okd"] | |
| self.wa_value = args["wa_value"] | |
| self.oofc = args["oofc"].lower() | |
| def after_task(self): | |
| self._known_classes = self._total_classes | |
| logging.info("Exemplar size: {}".format(self.exemplar_size)) | |
| def load_checkpoint(self, filename): | |
| checkpoint = torch.load(filename) | |
| self._known_classes = len(checkpoint["classes"]) | |
| self.class_list = np.array(checkpoint["classes"]) | |
| self.label_list = checkpoint["label_list"] | |
| self._network.update_fc(self._known_classes) | |
| self._network.load_checkpoint(checkpoint["network"]) | |
| self._network.to(self._device) | |
| self._cur_task = 0 | |
| def save_checkpoint(self, filename): | |
| self._network.cpu() | |
| save_dict = { | |
| "classes": self.data_manager.get_class_list(self._cur_task), | |
| "network": { | |
| "convnet": self._network.convnets[0].state_dict(), | |
| "fc": self._network.fc.state_dict() | |
| }, | |
| "label_list": self.data_manager.get_label_list(self._cur_task), | |
| } | |
| torch.save(save_dict, "./{}/{}_{}.pkl".format(filename, self.args['model_name'], self._cur_task)) | |
| def incremental_train(self, data_manager): | |
| self.data_manager = data_manager | |
| if hasattr(self.data_manager,'label_list') and hasattr(self,'label_list'): | |
| self.data_manager.label_list = list(self.label_list.values()) + self.data_manager.label_list | |
| self._cur_task += 1 | |
| if self._cur_task > 1: | |
| self._network = self._snet | |
| self._total_classes = self._known_classes + data_manager.get_task_size( | |
| self._cur_task | |
| ) | |
| self._network.update_fc(self._total_classes) | |
| self._network_module_ptr = self._network | |
| logging.info( | |
| "Learning on {}-{}".format(self._known_classes, self._total_classes) | |
| ) | |
| if self._cur_task > 0: | |
| for p in self._network.convnets[0].parameters(): | |
| p.requires_grad = False | |
| for p in self._network.oldfc.parameters(): | |
| p.requires_grad = False | |
| logging.info("All params: {}".format(count_parameters(self._network))) | |
| logging.info( | |
| "Trainable params: {}".format(count_parameters(self._network, True)) | |
| ) | |
| train_dataset = data_manager.get_dataset( | |
| np.arange(self._known_classes, self._total_classes), | |
| source="train", | |
| mode="train", | |
| appendent=self._get_memory(), | |
| ) | |
| self.train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=self.args["batch_size"], | |
| shuffle=True, | |
| num_workers=self.args["num_workers"], | |
| pin_memory=True, | |
| ) | |
| test_dataset = data_manager.get_dataset( | |
| np.arange(0, self._total_classes), source="test", mode="test" | |
| ) | |
| self.test_loader = DataLoader( | |
| test_dataset, | |
| batch_size=self.args["batch_size"], | |
| shuffle=False, | |
| num_workers=self.args["num_workers"], | |
| ) | |
| if len(self._multiple_gpus) > 1: | |
| self._network = nn.DataParallel(self._network, self._multiple_gpus) | |
| self._train(self.train_loader, self.test_loader) | |
| #self.build_rehearsal_memory(data_manager, self.samples_per_class) | |
| if len(self._multiple_gpus) > 1: | |
| self._network = self._network.module | |
| def train(self): | |
| self._network_module_ptr.train() | |
| self._network_module_ptr.convnets[-1].train() | |
| if self._cur_task >= 1: | |
| self._network_module_ptr.convnets[0].eval() | |
| def _train(self, train_loader, test_loader): | |
| self._network.to(self._device) | |
| if hasattr(self._network, "module"): | |
| self._network_module_ptr = self._network.module | |
| if self._cur_task == 0: | |
| optimizer = optim.SGD( | |
| filter(lambda p: p.requires_grad, self._network.parameters()), | |
| momentum=0.9, | |
| lr=self.args["init_lr"], | |
| weight_decay=self.args["init_weight_decay"], | |
| ) | |
| scheduler = optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer=optimizer, T_max=self.args["init_epochs"] | |
| ) | |
| self._init_train(train_loader, test_loader, optimizer, scheduler) | |
| else: | |
| cls_num_list = [self.samples_old_class] * self._known_classes + [ | |
| self.samples_new_class(i) | |
| for i in range(self._known_classes, self._total_classes) | |
| ] | |
| effective_num = 1.0 - np.power(self.beta1, cls_num_list) | |
| per_cls_weights = (1.0 - self.beta1) / np.array(effective_num) | |
| per_cls_weights = ( | |
| per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list) | |
| ) | |
| logging.info("per cls weights : {}".format(per_cls_weights)) | |
| self.per_cls_weights = torch.FloatTensor(per_cls_weights).to(self._device) | |
| optimizer = optim.SGD( | |
| filter(lambda p: p.requires_grad, self._network.parameters()), | |
| lr=self.args["lr"], | |
| momentum=0.9, | |
| weight_decay=self.args["weight_decay"], | |
| ) | |
| scheduler = optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer=optimizer, T_max=self.args["boosting_epochs"] | |
| ) | |
| if self.oofc == "az": | |
| for i, p in enumerate(self._network_module_ptr.fc.parameters()): | |
| if i == 0: | |
| p.data[ | |
| self._known_classes :, : self._network_module_ptr.out_dim | |
| ] = torch.tensor(0.0) | |
| elif self.oofc != "ft": | |
| assert 0, "not implemented" | |
| self._feature_boosting(train_loader, test_loader, optimizer, scheduler) | |
| if self.is_teacher_wa: | |
| self._network_module_ptr.weight_align( | |
| self._known_classes, | |
| self._total_classes - self._known_classes, | |
| self.wa_value, | |
| ) | |
| else: | |
| logging.info("do not weight align teacher!") | |
| cls_num_list = [self.samples_old_class] * self._known_classes + [ | |
| self.samples_new_class(i) | |
| for i in range(self._known_classes, self._total_classes) | |
| ] | |
| effective_num = 1.0 - np.power(self.beta2, cls_num_list) | |
| per_cls_weights = (1.0 - self.beta2) / np.array(effective_num) | |
| per_cls_weights = ( | |
| per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list) | |
| ) | |
| logging.info("per cls weights : {}".format(per_cls_weights)) | |
| self.per_cls_weights = torch.FloatTensor(per_cls_weights).to(self._device) | |
| self._feature_compression(train_loader, test_loader) | |
| def _init_train(self, train_loader, test_loader, optimizer, scheduler): | |
| prog_bar = tqdm(range(self.args["init_epochs"])) | |
| for _, epoch in enumerate(prog_bar): | |
| self.train() | |
| losses = 0.0 | |
| correct, total = 0, 0 | |
| for i, (_, inputs, targets) in enumerate(train_loader): | |
| inputs, targets = inputs.to( | |
| self._device, non_blocking=True | |
| ), targets.to(self._device, non_blocking=True) | |
| logits = self._network(inputs)["logits"] | |
| loss = F.cross_entropy(logits, targets) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| losses += loss.item() | |
| _, preds = torch.max(logits, dim=1) | |
| correct += preds.eq(targets.expand_as(preds)).cpu().sum() | |
| total += len(targets) | |
| scheduler.step() | |
| train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) | |
| if epoch % 5 == 0: | |
| test_acc = self._compute_accuracy(self._network, test_loader) | |
| info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( | |
| self._cur_task, | |
| epoch + 1, | |
| self.args["init_epochs"], | |
| losses / len(train_loader), | |
| train_acc, | |
| test_acc, | |
| ) | |
| else: | |
| info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( | |
| self._cur_task, | |
| epoch + 1, | |
| self.args["init_epochs"], | |
| losses / len(train_loader), | |
| train_acc, | |
| ) | |
| prog_bar.set_description(info) | |
| logging.info(info) | |
| def _feature_boosting(self, train_loader, test_loader, optimizer, scheduler): | |
| prog_bar = tqdm(range(self.args["boosting_epochs"])) | |
| for _, epoch in enumerate(prog_bar): | |
| self.train() | |
| losses = 0.0 | |
| losses_clf = 0.0 | |
| losses_fe = 0.0 | |
| losses_kd = 0.0 | |
| correct, total = 0, 0 | |
| for i, (_, inputs, targets) in enumerate(train_loader): | |
| inputs, targets = inputs.to( | |
| self._device, non_blocking=True | |
| ), targets.to(self._device, non_blocking=True) | |
| outputs = self._network(inputs) | |
| logits, fe_logits, old_logits = ( | |
| outputs["logits"], | |
| outputs["fe_logits"], | |
| outputs["old_logits"].detach(), | |
| ) | |
| loss_clf = F.cross_entropy(logits / self.per_cls_weights, targets) | |
| loss_fe = F.cross_entropy(fe_logits, targets) | |
| loss_kd = self.lambda_okd * _KD_loss( | |
| logits[:, : self._known_classes], old_logits, self.args["T"] | |
| ) | |
| loss = loss_clf + loss_fe + loss_kd | |
| optimizer.zero_grad() | |
| loss.backward() | |
| if self.oofc == "az": | |
| for i, p in enumerate(self._network_module_ptr.fc.parameters()): | |
| if i == 0: | |
| p.grad.data[ | |
| self._known_classes :, | |
| : self._network_module_ptr.out_dim, | |
| ] = torch.tensor(0.0) | |
| elif self.oofc != "ft": | |
| assert 0, "not implemented" | |
| optimizer.step() | |
| losses += loss.item() | |
| losses_fe += loss_fe.item() | |
| losses_clf += loss_clf.item() | |
| losses_kd += ( | |
| self._known_classes / self._total_classes | |
| ) * loss_kd.item() | |
| _, preds = torch.max(logits, dim=1) | |
| correct += preds.eq(targets.expand_as(preds)).cpu().sum() | |
| total += len(targets) | |
| scheduler.step() | |
| train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) | |
| if epoch % 5 == 0: | |
| test_acc = self._compute_accuracy(self._network, test_loader) | |
| info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_kd {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( | |
| self._cur_task, | |
| epoch + 1, | |
| self.args["boosting_epochs"], | |
| losses / len(train_loader), | |
| losses_clf / len(train_loader), | |
| losses_fe / len(train_loader), | |
| losses_kd / len(train_loader), | |
| train_acc, | |
| test_acc, | |
| ) | |
| else: | |
| info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_kd {:.3f}, Train_accy {:.2f}".format( | |
| self._cur_task, | |
| epoch + 1, | |
| self.args["boosting_epochs"], | |
| losses / len(train_loader), | |
| losses_clf / len(train_loader), | |
| losses_fe / len(train_loader), | |
| losses_kd / len(train_loader), | |
| train_acc, | |
| ) | |
| prog_bar.set_description(info) | |
| logging.info(info) | |
| def _feature_compression(self, train_loader, test_loader): | |
| self._snet = FOSTERNet(self.args, False) | |
| self._snet.update_fc(self._total_classes) | |
| if len(self._multiple_gpus) > 1: | |
| self._snet = nn.DataParallel(self._snet, self._multiple_gpus) | |
| if hasattr(self._snet, "module"): | |
| self._snet_module_ptr = self._snet.module | |
| else: | |
| self._snet_module_ptr = self._snet | |
| self._snet.to(self._device) | |
| self._snet_module_ptr.convnets[0].load_state_dict( | |
| self._network_module_ptr.convnets[0].state_dict() | |
| ) | |
| self._snet_module_ptr.copy_fc(self._network_module_ptr.oldfc) | |
| optimizer = optim.SGD( | |
| filter(lambda p: p.requires_grad, self._snet.parameters()), | |
| lr=self.args["lr"], | |
| momentum=0.9, | |
| ) | |
| scheduler = optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer=optimizer, T_max=self.args["compression_epochs"] | |
| ) | |
| self._network.eval() | |
| prog_bar = tqdm(range(self.args["compression_epochs"])) | |
| for _, epoch in enumerate(prog_bar): | |
| self._snet.train() | |
| losses = 0.0 | |
| correct, total = 0, 0 | |
| for i, (_, inputs, targets) in enumerate(train_loader): | |
| inputs, targets = inputs.to( | |
| self._device, non_blocking=True | |
| ), targets.to(self._device, non_blocking=True) | |
| dark_logits = self._snet(inputs)["logits"] | |
| with torch.no_grad(): | |
| outputs = self._network(inputs) | |
| logits, old_logits, fe_logits = ( | |
| outputs["logits"], | |
| outputs["old_logits"], | |
| outputs["fe_logits"], | |
| ) | |
| loss_dark = self.BKD(dark_logits, logits, self.args["T"]) | |
| loss = loss_dark | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| losses += loss.item() | |
| _, preds = torch.max(dark_logits[: targets.shape[0]], dim=1) | |
| correct += preds.eq(targets.expand_as(preds)).cpu().sum() | |
| total += len(targets) | |
| scheduler.step() | |
| train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) | |
| if epoch % 5 == 0: | |
| test_acc = self._compute_accuracy(self._snet, test_loader) | |
| info = "SNet: Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( | |
| self._cur_task, | |
| epoch + 1, | |
| self.args["compression_epochs"], | |
| losses / len(train_loader), | |
| train_acc, | |
| test_acc, | |
| ) | |
| else: | |
| info = "SNet: Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( | |
| self._cur_task, | |
| epoch + 1, | |
| self.args["compression_epochs"], | |
| losses / len(train_loader), | |
| train_acc, | |
| ) | |
| prog_bar.set_description(info) | |
| logging.info(info) | |
| if len(self._multiple_gpus) > 1: | |
| self._snet = self._snet.module | |
| if self.is_student_wa: | |
| self._snet.weight_align( | |
| self._known_classes, | |
| self._total_classes - self._known_classes, | |
| self.wa_value, | |
| ) | |
| else: | |
| logging.info("do not weight align student!") | |
| if self._cur_task > 1: | |
| self._network = self._snet | |
| self._snet.eval() | |
| y_pred, y_true = [], [] | |
| for _, (_, inputs, targets) in enumerate(test_loader): | |
| inputs = inputs.to(self._device, non_blocking=True) | |
| with torch.no_grad(): | |
| outputs = self._snet(inputs)["logits"] | |
| predicts = torch.topk( | |
| outputs, k=self.topk, dim=1, largest=True, sorted=True | |
| )[1] | |
| y_pred.append(predicts.cpu().numpy()) | |
| y_true.append(targets.cpu().numpy()) | |
| y_pred = np.concatenate(y_pred) | |
| y_true = np.concatenate(y_true) | |
| cnn_accy = self._evaluate(y_pred, y_true) | |
| logging.info("darknet eval: ") | |
| logging.info("CNN top1 curve: {}".format(cnn_accy["top1"])) | |
| logging.info("CNN top5 curve: {}".format(cnn_accy["top5"])) | |
| def samples_old_class(self): | |
| if self._fixed_memory: | |
| return self._memory_per_class | |
| else: | |
| assert self._total_classes != 0, "Total classes is 0" | |
| return self._memory_size // self._known_classes | |
| def samples_new_class(self, index): | |
| if self.args["dataset"] == "cifar100": | |
| return 500 | |
| else: | |
| return self.data_manager.getlen(index) | |
| def BKD(self, pred, soft, T): | |
| pred = torch.log_softmax(pred / T, dim=1) | |
| soft = torch.softmax(soft / T, dim=1) | |
| soft = soft * self.per_cls_weights | |
| soft = soft / soft.sum(1)[:, None] | |
| return -1 * torch.mul(soft, pred).sum() / pred.shape[0] | |
| def _KD_loss(pred, soft, T): | |
| pred = torch.log_softmax(pred / T, dim=1) | |
| soft = torch.softmax(soft / T, dim=1) | |
| return -1 * torch.mul(soft, pred).sum() / pred.shape[0] | |