|
|
import numpy as np |
|
|
import os |
|
|
import yaml |
|
|
import logging |
|
|
from collections import OrderedDict |
|
|
from .helpers.ordered_easydict import OrderedEasyDict as edict |
|
|
|
|
|
__C = edict() |
|
|
cfg = __C |
|
|
|
|
|
|
|
|
__C.SEED = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__C.DATASET = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__C.ROOT_DIR = 'data/HKO-7' |
|
|
|
|
|
__C.MNIST_PATH = os.path.join(__C.ROOT_DIR, 'mnist_data') |
|
|
if not os.path.exists(__C.MNIST_PATH): |
|
|
os.makedirs(__C.MNIST_PATH) |
|
|
__C.HKO_DATA_BASE_PATH = os.path.join(__C.ROOT_DIR, 'hko_data') |
|
|
|
|
|
|
|
|
possible_hko_png_paths = [os.path.join('E:\\datasets\\HKO-data\\radarPNG\\radarPNG'), |
|
|
os.path.join(__C.HKO_DATA_BASE_PATH, 'radarPNG'), |
|
|
'data/HKO-7/radarPNG'] |
|
|
possible_hko_mask_paths = [os.path.join('E:\\datasets\\HKO-data\\radarPNG\\radarPNG_mask'), |
|
|
os.path.join(__C.HKO_DATA_BASE_PATH, 'radarPNG_mask'), |
|
|
'data/HKO-7/radarPNG_mask'] |
|
|
|
|
|
|
|
|
find_hko_png_path = False |
|
|
for ele in possible_hko_png_paths: |
|
|
if os.path.exists(ele): |
|
|
find_hko_png_path = True |
|
|
__C.HKO_PNG_PATH = ele |
|
|
break |
|
|
if not find_hko_png_path: |
|
|
raise RuntimeError("radarPNG is not found! You can download the radarPNG using" |
|
|
" `bash download_radar_png.bash`") |
|
|
|
|
|
find_hko_mask_path = False |
|
|
for ele in possible_hko_mask_paths: |
|
|
if os.path.exists(ele): |
|
|
find_hko_mask_path = True |
|
|
__C.HKO_MASK_PATH = ele |
|
|
break |
|
|
if not find_hko_mask_path: |
|
|
raise RuntimeError("radarPNG_mask is not found! You can download the radarPNG_mask using" |
|
|
" `bash download_radar_png.bash`") |
|
|
if not os.path.exists(__C.HKO_DATA_BASE_PATH): |
|
|
os.makedirs(__C.HKO_DATA_BASE_PATH) |
|
|
__C.HKO_PD_BASE_PATH = os.path.join(__C.HKO_DATA_BASE_PATH, 'pd') |
|
|
if not os.path.exists(__C.HKO_PD_BASE_PATH): |
|
|
os.makedirs(__C.HKO_PD_BASE_PATH) |
|
|
__C.HKO_VALID_DATETIME_PATH = os.path.join(__C.HKO_DATA_BASE_PATH, 'valid_datetime.pkl') |
|
|
__C.HKO_SORTED_DAYS_PATH = os.path.join(__C.HKO_DATA_BASE_PATH, 'sorted_day.pkl') |
|
|
__C.HKO_RAINY_TRAIN_DAYS_PATH = os.path.join(__C.HKO_DATA_BASE_PATH, 'hko7_rainy_train_days.txt') |
|
|
__C.HKO_RAINY_VALID_DAYS_PATH = os.path.join(__C.HKO_DATA_BASE_PATH, 'hko7_rainy_valid_days.txt') |
|
|
__C.HKO_RAINY_TEST_DAYS_PATH = os.path.join(__C.HKO_DATA_BASE_PATH, 'hko7_rainy_test_days.txt') |
|
|
|
|
|
__C.HKO_PD = edict() |
|
|
__C.HKO_PD.ALL = os.path.join(__C.HKO_PD_BASE_PATH, 'hko7_all.pkl') |
|
|
__C.HKO_PD.ALL_09_14 = os.path.join(__C.HKO_PD_BASE_PATH, 'hko7_all_09_14.pkl') |
|
|
__C.HKO_PD.ALL_15 = os.path.join(__C.HKO_PD_BASE_PATH, 'hko7_all_15.pkl') |
|
|
__C.HKO_PD.RAINY_TRAIN = os.path.join(__C.HKO_PD_BASE_PATH, 'hko7_rainy_train.pkl') |
|
|
__C.HKO_PD.RAINY_VALID = os.path.join(__C.HKO_PD_BASE_PATH, 'hko7_rainy_valid.pkl') |
|
|
__C.HKO_PD.RAINY_TEST = os.path.join(__C.HKO_PD_BASE_PATH, 'hko7_rainy_test.pkl') |
|
|
|
|
|
__C.HKO = edict() |
|
|
__C.HKO.ITERATOR = edict() |
|
|
__C.HKO.ITERATOR.WIDTH = 480 |
|
|
__C.HKO.ITERATOR.HEIGHT = 480 |
|
|
__C.HKO.ITERATOR.FILTER_RAINFALL = True |
|
|
__C.HKO.ITERATOR.FILTER_RAINFALL_THRESHOLD = 0.28 |
|
|
|
|
|
|
|
|
|
|
|
__C.HKO.BENCHMARK = edict() |
|
|
__C.HKO.BENCHMARK.STAT_PATH = os.path.join(__C.HKO_DATA_BASE_PATH, 'benchmark_stat') |
|
|
if not os.path.exists(__C.HKO.BENCHMARK.STAT_PATH): |
|
|
os.makedirs(__C.HKO.BENCHMARK.STAT_PATH) |
|
|
__C.HKO.BENCHMARK.VISUALIZE_SEQ_NUM = 10 |
|
|
__C.HKO.BENCHMARK.IN_LEN = 5 |
|
|
__C.HKO.BENCHMARK.OUT_LEN = 20 |
|
|
__C.HKO.BENCHMARK.STRIDE = 5 |
|
|
|
|
|
|
|
|
__C.HKO.EVALUATION = edict() |
|
|
__C.HKO.EVALUATION.ZR = edict() |
|
|
__C.HKO.EVALUATION.ZR.a = 58.53 |
|
|
__C.HKO.EVALUATION.ZR.b = 1.56 |
|
|
__C.HKO.EVALUATION.THRESHOLDS = (0.5, 2, 5, 10, 30) |
|
|
__C.HKO.EVALUATION.BALANCING_WEIGHTS = (1, 1, 2, 5, 10, 30) |
|
|
__C.HKO.EVALUATION.CENTRAL_REGION = (120, 120, 360, 360) |
|
|
|
|
|
__C.MOVINGMNIST = edict() |
|
|
__C.MOVINGMNIST.DISTRACTOR_NUM = 0 |
|
|
__C.MOVINGMNIST.VELOCITY_LOWER = 0.0 |
|
|
__C.MOVINGMNIST.VELOCITY_UPPER = 3.6 |
|
|
__C.MOVINGMNIST.SCALE_VARIATION_LOWER = 1/1.1 |
|
|
__C.MOVINGMNIST.SCALE_VARIATION_UPPER = 1.1 |
|
|
__C.MOVINGMNIST.ROTATION_LOWER = -30 |
|
|
__C.MOVINGMNIST.ROTATION_UPPER = 30 |
|
|
__C.MOVINGMNIST.ILLUMINATION_LOWER = 0.6 |
|
|
__C.MOVINGMNIST.ILLUMINATION_UPPER = 1.0 |
|
|
__C.MOVINGMNIST.DIGIT_NUM = 3 |
|
|
__C.MOVINGMNIST.IN_LEN = 10 |
|
|
__C.MOVINGMNIST.OUT_LEN = 10 |
|
|
__C.MOVINGMNIST.TESTING_LEN = 20 |
|
|
__C.MOVINGMNIST.IMG_SIZE = 64 |
|
|
__C.MOVINGMNIST.TEST_FILE = os.path.join(__C.MNIST_PATH, "movingmnist_10000_nodistr.npz") |
|
|
|
|
|
__C.MODEL = edict() |
|
|
__C.MODEL.RESUME = False |
|
|
__C.MODEL.TESTING = False |
|
|
__C.MODEL.LOAD_DIR = "" |
|
|
|
|
|
__C.MODEL.LOAD_ITER = 79999 |
|
|
__C.MODEL.SAVE_DIR = "" |
|
|
__C.MODEL.CNN_ACT_TYPE = "leaky" |
|
|
__C.MODEL.RNN_ACT_TYPE = "leaky" |
|
|
__C.MODEL.FRAME_STACK = 1 |
|
|
__C.MODEL.FRAME_SKIP = 1 |
|
|
__C.MODEL.IN_LEN = 5 |
|
|
__C.MODEL.OUT_LEN = 20 |
|
|
__C.MODEL.OUT_TYPE = "direct" |
|
|
__C.MODEL.NORMAL_LOSS_GLOBAL_SCALE = 0.00005 |
|
|
__C.MODEL.USE_BALANCED_LOSS = True |
|
|
__C.MODEL.TEMPORAL_WEIGHT_TYPE = "same" |
|
|
__C.MODEL.TEMPORAL_WEIGHT_UPPER = 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__C.MODEL.L1_LAMBDA = 1.0 |
|
|
__C.MODEL.L2_LAMBDA = 1.0 |
|
|
__C.MODEL.GDL_LAMBDA = 0.0 |
|
|
__C.MODEL.USE_SEASONALITY = False |
|
|
|
|
|
__C.MODEL.TRAJRNN = edict() |
|
|
__C.MODEL.TRAJRNN.INIT_GRID = True |
|
|
__C.MODEL.TRAJRNN.FLOW_LR_MULT = 1.0 |
|
|
__C.MODEL.TRAJRNN.SAVE_MID_RESULTS = False |
|
|
|
|
|
__C.MODEL.ENCODER_FORECASTER = edict() |
|
|
__C.MODEL.ENCODER_FORECASTER.HAS_MASK = True |
|
|
__C.MODEL.ENCODER_FORECASTER.FEATMAP_SIZE = [96, 32, 16] |
|
|
__C.MODEL.ENCODER_FORECASTER.FIRST_CONV = (8, 7, 5, 1) |
|
|
__C.MODEL.ENCODER_FORECASTER.LAST_DECONV = (8, 7, 5, 1) |
|
|
__C.MODEL.ENCODER_FORECASTER.DOWNSAMPLE = [(5, 3, 1), |
|
|
(3, 2, 1)] |
|
|
__C.MODEL.ENCODER_FORECASTER.UPSAMPLE = [(5, 3, 1), |
|
|
(4, 2, 1)] |
|
|
|
|
|
__C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS = edict() |
|
|
|
|
|
__C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS.RES_CONNECTION = True |
|
|
__C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS.LAYER_TYPE = ["ConvGRU", "ConvGRU", "ConvGRU"] |
|
|
__C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS.STACK_NUM = [2, 3, 3] |
|
|
|
|
|
__C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS.NUM_FILTER = [32, 64, 64] |
|
|
__C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS.H2H_KERNEL = [(5, 5), (5, 5), (3, 3)] |
|
|
__C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS.H2H_DILATE = [(1, 1), (1, 1), (1, 1)] |
|
|
__C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS.I2H_KERNEL = [(3, 3), (3, 3), (3, 3)] |
|
|
__C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS.I2H_PAD = [(1, 1), (1, 1), (1, 1)] |
|
|
|
|
|
__C.MODEL.ENCODER_FORECASTER.RNN_BLOCKS.L = [5, 5, 5] |
|
|
|
|
|
__C.MODEL.DECONVBASELINE = edict() |
|
|
__C.MODEL.DECONVBASELINE.BASE_NUM_FILTER = 16 |
|
|
__C.MODEL.DECONVBASELINE.USE_3D = True |
|
|
__C.MODEL.DECONVBASELINE.ENCODER = "separate" |
|
|
__C.MODEL.DECONVBASELINE.BN = True |
|
|
__C.MODEL.DECONVBASELINE.BN_GLOBAL_STATS = False |
|
|
__C.MODEL.DECONVBASELINE.COMPAT = edict() |
|
|
__C.MODEL.DECONVBASELINE.COMPAT.CONV_INSTEADOF_FC_IN_ENCODER = False |
|
|
__C.MODEL.DECONVBASELINE.FC_BETWEEN_ENCDEC = 0 |
|
|
|
|
|
__C.MODEL.TRAIN = edict() |
|
|
__C.MODEL.TRAIN.BATCH_SIZE = 3 |
|
|
__C.MODEL.TRAIN.TBPTT = False |
|
|
__C.MODEL.TRAIN.OPTIMIZER = "adam" |
|
|
__C.MODEL.TRAIN.LR = 1E-4 |
|
|
__C.MODEL.TRAIN.GAMMA1 = 0.9 |
|
|
__C.MODEL.TRAIN.BETA1 = 0.5 |
|
|
__C.MODEL.TRAIN.EPS = 1E-8 |
|
|
__C.MODEL.TRAIN.MIN_LR = 1E-6 |
|
|
__C.MODEL.TRAIN.GRAD_CLIP = 50.0 |
|
|
__C.MODEL.TRAIN.WD = 0 |
|
|
__C.MODEL.TRAIN.MAX_ITER = 180000 |
|
|
__C.MODEL.VALID_ITER = 5000 |
|
|
__C.MODEL.SAVE_ITER = 15000 |
|
|
__C.MODEL.TRAIN.LR_DECAY_ITER = 10000 |
|
|
__C.MODEL.TRAIN.LR_DECAY_FACTOR = 0.7 |
|
|
|
|
|
__C.MODEL.TEST = edict() |
|
|
__C.MODEL.TEST.FINETUNE = True |
|
|
__C.MODEL.TEST.MAX_ITER = 1 |
|
|
__C.MODEL.TEST.MODE = "online" |
|
|
__C.MODEL.TEST.DISABLE_TBPTT = True |
|
|
__C.MODEL.TEST.ONLINE = edict() |
|
|
__C.MODEL.TEST.ONLINE.OPTIMIZER = "adagrad" |
|
|
__C.MODEL.TEST.ONLINE.LR = 1E-4 |
|
|
__C.MODEL.TEST.ONLINE.FINETUNE_MIN_MSE = 0.0 |
|
|
__C.MODEL.TEST.ONLINE.GAMMA1 = 0.9 |
|
|
__C.MODEL.TEST.ONLINE.BETA1 = 0.5 |
|
|
__C.MODEL.TEST.ONLINE.EPS = 1E-6 |
|
|
__C.MODEL.TEST.ONLINE.GRAD_CLIP = 50.0 |
|
|
__C.MODEL.TEST.ONLINE.WD = 0 |
|
|
|
|
|
|
|
|
def _merge_two_config(user_cfg, default_cfg): |
|
|
""" Merge user's config into default config dictionary, clobbering the |
|
|
options in b whenever they are also specified in a. |
|
|
Need to ensure the type of two val under same key are the same |
|
|
Do recursive merge when encounter hierarchical dictionary |
|
|
""" |
|
|
if type(user_cfg) is not edict: |
|
|
return |
|
|
for key, val in user_cfg.items(): |
|
|
|
|
|
if not key in default_cfg: |
|
|
raise KeyError('{} is not a valid config key'.format(key)) |
|
|
|
|
|
if (type(default_cfg[key]) is not type(val) and |
|
|
default_cfg[key] is not None): |
|
|
if isinstance(default_cfg[key], np.ndarray): |
|
|
val = np.array(val, dtype=default_cfg[key].dtype) |
|
|
else: |
|
|
raise ValueError( |
|
|
'Type mismatch ({} vs. {}) ' |
|
|
'for config key: {}'.format(type(default_cfg[key]), |
|
|
type(val), key)) |
|
|
|
|
|
if type(val) is edict: |
|
|
try: |
|
|
_merge_two_config(user_cfg[key], default_cfg[key]) |
|
|
except: |
|
|
print('Error under config key: {}'.format(key)) |
|
|
raise |
|
|
else: |
|
|
default_cfg[key] = val |
|
|
|
|
|
|
|
|
def cfg_from_file(file_name, target=__C): |
|
|
""" Load a config file and merge it into the default options. |
|
|
""" |
|
|
import yaml |
|
|
with open(file_name, 'r') as f: |
|
|
print('Loading YAML config file from %s' %f) |
|
|
yaml_cfg = edict(yaml.load(f)) |
|
|
|
|
|
_merge_two_config(yaml_cfg, target) |
|
|
|
|
|
|
|
|
def ordered_dump(data, stream=None, Dumper=yaml.SafeDumper, **kwds): |
|
|
class OrderedDumper(Dumper): |
|
|
pass |
|
|
|
|
|
def _dict_representer(dumper, data): |
|
|
return dumper.represent_mapping( |
|
|
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, |
|
|
data.items(), flow_style=False) |
|
|
|
|
|
def _ndarray_representer(dumper, data): |
|
|
return dumper.represent_list(data.tolist()) |
|
|
|
|
|
OrderedDumper.add_representer(OrderedDict, _dict_representer) |
|
|
OrderedDumper.add_representer(edict, _dict_representer) |
|
|
OrderedDumper.add_representer(np.ndarray, _ndarray_representer) |
|
|
return yaml.dump(data, stream, OrderedDumper, **kwds) |
|
|
|
|
|
|
|
|
def save_cfg(dir_path, source=__C): |
|
|
cfg_count = 0 |
|
|
file_path = os.path.join(dir_path, 'cfg%d.yml' %cfg_count) |
|
|
while os.path.exists(file_path): |
|
|
cfg_count += 1 |
|
|
file_path = os.path.join(dir_path, 'cfg%d.yml' % cfg_count) |
|
|
with open(file_path, 'w') as f: |
|
|
logging.info("Save YAML config file to %s" %file_path) |
|
|
ordered_dump(source, f, yaml.SafeDumper, default_flow_style=None) |
|
|
|
|
|
|
|
|
def load_latest_cfg(dir_path, target=__C): |
|
|
import re |
|
|
cfg_count = None |
|
|
source_cfg_path = None |
|
|
for fname in os.listdir(dir_path): |
|
|
ret = re.search('cfg(\d+)\.yml', fname) |
|
|
if ret != None: |
|
|
if cfg_count is None or (int(re.group(1)) > cfg_count): |
|
|
cfg_count = int(re.group(1)) |
|
|
source_cfg_path = os.path.join(dir_path, ret.group(0)) |
|
|
cfg_from_file(file_name=source_cfg_path, target=target) |
|
|
|