import os import os.path import sys import logging import time import torch from utils.data_manager import DataManager from utils.toolkit import count_parameters from methods.prompt2guard import Prompt2Guard import numpy as np def train(args): logfilename = "logs/{}/{}".format( args["run_name"].replace("_", "/"), time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()), ) os.makedirs(logfilename) logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(filename)s] => %(message)s", handlers=[ logging.FileHandler(filename=logfilename + "/info.log"), logging.StreamHandler(sys.stdout), ], ) os.makedirs(f'./checkpoint/{args["run_name"]}/weights', exist_ok=True) _set_random(args) print_args(args) data_manager = DataManager( args["dataset"], args["shuffle"], args["seed"], args["init_cls"], args["increment"], args, ) args["class_order"] = data_manager._class_order args["filename"] = os.path.join(logfilename, "task") model = Prompt2Guard(args) acc_matrix = { "top1": np.zeros((data_manager.nb_tasks, data_manager.nb_tasks)), "mean": np.zeros((data_manager.nb_tasks, data_manager.nb_tasks)), "mix_top_mean": np.zeros((data_manager.nb_tasks, data_manager.nb_tasks)), } label_history = [] for task in range(data_manager.nb_tasks): logging.info("All params: {}".format(count_parameters(model.network))) logging.info("Trainable params: {}".format(count_parameters(model.network, True))) model.incremental_train(data_manager) record_task_accuracy(task, model.eval_task(), acc_matrix, label_history) model.after_task(data_manager.nb_tasks) model.save_checkpoint() def _compute_AF(matrix): total_bwt = 0 N = matrix.shape[0] for i in range(N - 1): # Iterate through each task except the last one bwt_i = 0 for j in range(i + 1, N): # Iterate from task i+1 to N to calculate BWT_i bwt_i += matrix[j, i] - matrix[i, i] bwt_i /= N - i - 1 # Normalize by the number of tasks considered for this BWT_i total_bwt += bwt_i af = total_bwt / (N - 1) # Calculate the average of all BWT_i return af def compute_forgetting(model: Prompt2Guard, acc_matrix): for k in acc_matrix.keys(): forgetting = _compute_AF(acc_matrix[k]) logging.info("Avg Forgetting of {}: {:.4f}".format(k, forgetting)) def record_task_accuracy( current_task, current_task_acc: dict, matrix_dict: dict, label_history: list ): label_history.append( "{}-{}".format( str(current_task * 2).zfill(2), str(current_task * 2 + 1).zfill(2) ) ) for logit_ops in current_task_acc.keys(): dict_subset = { k: current_task_acc[logit_ops][k] for k in label_history if k in current_task_acc[logit_ops] } for idx_label, label_task in enumerate(dict_subset): matrix_dict[logit_ops][current_task][idx_label] = current_task_acc[ logit_ops ][label_task] for key, value in current_task_acc.items(): logging.info(f"Performance Task {current_task} for {key}: {value}") def _set_device(args): if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") logging.info("Device: " + device.type) args["device"] = device def _set_random(args): torch.manual_seed(args["torch_seed"]) torch.cuda.manual_seed(args["torch_seed"]) torch.cuda.manual_seed_all(args["torch_seed"]) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def print_args(args): for key, value in args.items(): logging.info("{}: {}".format(key, value))