import numpy as np import torch import torch.optim import os import random from methods import backbone from methods.backbone_multiblock import model_dict from data.datamgr import SimpleDataManager, SetDataManager from methods.StyleAdv_RN_GNN import StyleAdvGNN from options import parse_args, get_resume_file, load_warmup_state from test_function_fwt_benchmark import test_bestmodel from test_function_bscdfsl_benchmark import test_bestmodel_bscdfsl def record_test_result(params): acc_file_path = "tmp2.txt" acc_file = open(acc_file_path, "w") epoch_id = -1 print( "epoch", epoch_id, "miniImagenet:", "cub:", "cars:", "places:", "plantae:", file=acc_file, ) name = params.name n_shot = params.n_shot method = params.method test_bestmodel(acc_file, name, method, "miniImagenet", n_shot, epoch_id) # test_bestmodel(acc_file, name, method, 'cub', n_shot, epoch_id) # test_bestmodel(acc_file, name, method, 'cars', n_shot, epoch_id) # test_bestmodel(acc_file, name, method, 'places', n_shot, epoch_id) # test_bestmodel(acc_file, name, method, 'plantae', n_shot, epoch_id) acc_file.close() return def record_test_result_bscdfsl(params): print("hhhhhhh testing for bscdfsl") acc_file_path = "tmp_bscdfsl2.txt" acc_file = open(acc_file_path, "w") epoch_id = -1 print( "epoch", epoch_id, "ChestX:", "ISIC:", "EuroSAT:", "CropDisease", file=acc_file ) name = params.name n_shot = params.n_shot method = params.method # test_bestmodel_bscdfsl(acc_file, name, method, 'ChestX', n_shot, epoch_id) # test_bestmodel_bscdfsl(acc_file, name, method, 'ISIC', n_shot, epoch_id) test_bestmodel_bscdfsl(acc_file, name, method, "EuroSAT", n_shot, epoch_id) # test_bestmodel_bscdfsl(acc_file, name, method, 'CropDisease', n_shot, epoch_id) acc_file.close() return # --- main function --- if __name__ == "__main__": # fix seed seed = 0 print("set seed = %d" % seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # parser argument params = parse_args("train") # testing # record_test_result(params) # testing bscdfsl record_test_result_bscdfsl(params)