File size: 2,425 Bytes
197d4ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np
import torch
import torch.optim
import os
import random

from methods import backbone
from methods.backbone_multiblock import model_dict
from data.datamgr import SimpleDataManager, SetDataManager
from methods.StyleAdv_RN_GNN import StyleAdvGNN

from options import parse_args, get_resume_file, load_warmup_state
from test_function_fwt_benchmark import test_bestmodel
from test_function_bscdfsl_benchmark import test_bestmodel_bscdfsl


def record_test_result(params):
    acc_file_path = "tmp2.txt"
    acc_file = open(acc_file_path, "w")
    epoch_id = -1
    print(
        "epoch",
        epoch_id,
        "miniImagenet:",
        "cub:",
        "cars:",
        "places:",
        "plantae:",
        file=acc_file,
    )
    name = params.name
    n_shot = params.n_shot
    method = params.method
    test_bestmodel(acc_file, name, method, "miniImagenet", n_shot, epoch_id)
    # test_bestmodel(acc_file, name, method, 'cub', n_shot, epoch_id)
    # test_bestmodel(acc_file, name, method, 'cars', n_shot, epoch_id)
    # test_bestmodel(acc_file, name, method, 'places', n_shot, epoch_id)
    # test_bestmodel(acc_file, name, method, 'plantae', n_shot, epoch_id)

    acc_file.close()
    return


def record_test_result_bscdfsl(params):
    print("hhhhhhh testing for bscdfsl")
    acc_file_path = "tmp_bscdfsl2.txt"
    acc_file = open(acc_file_path, "w")
    epoch_id = -1
    print(
        "epoch", epoch_id, "ChestX:", "ISIC:", "EuroSAT:", "CropDisease", file=acc_file
    )
    name = params.name
    n_shot = params.n_shot
    method = params.method
    # test_bestmodel_bscdfsl(acc_file, name, method, 'ChestX', n_shot, epoch_id)
    # test_bestmodel_bscdfsl(acc_file, name, method, 'ISIC', n_shot, epoch_id)
    test_bestmodel_bscdfsl(acc_file, name, method, "EuroSAT", n_shot, epoch_id)
    # test_bestmodel_bscdfsl(acc_file, name, method, 'CropDisease', n_shot, epoch_id)

    acc_file.close()
    return


# --- main function ---
if __name__ == "__main__":
    # fix seed
    seed = 0
    print("set seed = %d" % seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # parser argument
    params = parse_args("train")

    # testing
    # record_test_result(params)
    # testing bscdfsl
    record_test_result_bscdfsl(params)