AMontiB
Your original commit message (now includes LFS pointer)
9c4b1c4
import os
import os.path
import sys
import logging
import time
import torch
from utils.data_manager import DataManager
from utils.toolkit import count_parameters
from methods.prompt2guard import Prompt2Guard
import numpy as np
def train(args):
logfilename = "logs/{}/{}".format(
args["run_name"].replace("_", "/"),
time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()),
)
os.makedirs(logfilename)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(filename)s] => %(message)s",
handlers=[
logging.FileHandler(filename=logfilename + "/info.log"),
logging.StreamHandler(sys.stdout),
],
)
os.makedirs(f'./checkpoint/{args["run_name"]}/weights', exist_ok=True)
_set_random(args)
print_args(args)
data_manager = DataManager(
args["dataset"],
args["shuffle"],
args["seed"],
args["init_cls"],
args["increment"],
args,
)
args["class_order"] = data_manager._class_order
args["filename"] = os.path.join(logfilename, "task")
model = Prompt2Guard(args)
acc_matrix = {
"top1": np.zeros((data_manager.nb_tasks, data_manager.nb_tasks)),
"mean": np.zeros((data_manager.nb_tasks, data_manager.nb_tasks)),
"mix_top_mean": np.zeros((data_manager.nb_tasks, data_manager.nb_tasks)),
}
label_history = []
for task in range(data_manager.nb_tasks):
logging.info("All params: {}".format(count_parameters(model.network)))
logging.info("Trainable params: {}".format(count_parameters(model.network, True)))
model.incremental_train(data_manager)
record_task_accuracy(task, model.eval_task(), acc_matrix, label_history)
model.after_task(data_manager.nb_tasks)
model.save_checkpoint()
def _compute_AF(matrix):
total_bwt = 0
N = matrix.shape[0]
for i in range(N - 1): # Iterate through each task except the last one
bwt_i = 0
for j in range(i + 1, N): # Iterate from task i+1 to N to calculate BWT_i
bwt_i += matrix[j, i] - matrix[i, i]
bwt_i /= N - i - 1 # Normalize by the number of tasks considered for this BWT_i
total_bwt += bwt_i
af = total_bwt / (N - 1) # Calculate the average of all BWT_i
return af
def compute_forgetting(model: Prompt2Guard, acc_matrix):
for k in acc_matrix.keys():
forgetting = _compute_AF(acc_matrix[k])
logging.info("Avg Forgetting of {}: {:.4f}".format(k, forgetting))
def record_task_accuracy(
current_task, current_task_acc: dict, matrix_dict: dict, label_history: list
):
label_history.append(
"{}-{}".format(
str(current_task * 2).zfill(2), str(current_task * 2 + 1).zfill(2)
)
)
for logit_ops in current_task_acc.keys():
dict_subset = {
k: current_task_acc[logit_ops][k]
for k in label_history
if k in current_task_acc[logit_ops]
}
for idx_label, label_task in enumerate(dict_subset):
matrix_dict[logit_ops][current_task][idx_label] = current_task_acc[
logit_ops
][label_task]
for key, value in current_task_acc.items():
logging.info(f"Performance Task {current_task} for {key}: {value}")
def _set_device(args):
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
logging.info("Device: " + device.type)
args["device"] = device
def _set_random(args):
torch.manual_seed(args["torch_seed"])
torch.cuda.manual_seed(args["torch_seed"])
torch.cuda.manual_seed_all(args["torch_seed"])
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def print_args(args):
for key, value in args.items():
logging.info("{}: {}".format(key, value))