aigv / core /utils1 /config.py
Qafig's picture
Upload folder using huggingface_hub
73e19ac verified
import argparse
import os
import sys
from abc import ABC
from typing import Type
class DefaultConfigs(ABC):
####### base setting ######
gpus = [0]
seed = 3407
arch = "resnet50"
datasets = ["zhaolian_train"]
datasets_test = ["adm_res_abs_ddim20s"]
mode = "binary"
class_bal = False
batch_size = 64
loadSize = 256
cropSize = 224
epoch = "latest"
num_workers = 20
serial_batches = False
isTrain = True
# data augmentation
rz_interp = ["bilinear"]
# blur_prob = 0.0
blur_prob = 0.1
blur_sig = [0.5]
# jpg_prob = 0.0
jpg_prob = 0.1
jpg_method = ["cv2"]
jpg_qual = [75]
gray_prob = 0.0
aug_resize = True
aug_crop = True
aug_flip = True
aug_norm = True
####### train setting ######
warmup = False
# warmup = True
warmup_epoch = 3
earlystop = True
earlystop_epoch = 5
optim = "adam"
new_optim = False
loss_freq = 400
save_latest_freq = 2000
save_epoch_freq = 20
continue_train = False
epoch_count = 1
last_epoch = -1
nepoch = 400
beta1 = 0.9
lr = 0.0001
init_type = "normal"
init_gain = 0.02
pretrained = True
# paths information
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
dataset_root = os.path.join(root_dir, "data")
exp_root = os.path.join(root_dir, "data", "exp")
_exp_name = ""
exp_dir = ""
ckpt_dir = ""
logs_path = ""
ckpt_path = ""
@property
def exp_name(self):
return self._exp_name
@exp_name.setter
def exp_name(self, value: str):
self._exp_name = value
self.exp_dir: str = os.path.join(self.exp_root, self.exp_name)
self.ckpt_dir: str = os.path.join(self.exp_dir, "ckpt")
self.logs_path: str = os.path.join(self.exp_dir, "logs.txt")
os.makedirs(self.exp_dir, exist_ok=True)
os.makedirs(self.ckpt_dir, exist_ok=True)
def to_dict(self):
dic = {}
for fieldkey in dir(self):
fieldvalue = getattr(self, fieldkey)
if not fieldkey.startswith("__") and not callable(fieldvalue) and not fieldkey.startswith("_"):
dic[fieldkey] = fieldvalue
return dic
def args_list2dict(arg_list: list):
assert len(arg_list) % 2 == 0, f"Override list has odd length: {arg_list}; it must be a list of pairs"
return dict(zip(arg_list[::2], arg_list[1::2]))
def str2bool(v: str) -> bool:
if isinstance(v, bool):
return v
elif v.lower() in ("true", "yes", "on", "y", "t", "1"):
return True
elif v.lower() in ("false", "no", "off", "n", "f", "0"):
return False
else:
return bool(v)
def str2list(v: str, element_type=None) -> list:
if not isinstance(v, (list, tuple, set)):
v = v.lstrip("[").rstrip("]")
v = v.split(",")
v = list(map(str.strip, v))
if element_type is not None:
v = list(map(element_type, v))
return v
CONFIGCLASS = Type[DefaultConfigs]
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", default=[0], type=int, nargs="+")
parser.add_argument("--exp_name", default="", type=str)
parser.add_argument("--ckpt", default="model_epoch_latest.pth", type=str)
parser.add_argument("opts", default=[], nargs=argparse.REMAINDER)
args = parser.parse_args()
if os.path.exists(os.path.join(DefaultConfigs.exp_root, args.exp_name, "config.py")):
sys.path.insert(0, os.path.join(DefaultConfigs.exp_root, args.exp_name))
from config import cfg
cfg: CONFIGCLASS
else:
cfg = DefaultConfigs()
if args.opts:
opts = args_list2dict(args.opts)
for k, v in opts.items():
if not hasattr(cfg, k):
raise ValueError(f"Unrecognized option: {k}")
original_type = type(getattr(cfg, k))
if original_type == bool:
setattr(cfg, k, str2bool(v))
elif original_type in (list, tuple, set):
setattr(cfg, k, str2list(v, type(getattr(cfg, k)[0])))
else:
setattr(cfg, k, original_type(v))
cfg.gpus: list = args.gpus
os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join([str(gpu) for gpu in cfg.gpus])
cfg.exp_name = args.exp_name
cfg.ckpt_path: str = os.path.join(cfg.ckpt_dir, args.ckpt)
if isinstance(cfg.datasets, str):
cfg.datasets = cfg.datasets.split(",")