import json import argparse import sys import logging import copy import torch from utils import factory from utils.data_manager import DataManager from utils.toolkit import count_parameters import os import numpy as np def main(): args = setup_parser().parse_args() param = load_json(args.config) args = vars(args) # Converting argparse Namespace to a dict. args.update(param) # Add parameters from json train(args) def train(args): seed_list = copy.deepcopy(args["seed"]) device = copy.deepcopy(args["device"]) for seed in seed_list: args["seed"] = seed args["device"] = device _train(args) def _train(args): init_cls = 0 if args["init_cls"] == args["increment"] else args["init_cls"] logs_name = "logs/{}/{}/init_cls_{}/per_classes_{}/{}".format(args["model_name"], args["dataset"], init_cls, args['increment'], args["convnet_type"]) if not os.path.exists(logs_name): os.makedirs(logs_name) logfilename = "logs/{}/{}/init_cls_{}/per_classes_{}/{}/{}_{}_{}".format( args["model_name"], args["dataset"], init_cls, args["increment"], args["convnet_type"], args["prefix"], args["seed"], args["convnet_type"], ) logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(filename)s] => %(message)s", handlers=[ logging.FileHandler(filename=logfilename + ".log"), logging.StreamHandler(sys.stdout), ], ) _set_random() _set_device(args) print_args(args) data_manager = DataManager( args["dataset"], args["shuffle"], args["seed"], args["init_cls"], args["increment"], ) model = factory.get_model(args["model_name"], args) cnn_curve, nme_curve = {"top1": [], "top5": []}, {"top1": [], "top5": []} cnn_matrix, nme_matrix = [], [] for task in range(data_manager.nb_tasks): logging.info("All params: {}".format(count_parameters(model._network))) logging.info( "Trainable params: {}".format(count_parameters(model._network, True)) ) model.incremental_train(data_manager) cnn_accy, nme_accy = model.eval_task() model.save_checkpoint(logs_name, args["convnet_type"]) model.after_task() if nme_accy is not None: logging.info("CNN: {}".format(cnn_accy["grouped"])) logging.info("NME: {}".format(nme_accy["grouped"])) cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key] cnn_keys_sorted = sorted(cnn_keys) cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted] cnn_matrix.append(cnn_values) nme_keys = [key for key in nme_accy["grouped"].keys() if '-' in key] nme_keys_sorted = sorted(nme_keys) nme_values = [nme_accy["grouped"][key] for key in nme_keys_sorted] nme_matrix.append(nme_values) cnn_curve["top1"].append(cnn_accy["top1"]) cnn_curve["top5"].append(cnn_accy["top5"]) nme_curve["top1"].append(nme_accy["top1"]) nme_curve["top5"].append(nme_accy["top5"]) logging.info("CNN top1 curve: {}".format(cnn_curve["top1"])) logging.info("CNN top5 curve: {}".format(cnn_curve["top5"])) logging.info("NME top1 curve: {}".format(nme_curve["top1"])) logging.info("NME top5 curve: {}\n".format(nme_curve["top5"])) print('Average Accuracy (CNN):', sum(cnn_curve["top1"]) / len(cnn_curve["top1"])) print('Average Accuracy (NME):', sum(nme_curve["top1"]) / len(nme_curve["top1"])) logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"]) / len(cnn_curve["top1"]))) logging.info("Average Accuracy (NME): {}".format(sum(nme_curve["top1"]) / len(nme_curve["top1"]))) else: logging.info("No NME accuracy.") logging.info("CNN: {}".format(cnn_accy["grouped"])) cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key] cnn_keys_sorted = sorted(cnn_keys) cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted] cnn_matrix.append(cnn_values) cnn_curve["top1"].append(cnn_accy["top1"]) cnn_curve["top5"].append(cnn_accy["top5"]) logging.info("CNN top1 curve: {}".format(cnn_curve["top1"])) logging.info("CNN top5 curve: {}\n".format(cnn_curve["top5"])) print('Average Accuracy (CNN):', sum(cnn_curve["top1"]) / len(cnn_curve["top1"])) logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"]) / len(cnn_curve["top1"]))) if len(cnn_matrix) > 0: np_acctable = np.zeros([task + 1, task + 1]) for idxx, line in enumerate(cnn_matrix): idxy = len(line) np_acctable[idxx, :idxy] = np.array(line) np_acctable = np_acctable.T forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, task])[:task]) print('Accuracy Matrix (CNN):') print(np_acctable) formatted_str = "\n".join( ["\t[" + ", ".join("{:.1f}".format(float(val)) for val in row) + "]" for row in np_acctable]) logging.info("Accuracy Matrix (CNN):\n{}".format(formatted_str)) # logging.info('Accuracy Matrix (CNN): {}'.format(np_acctable.tolist())) print('Forgetting (CNN):', forgetting) logging.info('Forgetting (CNN): {}'.format(forgetting)) if len(nme_matrix) > 0: np_acctable = np.zeros([task + 1, task + 1]) for idxx, line in enumerate(nme_matrix): idxy = len(line) np_acctable[idxx, :idxy] = np.array(line) np_acctable = np_acctable.T forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, task])[:task]) print('Accuracy Matrix (NME):') print(np_acctable) formatted_str = "\n".join( ["\t[" + ", ".join("{:.1f}".format(float(val)) for val in row) + "]" for row in np_acctable]) logging.info("Accuracy Matrix (NME):\n{}".format(formatted_str)) print('Forgetting (NME):', forgetting) logging.info('Forgetting (NME): {}'.format(forgetting)) def _set_device(args): device_type = args["device"] gpus = [] for device in device_type: if device_type == -1: device = torch.device("cpu") else: device = torch.device("cuda:{}".format(device)) gpus.append(device) args["device"] = gpus def _set_random(): torch.manual_seed(1) torch.cuda.manual_seed(1) torch.cuda.manual_seed_all(1) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def print_args(args): for key, value in args.items(): logging.info("{}: {}".format(key, value)) def load_json(settings_path): with open(settings_path) as data_file: param = json.load(data_file) return param def setup_parser(): parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorithms.') parser.add_argument('--config', type=str, default='./exps/gem.json', help='Json file of settings.') return parser if __name__ == '__main__': main()