|
|
import numpy as np |
|
|
import torch |
|
|
import torch.optim |
|
|
import os |
|
|
import random |
|
|
|
|
|
from methods import backbone |
|
|
from methods.backbone_multiblock import model_dict |
|
|
from data.datamgr import SimpleDataManager, SetDataManager |
|
|
from methods.StyleAdv_RN_GNN import StyleAdvGNN |
|
|
|
|
|
from options import parse_args, get_resume_file, load_warmup_state |
|
|
from test_function_fwt_benchmark import test_bestmodel |
|
|
from test_function_bscdfsl_benchmark import test_bestmodel_bscdfsl |
|
|
|
|
|
|
|
|
def record_test_result(params): |
|
|
acc_file_path = "tmp2.txt" |
|
|
acc_file = open(acc_file_path, "w") |
|
|
epoch_id = -1 |
|
|
print( |
|
|
"epoch", |
|
|
epoch_id, |
|
|
"miniImagenet:", |
|
|
"cub:", |
|
|
"cars:", |
|
|
"places:", |
|
|
"plantae:", |
|
|
file=acc_file, |
|
|
) |
|
|
name = params.name |
|
|
n_shot = params.n_shot |
|
|
method = params.method |
|
|
test_bestmodel(acc_file, name, method, "miniImagenet", n_shot, epoch_id) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
acc_file.close() |
|
|
return |
|
|
|
|
|
|
|
|
def record_test_result_bscdfsl(params): |
|
|
print("hhhhhhh testing for bscdfsl") |
|
|
acc_file_path = "tmp_bscdfsl2.txt" |
|
|
acc_file = open(acc_file_path, "w") |
|
|
epoch_id = -1 |
|
|
print( |
|
|
"epoch", epoch_id, "ChestX:", "ISIC:", "EuroSAT:", "CropDisease", file=acc_file |
|
|
) |
|
|
name = params.name |
|
|
n_shot = params.n_shot |
|
|
method = params.method |
|
|
|
|
|
|
|
|
test_bestmodel_bscdfsl(acc_file, name, method, "EuroSAT", n_shot, epoch_id) |
|
|
|
|
|
|
|
|
acc_file.close() |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
seed = 0 |
|
|
print("set seed = %d" % seed) |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
params = parse_args("train") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
record_test_result_bscdfsl(params) |
|
|
|