File size: 5,888 Bytes
424919d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import argparse
import os
import sys
from abc import ABC
from typing import Type


class DefaultConfigs(ABC):
    ####### base setting ######
    gpus = [0]
    seed = 3407
    arch = "resnet50"
    datasets = [""]
    datasets_test = [""]
    class_bal = False
    batch_size = 256
    val_every = 1
    loadSize = 256
    cropSize = 224
    epoch = "latest"
    num_workers = 2
    isTrain = True

    ####### train setting ######
    warmup = False
    warmup_epoch = 3
    earlystop = True
    earlystop_epoch = 5
    optim = "adam"
    new_optim = False
    loss_freq = 400
    save_latest_freq = 2000
    save_epoch_freq = 20
    continue_train = False
    epoch_count = 1
    last_epoch = -1
    nepoch = 20
    beta1 = 0.9
    lr = 0.00001
    init_type = "normal"
    init_gain = 0.02
    pretrained = True

    # paths information
    root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
    dataset_root = os.path.join(root_dir, "datasets")
    dataset_test_root = os.path.join(root_dir, "datasets")
    exp_root = os.path.join(root_dir, "experiments")
    _exp_name = ""
    exp_dir = ""
    ckpt_dir = ""
    logs_path = ""
    ckpt_path = ""
    pretrained_weights = ""
    kd = True
    kd_weight = 1.
    reproduce_dire = False
    only_eps = False
    only_img = False

    @property
    def exp_name(self):
        return self._exp_name

    @exp_name.setter
    def exp_name(self, value: str):
        self._exp_name = value
        self.exp_dir: str = os.path.join(self.exp_root, self.exp_name)
        self.ckpt_dir: str = os.path.join(self.exp_dir, "ckpt")
        self.logs_path: str = os.path.join(self.exp_dir, "logs.txt")
        if self.isTrain:
            os.makedirs(self.exp_dir, exist_ok=True)
            os.makedirs(self.ckpt_dir, exist_ok=True)

    def to_dict(self):
        dic = {}
        for fieldkey in dir(self):
            fieldvalue = getattr(self, fieldkey)
            if not fieldkey.startswith("__") and not callable(fieldvalue) and not fieldkey.startswith("_"):
                dic[fieldkey] = fieldvalue
        return dic


def args_list2dict(arg_list: list):
    assert len(arg_list) % 2 == 0, f"Override list has odd length: {arg_list}; it must be a list of pairs"
    return dict(zip(arg_list[::2], arg_list[1::2]))


def str2bool(v: str) -> bool:
    if isinstance(v, bool):
        return v
    elif v.lower() in ("true", "yes", "on", "y", "t", "1"):
        return True
    elif v.lower() in ("false", "no", "off", "n", "f", "0"):
        return False
    else:
        return bool(v)


def str2list(v: str, element_type=None) -> list:
    if not isinstance(v, (list, tuple, set)):
        v = v.lstrip("[").rstrip("]")
        v = v.split(",")
        v = list(map(str.strip, v))
        if element_type is not None:
            v = list(map(element_type, v))
    return v


CONFIGCLASS = Type[DefaultConfigs]

parser = argparse.ArgumentParser()
parser.add_argument("--gpus", default="0", type=str)
parser.add_argument("--batch", default=256, type=int)
parser.add_argument("--epoch", default="100", type=int)
parser.add_argument("--exp_name", default="", type=str)
parser.add_argument("--datasets", default="", type=str)
parser.add_argument("--dataset_root", default="", type=str)
parser.add_argument("--datasets_test", default="", type=str)
parser.add_argument("--dataset_test_root", default="", type=str)
parser.add_argument("--pretrained_weights", default="", type=str)
parser.add_argument("--lr", default=0.00001, type=float)
parser.add_argument("--kd_weight", default=1., type=float)
parser.add_argument("--test", default=False, type=str2bool)
parser.add_argument("--reproduce_dire", default=False, type=str2bool)
parser.add_argument("--only_eps", default=False, type=str2bool)
parser.add_argument("--only_img", default=False, type=str2bool)
parser.add_argument("--kd", default=True, type=str2bool)
parser.add_argument("opts", default=[], nargs=argparse.REMAINDER)
args = parser.parse_args()

if os.path.exists(os.path.join(DefaultConfigs.exp_root, args.exp_name, "config.py")):
    sys.path.insert(0, os.path.join(DefaultConfigs.exp_root, args.exp_name))
    from config import cfg

    cfg: CONFIGCLASS
else:
    cfg = DefaultConfigs()

if args.opts:
    opts = args_list2dict(args.opts)
    for k, v in opts.items():
        if not hasattr(cfg, k):
            raise ValueError(f"Unrecognized option: {k}")
        original_type = type(getattr(cfg, k))
        if original_type == bool:
            setattr(cfg, k, str2bool(v))
        elif original_type in (list, tuple, set):
            setattr(cfg, k, str2list(v, type(getattr(cfg, k)[0])))
        else:
            setattr(cfg, k, original_type(v))

#cfg.gpus: list = args.gpus
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus#", ".join([str(gpu) for gpu in cfg.gpus])
if args.test:
    cfg.isTrain = False
cfg.exp_name = args.exp_name
cfg.batch_size = args.batch
cfg.datasets = args.datasets
cfg.datasets_test = args.datasets_test if args.datasets_test else args.datasets
cfg.pretrained_weights = args.pretrained_weights
cfg.lr = args.lr
cfg.nepoch = args.epoch

root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
cfg.dataset_root = os.path.join(root_dir, 'datasets', cfg.datasets)
if args.dataset_root != "":
    cfg.dataset_root = args.dataset_root
cfg.dataset_test_root = os.path.join(root_dir, 'datasets', cfg.datasets_test)
if args.dataset_test_root != "":
    cfg.dataset_test_root = args.dataset_test_root
cfg.kd = args.kd
cfg.reproduce_dire = args.reproduce_dire
cfg.only_eps = args.only_eps
cfg.only_img = args.only_img
cfg.kd_weight = args.kd_weight

# if isinstance(cfg.datasets, str):
#     cfg.datasets = cfg.datasets.split(",")