SAE / trainCIL.py
Ttius's picture
Upload 192 files
998bb30 verified
import json
import argparse
import sys
import logging
import copy
import torch
from utils import factory
from utils.data_manager import DataManager
from utils.toolkit import count_parameters
import os
import numpy as np
def main():
args = setup_parser().parse_args()
param = load_json(args.config)
args = vars(args) # Converting argparse Namespace to a dict.
args.update(param) # Add parameters from json
train(args)
def train(args):
seed_list = copy.deepcopy(args["seed"])
device = copy.deepcopy(args["device"])
for seed in seed_list:
args["seed"] = seed
args["device"] = device
_train(args)
def _train(args):
init_cls = 0 if args["init_cls"] == args["increment"] else args["init_cls"]
logs_name = "logs/{}/{}/init_cls_{}/per_classes_{}/{}".format(args["model_name"], args["dataset"], init_cls,
args['increment'], args["convnet_type"])
if not os.path.exists(logs_name):
os.makedirs(logs_name)
logfilename = "logs/{}/{}/init_cls_{}/per_classes_{}/{}/{}_{}_{}".format(
args["model_name"],
args["dataset"],
init_cls,
args["increment"],
args["convnet_type"],
args["prefix"],
args["seed"],
args["convnet_type"],
)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(filename)s] => %(message)s",
handlers=[
logging.FileHandler(filename=logfilename + ".log"),
logging.StreamHandler(sys.stdout),
],
)
_set_random()
_set_device(args)
print_args(args)
data_manager = DataManager(
args["dataset"],
args["shuffle"],
args["seed"],
args["init_cls"],
args["increment"],
)
model = factory.get_model(args["model_name"], args)
cnn_curve, nme_curve = {"top1": [], "top5": []}, {"top1": [], "top5": []}
cnn_matrix, nme_matrix = [], []
for task in range(data_manager.nb_tasks):
logging.info("All params: {}".format(count_parameters(model._network)))
logging.info(
"Trainable params: {}".format(count_parameters(model._network, True))
)
model.incremental_train(data_manager)
cnn_accy, nme_accy = model.eval_task()
model.save_checkpoint(logs_name, args["convnet_type"])
model.after_task()
if nme_accy is not None:
logging.info("CNN: {}".format(cnn_accy["grouped"]))
logging.info("NME: {}".format(nme_accy["grouped"]))
cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key]
cnn_keys_sorted = sorted(cnn_keys)
cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted]
cnn_matrix.append(cnn_values)
nme_keys = [key for key in nme_accy["grouped"].keys() if '-' in key]
nme_keys_sorted = sorted(nme_keys)
nme_values = [nme_accy["grouped"][key] for key in nme_keys_sorted]
nme_matrix.append(nme_values)
cnn_curve["top1"].append(cnn_accy["top1"])
cnn_curve["top5"].append(cnn_accy["top5"])
nme_curve["top1"].append(nme_accy["top1"])
nme_curve["top5"].append(nme_accy["top5"])
logging.info("CNN top1 curve: {}".format(cnn_curve["top1"]))
logging.info("CNN top5 curve: {}".format(cnn_curve["top5"]))
logging.info("NME top1 curve: {}".format(nme_curve["top1"]))
logging.info("NME top5 curve: {}\n".format(nme_curve["top5"]))
print('Average Accuracy (CNN):', sum(cnn_curve["top1"]) / len(cnn_curve["top1"]))
print('Average Accuracy (NME):', sum(nme_curve["top1"]) / len(nme_curve["top1"]))
logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"]) / len(cnn_curve["top1"])))
logging.info("Average Accuracy (NME): {}".format(sum(nme_curve["top1"]) / len(nme_curve["top1"])))
else:
logging.info("No NME accuracy.")
logging.info("CNN: {}".format(cnn_accy["grouped"]))
cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key]
cnn_keys_sorted = sorted(cnn_keys)
cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted]
cnn_matrix.append(cnn_values)
cnn_curve["top1"].append(cnn_accy["top1"])
cnn_curve["top5"].append(cnn_accy["top5"])
logging.info("CNN top1 curve: {}".format(cnn_curve["top1"]))
logging.info("CNN top5 curve: {}\n".format(cnn_curve["top5"]))
print('Average Accuracy (CNN):', sum(cnn_curve["top1"]) / len(cnn_curve["top1"]))
logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"]) / len(cnn_curve["top1"])))
if len(cnn_matrix) > 0:
np_acctable = np.zeros([task + 1, task + 1])
for idxx, line in enumerate(cnn_matrix):
idxy = len(line)
np_acctable[idxx, :idxy] = np.array(line)
np_acctable = np_acctable.T
forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, task])[:task])
print('Accuracy Matrix (CNN):')
print(np_acctable)
formatted_str = "\n".join(
["\t[" + ", ".join("{:.1f}".format(float(val)) for val in row) + "]" for row in np_acctable])
logging.info("Accuracy Matrix (CNN):\n{}".format(formatted_str))
# logging.info('Accuracy Matrix (CNN): {}'.format(np_acctable.tolist()))
print('Forgetting (CNN):', forgetting)
logging.info('Forgetting (CNN): {}'.format(forgetting))
if len(nme_matrix) > 0:
np_acctable = np.zeros([task + 1, task + 1])
for idxx, line in enumerate(nme_matrix):
idxy = len(line)
np_acctable[idxx, :idxy] = np.array(line)
np_acctable = np_acctable.T
forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, task])[:task])
print('Accuracy Matrix (NME):')
print(np_acctable)
formatted_str = "\n".join(
["\t[" + ", ".join("{:.1f}".format(float(val)) for val in row) + "]" for row in np_acctable])
logging.info("Accuracy Matrix (NME):\n{}".format(formatted_str))
print('Forgetting (NME):', forgetting)
logging.info('Forgetting (NME): {}'.format(forgetting))
def _set_device(args):
device_type = args["device"]
gpus = []
for device in device_type:
if device_type == -1:
device = torch.device("cpu")
else:
device = torch.device("cuda:{}".format(device))
gpus.append(device)
args["device"] = gpus
def _set_random():
torch.manual_seed(1)
torch.cuda.manual_seed(1)
torch.cuda.manual_seed_all(1)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def print_args(args):
for key, value in args.items():
logging.info("{}: {}".format(key, value))
def load_json(settings_path):
with open(settings_path) as data_file:
param = json.load(data_file)
return param
def setup_parser():
parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorithms.')
parser.add_argument('--config', type=str, default='./exps/gem.json',
help='Json file of settings.')
return parser
if __name__ == '__main__':
main()