# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import csv import math import sys from argparse import Namespace from typing import Tuple import torch from datasets import get_dataset from datasets.utils.continual_dataset import ContinualDataset from models.utils.continual_model import ContinualModel from typing import Tuple, List from utils.loggers import * from utils.mlflow_logger import MLFlowLogger from utils.status import ProgressBar import torch.nn.functional as F import utils.metrics def evaluate_ece(model: ContinualModel, dataset: ContinualDataset, last=False) -> Tuple[List[float], List[float], float, float, float]: """ Evaluates accuracy and computes AURC, FPR95, AUROC for each loader, then averages. :param model: model to evaluate :param dataset: continual dataset :return: class-il acc, task-il acc, mean AURC, mean FPR95, mean AUROC """ status = model.net.training model.net.eval() accs, accs_mask_classes = [], [] ece_list = [] for k, test_loader in enumerate(dataset.test_loaders): if last and k < len(dataset.test_loaders) - 1: continue correct, correct_mask_classes, total = 0.0, 0.0, 0.0 val_log = {'softmax' : [], 'correct' : [], 'logit' : [], 'target':[]} for data in test_loader: with torch.no_grad(): inputs, labels = data inputs, labels = inputs.to(model.device), labels.to(model.device) if 'class-il' not in model.COMPATIBILITY: output = model(inputs, k) else: output = model(inputs) softmax = F.softmax(output, dim=1) _, pred_cls = softmax.max(1) val_log['correct'].append(pred_cls.cpu().eq(labels.cpu().data.view_as(pred_cls)).numpy()) val_log['softmax'].append(softmax.cpu().data.numpy()) val_log['logit'].append(output.cpu().data.numpy()) val_log['target'].append(labels.cpu().data.numpy()) _, pred = torch.max(output.data, 1) correct += torch.sum(pred == labels).item() total += labels.shape[0] if dataset.SETTING == 'class-il': mask_classes(output, dataset, k) _, pred_masked = torch.max(output.data, 1) correct_mask_classes += pred_masked.eq(labels).sum().item() for key in val_log : val_log[key] = np.concatenate(val_log[key]) acc = 100. * val_log['correct'].mean() ece = utils.metrics.calc_ece(val_log['softmax'], val_log['target'], bins=15) accs.append(correct / total * 100 if 'class-il' in model.COMPATIBILITY else 0) accs_mask_classes.append(correct_mask_classes / total * 100) ece_list.append(ece*100) model.net.train(status) # averagge mean_ece = np.mean(ece_list) print('evaluation acc:', accs) print(f'ECE per loader: {ece_list}') print(f'Mean ECE: {mean_ece:.4f}') return accs, accs_mask_classes def evaluate_eceid(model: ContinualModel, dataset: ContinualDataset, last=False) -> Tuple[List[float], List[float], float, float, float]: """ Evaluates accuracy and computes AURC, FPR95, AUROC for each loader, then averages. :param model: model to evaluate :param dataset: continual dataset :return: class-il acc, task-il acc, mean AURC, mean FPR95, mean AUROC """ status = model.net.training model.net.eval() accs, accs_mask_classes = [], [] ece_list = [] for k, test_loader in enumerate(dataset.test_loaders): if last and k < len(dataset.test_loaders) - 1: continue correct, correct_mask_classes, total = 0.0, 0.0, 0.0 val_log = {'softmax' : [], 'correct' : [], 'logit' : [], 'target':[]} for data in test_loader: with torch.no_grad(): inputs, labels = data inputs, labels = inputs.to(model.device), labels.to(model.device) batch_size, _, H, W = inputs.shape if 'class-il' not in model.COMPATIBILITY: output = model(inputs, k) else: y_0 = torch.ones(batch_size, dataset.N_CLASSES).to(inputs.device) /dataset.N_CLASSES z = model.net.f1(inputs) output, z1 = model.net.f2(z, y_0) softmax = F.softmax(output, dim=1) _, pred_cls = softmax.max(1) val_log['correct'].append(pred_cls.cpu().eq(labels.cpu().data.view_as(pred_cls)).numpy()) val_log['softmax'].append(softmax.cpu().data.numpy()) val_log['logit'].append(output.cpu().data.numpy()) val_log['target'].append(labels.cpu().data.numpy()) _, pred = torch.max(output.data, 1) correct += torch.sum(pred == labels).item() total += labels.shape[0] if dataset.SETTING == 'class-il': mask_classes(output, dataset, k) _, pred_masked = torch.max(output.data, 1) correct_mask_classes += pred_masked.eq(labels).sum().item() for key in val_log : val_log[key] = np.concatenate(val_log[key]) acc = 100. * val_log['correct'].mean() ece = utils.metrics.calc_ece(val_log['softmax'], val_log['target'], bins=15) accs.append(correct / total * 100 if 'class-il' in model.COMPATIBILITY else 0) accs_mask_classes.append(correct_mask_classes / total * 100) ece_list.append(ece*100) model.net.train(status) # average mean_ece = np.mean(ece_list) # print('evaluation acc:', accs) print(f'ECE per loader: {ece_list}') print(f'Mean ECE: {mean_ece:.4f}') return accs, accs_mask_classes def mask_classes(outputs: torch.Tensor, dataset: ContinualDataset, k: int) -> None: """ Given the output tensor, the dataset at hand and the current task, masks the former by setting the responses for the other tasks at -inf. It is used to obtain the results for the task-il setting. :param outputs: the output tensor :param dataset: the continual dataset :param k: the task index """ outputs[:, 0:k * dataset.N_CLASSES_PER_TASK] = -float('inf') outputs[:, (k + 1) * dataset.N_CLASSES_PER_TASK: dataset.N_TASKS * dataset.N_CLASSES_PER_TASK] = -float('inf') def evaluate(model: ContinualModel, dataset: ContinualDataset, last=False) -> Tuple[list, list]: """ Evaluates the accuracy of the model for each past task. :param model: the model to be evaluated :param dataset: the continual dataset at hand :return: a tuple of lists, containing the class-il and task-il accuracy for each task """ status = model.net.training model.net.eval() accs, accs_mask_classes = [], [] for k, test_loader in enumerate(dataset.test_loaders): if last and k < len(dataset.test_loaders) - 1: continue correct, correct_mask_classes, total = 0.0, 0.0, 0.0 for data in test_loader: with torch.no_grad(): inputs, labels = data inputs, labels = inputs.to(model.device), labels.to(model.device) if 'class-il' not in model.COMPATIBILITY: outputs = model(inputs, k) else: outputs = model(inputs) _, pred = torch.max(outputs.data, 1) correct += torch.sum(pred == labels).item() total += labels.shape[0] if dataset.SETTING == 'class-il': mask_classes(outputs, dataset, k) _, pred = torch.max(outputs.data, 1) correct_mask_classes += torch.sum(pred == labels).item() accs.append(correct / total * 100 if 'class-il' in model.COMPATIBILITY else 0) accs_mask_classes.append(correct_mask_classes / total * 100) model.net.train(status) print('evaluation acc:') print(accs) return accs, accs_mask_classes def evaluateid(model: ContinualModel, dataset: ContinualDataset, last=False) -> Tuple[list, list]: """ Evaluates the accuracy of the model for each past task. :param model: the model to be evaluated :param dataset: the continual dataset at hand :return: a tuple of lists, containing the class-il and task-il accuracy for each task """ status = model.net.training model.net.eval() accs, accs_mask_classes = [], [] for k, test_loader in enumerate(dataset.test_loaders): if last and k < len(dataset.test_loaders) - 1: continue correct, correct_mask_classes, total = 0.0, 0.0, 0.0 for data in test_loader: with torch.no_grad(): inputs, labels = data inputs, labels = inputs.to(model.device), labels.to(model.device) batch_size, _, H, W = inputs.shape if 'class-il' not in model.COMPATIBILITY: outputs = model(inputs, k) else: y_0 = torch.ones(batch_size, dataset.N_CLASSES).to(inputs.device) /dataset.N_CLASSES z = model.net.f1(inputs) outputs, z1 = model.net.f2(z, y_0) _, pred = torch.max(outputs.data, 1) correct += torch.sum(pred == labels).item() total += labels.shape[0] if dataset.SETTING == 'class-il': mask_classes(outputs, dataset, k) _, pred = torch.max(outputs.data, 1) correct_mask_classes += torch.sum(pred == labels).item() accs.append(correct / total * 100 if 'class-il' in model.COMPATIBILITY else 0) accs_mask_classes.append(correct_mask_classes / total * 100) model.net.train(status) print('evaluation acc:') print(accs) return accs, accs_mask_classes def train(model: ContinualModel, dataset: ContinualDataset, args: Namespace) -> None: """ The training process, including evaluations and loggers. :param model: the module to be trained :param dataset: the continual dataset at hand :param args: the arguments of the current execution """ model.net.to(model.device) results, results_mask_classes = [], [] if not args.disable_log and not args.debug: logger = MLFlowLogger(dataset.SETTING, dataset.NAME, model.NAME, experiment_name=args.experiment_name, parent_run_id=args.parent_run_id, run_name=args.run_name) logger.log_args(args.__dict__) progress_bar = ProgressBar(verbose=not args.non_verbose) if not args.ignore_other_metrics and not args.debug: dataset_copy = get_dataset(args) for t in range(dataset.N_TASKS): model.net.train() _, _ = dataset_copy.get_data_loaders() if model.NAME != 'icarl' and model.NAME != 'pnn': if model.NAME =='ider': random_results_class, random_results_task = evaluateid(model, dataset_copy) else: random_results_class, random_results_task = evaluate(model, dataset_copy) if os.path.exists('old_model.pt'): os.remove('old_model.pt') if os.path.exists('net.pt'): os.remove('net.pt') print(file=sys.stderr) for t in range(dataset.N_TASKS): model.net.train() train_loader, test_loader = dataset.get_data_loaders() if hasattr(model, 'begin_task'): model.begin_task(dataset) if t and not args.ignore_other_metrics and not args.debug: if model.NAME =='ider': accs = evaluateid(model, dataset, last=True) else: accs = evaluate(model, dataset, last=True) results[t-1] = results[t-1] + accs[0] if dataset.SETTING == 'class-il': results_mask_classes[t-1] = results_mask_classes[t-1] + accs[1] scheduler = dataset.get_scheduler(model, args) for epoch in range(model.args.n_epochs): if args.model == 'joint': continue for i, data in enumerate(train_loader): if args.debug and i > 3: break if hasattr(dataset.train_loader.dataset, 'logits'): inputs, labels, not_aug_inputs, logits = data inputs = inputs.to(model.device) labels = labels.to(model.device) not_aug_inputs = not_aug_inputs.to(model.device) logits = logits.to(model.device) loss = model.meta_observe(inputs, labels, not_aug_inputs, logits) else: inputs, labels, not_aug_inputs = data inputs, labels = inputs.to(model.device), labels.to( model.device) not_aug_inputs = not_aug_inputs.to(model.device) loss = model.meta_observe(inputs, labels, not_aug_inputs) assert not math.isnan(loss) progress_bar.prog(i, len(train_loader), epoch, t, loss) if scheduler is not None: scheduler.step() if hasattr(model, 'end_epoch'): model.end_epoch(dataset) if hasattr(model, 'end_task'): model.end_task(dataset) if model.NAME =='ider': accs = evaluateid(model, dataset) else: accs = evaluate(model, dataset) results.append(accs[0]) results_mask_classes.append(accs[1]) if model.NAME =='ider': eces= evaluate_eceid(model, dataset) else: eces= evaluate_ece(model, dataset) mean_acc = np.mean(accs, axis=1) print_mean_accuracy(mean_acc, t + 1, dataset.SETTING) if not args.disable_log and not args.debug: logger.log(mean_acc) logger.log_fullacc(accs) if not args.disable_log and not args.ignore_other_metrics and not args.debug: logger.add_bwt(results, results_mask_classes) logger.add_forgetting(results, results_mask_classes) if model.NAME != 'icarl' and model.NAME != 'pnn': logger.add_fwt(results, random_results_class, results_mask_classes, random_results_task) if args.savecheckpoint: dataset_name = args.dataset if hasattr(args, 'dataset') and args.dataset else 'unknown_dataset' buffer_tag = f"buffer_{args.buffer_size}" if hasattr(args, 'buffer_size') and args.buffer_size is not None else "buffer_none" save_dir = os.path.join("./experiments", dataset_name, buffer_tag) os.makedirs(save_dir, exist_ok=True) model_filename = f"{model.NAME}_seed_{args.seed}.pth" model_path = os.path.join(save_dir, model_filename) torch.save(model.net.state_dict(), model_path) print(f"Model saved to: {model_path}")