CausalStyleAdv / test.py
YuqianFu's picture
Upload folder using huggingface_hub
197d4ca verified
import numpy as np
import torch
import torch.optim
import os
import random
from methods import backbone
from methods.backbone_multiblock import model_dict
from data.datamgr import SimpleDataManager, SetDataManager
from methods.StyleAdv_RN_GNN import StyleAdvGNN
from options import parse_args, get_resume_file, load_warmup_state
from test_function_fwt_benchmark import test_bestmodel
from test_function_bscdfsl_benchmark import test_bestmodel_bscdfsl
def record_test_result(params):
acc_file_path = "tmp2.txt"
acc_file = open(acc_file_path, "w")
epoch_id = -1
print(
"epoch",
epoch_id,
"miniImagenet:",
"cub:",
"cars:",
"places:",
"plantae:",
file=acc_file,
)
name = params.name
n_shot = params.n_shot
method = params.method
test_bestmodel(acc_file, name, method, "miniImagenet", n_shot, epoch_id)
# test_bestmodel(acc_file, name, method, 'cub', n_shot, epoch_id)
# test_bestmodel(acc_file, name, method, 'cars', n_shot, epoch_id)
# test_bestmodel(acc_file, name, method, 'places', n_shot, epoch_id)
# test_bestmodel(acc_file, name, method, 'plantae', n_shot, epoch_id)
acc_file.close()
return
def record_test_result_bscdfsl(params):
print("hhhhhhh testing for bscdfsl")
acc_file_path = "tmp_bscdfsl2.txt"
acc_file = open(acc_file_path, "w")
epoch_id = -1
print(
"epoch", epoch_id, "ChestX:", "ISIC:", "EuroSAT:", "CropDisease", file=acc_file
)
name = params.name
n_shot = params.n_shot
method = params.method
# test_bestmodel_bscdfsl(acc_file, name, method, 'ChestX', n_shot, epoch_id)
# test_bestmodel_bscdfsl(acc_file, name, method, 'ISIC', n_shot, epoch_id)
test_bestmodel_bscdfsl(acc_file, name, method, "EuroSAT", n_shot, epoch_id)
# test_bestmodel_bscdfsl(acc_file, name, method, 'CropDisease', n_shot, epoch_id)
acc_file.close()
return
# --- main function ---
if __name__ == "__main__":
# fix seed
seed = 0
print("set seed = %d" % seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# parser argument
params = parse_args("train")
# testing
# record_test_result(params)
# testing bscdfsl
record_test_result_bscdfsl(params)