File size: 4,520 Bytes
73e19ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import argparse
import os
import sys
from abc import ABC
from typing import Type
class DefaultConfigs(ABC):
####### base setting ######
gpus = [0]
seed = 3407
arch = "resnet50"
datasets = ["zhaolian_train"]
datasets_test = ["adm_res_abs_ddim20s"]
mode = "binary"
class_bal = False
batch_size = 64
loadSize = 256
cropSize = 224
epoch = "latest"
num_workers = 20
serial_batches = False
isTrain = True
# data augmentation
rz_interp = ["bilinear"]
# blur_prob = 0.0
blur_prob = 0.1
blur_sig = [0.5]
# jpg_prob = 0.0
jpg_prob = 0.1
jpg_method = ["cv2"]
jpg_qual = [75]
gray_prob = 0.0
aug_resize = True
aug_crop = True
aug_flip = True
aug_norm = True
####### train setting ######
warmup = False
# warmup = True
warmup_epoch = 3
earlystop = True
earlystop_epoch = 5
optim = "adam"
new_optim = False
loss_freq = 400
save_latest_freq = 2000
save_epoch_freq = 20
continue_train = False
epoch_count = 1
last_epoch = -1
nepoch = 400
beta1 = 0.9
lr = 0.0001
init_type = "normal"
init_gain = 0.02
pretrained = True
# paths information
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
dataset_root = os.path.join(root_dir, "data")
exp_root = os.path.join(root_dir, "data", "exp")
_exp_name = ""
exp_dir = ""
ckpt_dir = ""
logs_path = ""
ckpt_path = ""
@property
def exp_name(self):
return self._exp_name
@exp_name.setter
def exp_name(self, value: str):
self._exp_name = value
self.exp_dir: str = os.path.join(self.exp_root, self.exp_name)
self.ckpt_dir: str = os.path.join(self.exp_dir, "ckpt")
self.logs_path: str = os.path.join(self.exp_dir, "logs.txt")
os.makedirs(self.exp_dir, exist_ok=True)
os.makedirs(self.ckpt_dir, exist_ok=True)
def to_dict(self):
dic = {}
for fieldkey in dir(self):
fieldvalue = getattr(self, fieldkey)
if not fieldkey.startswith("__") and not callable(fieldvalue) and not fieldkey.startswith("_"):
dic[fieldkey] = fieldvalue
return dic
def args_list2dict(arg_list: list):
assert len(arg_list) % 2 == 0, f"Override list has odd length: {arg_list}; it must be a list of pairs"
return dict(zip(arg_list[::2], arg_list[1::2]))
def str2bool(v: str) -> bool:
if isinstance(v, bool):
return v
elif v.lower() in ("true", "yes", "on", "y", "t", "1"):
return True
elif v.lower() in ("false", "no", "off", "n", "f", "0"):
return False
else:
return bool(v)
def str2list(v: str, element_type=None) -> list:
if not isinstance(v, (list, tuple, set)):
v = v.lstrip("[").rstrip("]")
v = v.split(",")
v = list(map(str.strip, v))
if element_type is not None:
v = list(map(element_type, v))
return v
CONFIGCLASS = Type[DefaultConfigs]
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", default=[0], type=int, nargs="+")
parser.add_argument("--exp_name", default="", type=str)
parser.add_argument("--ckpt", default="model_epoch_latest.pth", type=str)
parser.add_argument("opts", default=[], nargs=argparse.REMAINDER)
args = parser.parse_args()
if os.path.exists(os.path.join(DefaultConfigs.exp_root, args.exp_name, "config.py")):
sys.path.insert(0, os.path.join(DefaultConfigs.exp_root, args.exp_name))
from config import cfg
cfg: CONFIGCLASS
else:
cfg = DefaultConfigs()
if args.opts:
opts = args_list2dict(args.opts)
for k, v in opts.items():
if not hasattr(cfg, k):
raise ValueError(f"Unrecognized option: {k}")
original_type = type(getattr(cfg, k))
if original_type == bool:
setattr(cfg, k, str2bool(v))
elif original_type in (list, tuple, set):
setattr(cfg, k, str2list(v, type(getattr(cfg, k)[0])))
else:
setattr(cfg, k, original_type(v))
cfg.gpus: list = args.gpus
os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join([str(gpu) for gpu in cfg.gpus])
cfg.exp_name = args.exp_name
cfg.ckpt_path: str = os.path.join(cfg.ckpt_dir, args.ckpt)
if isinstance(cfg.datasets, str):
cfg.datasets = cfg.datasets.split(",")
|