|
|
import json |
|
|
import os |
|
|
import numpy as np |
|
|
from nowcasting.config import cfg |
|
|
from nowcasting.helpers.visualization import save_hko_movie |
|
|
from nowcasting.hko_iterator import HKOIterator |
|
|
from nowcasting.hko_evaluation import HKOEvaluation |
|
|
|
|
|
class HKOBenchmarkEnv(object): |
|
|
"""The Benchmark environment for the HKO7 Dataset |
|
|
|
|
|
There are two settings for the Benchmark, the "fixed" setting and the "online" setting. |
|
|
In the "fixed" setting, pre-defined input sequences that have the same length will be |
|
|
fed into the model for prediction. |
|
|
This setting tests the model's ability to use the instant past to predict the future. |
|
|
In the "online" setting, M frames will be given each time and the forecasting model |
|
|
is required to predict the next K frames every stride steps. |
|
|
If the begin_new_episode flag is turned on, a new episode has begun, which means that the current received images have no relationship with the previous images. |
|
|
If the need_upload_prediction flag is turned on, the model is required to predict the |
|
|
This setting tests both the model's ability to adapt in an online fashion and |
|
|
the ability to capture the long-term dependency. |
|
|
The input frame will be missing in some timestamps. |
|
|
|
|
|
To run the benchmark in the fixed setting: |
|
|
|
|
|
env = HKOBenchmarkEnv(...) |
|
|
while not env.done: |
|
|
# Get the observation |
|
|
in_frame_dat, in_mask_dat, in_datetime_clips, out_datetime_clips, begin_new_episode = |
|
|
env.get_observation(batch_size=batch_size) |
|
|
# Running your algorithm to get the prediction |
|
|
prediction = ... |
|
|
# Upload prediction to the environment |
|
|
env.upload_prediction(prediction) |
|
|
|
|
|
""" |
|
|
def __init__(self, |
|
|
pd_path, |
|
|
save_dir="hko7_benchmark", |
|
|
mode="fixed"): |
|
|
assert mode == "fixed" or mode == "online" |
|
|
self._pd_path = pd_path |
|
|
self._save_dir = save_dir |
|
|
if not os.path.exists(save_dir): |
|
|
os.makedirs(save_dir) |
|
|
self._mode = mode |
|
|
self._out_seq_len = cfg.HKO.BENCHMARK.OUT_LEN |
|
|
self._stride = cfg.HKO.BENCHMARK.STRIDE |
|
|
if mode == "fixed": |
|
|
self._in_seq_len = cfg.HKO.BENCHMARK.IN_LEN |
|
|
else: |
|
|
self._in_seq_len = cfg.HKO.BENCHMARK.STRIDE |
|
|
self._hko_iter = HKOIterator(pd_path=pd_path, |
|
|
sample_mode="sequent", |
|
|
seq_len=self._in_seq_len + self._out_seq_len, |
|
|
stride=self._stride) |
|
|
self._stat_dict = self._get_benchmark_stat() |
|
|
self._begin_new_episode = True |
|
|
self._received_pred_seq_num = 0 |
|
|
self._need_upload_prediction = False |
|
|
|
|
|
|
|
|
self._save_seq_inds = set(np.arange(1, cfg.HKO.BENCHMARK.VISUALIZE_SEQ_NUM + 1) * \ |
|
|
(self._stat_dict['pred_seq_num'] // |
|
|
cfg.HKO.BENCHMARK.VISUALIZE_SEQ_NUM)) |
|
|
self._all_eval = HKOEvaluation(seq_len=self._out_seq_len, use_central=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._in_frame_dat = None |
|
|
self._in_mask_dat = None |
|
|
self._in_datetime_clips = None |
|
|
self._out_frame_dat = None |
|
|
self._out_mask_dat = None |
|
|
self._out_datetime_clips = None |
|
|
|
|
|
def reset(self): |
|
|
self._hko_iter.reset() |
|
|
self._all_eval.clear_all() |
|
|
self._begin_new_episode = True |
|
|
self._received_pred_seq_num = 0 |
|
|
self._need_upload_prediction = False |
|
|
|
|
|
@property |
|
|
def _fingerprint(self): |
|
|
pd_file_name = os.path.splitext(os.path.basename(self._pd_path))[0] |
|
|
if self._mode == "fixed": |
|
|
fingerprint = pd_file_name + "_in" + str(self._in_seq_len)\ |
|
|
+ "_out" + str(self._out_seq_len) + "_stride" + str(self._stride)\ |
|
|
+ "_" + self._mode |
|
|
else: |
|
|
fingerprint = pd_file_name + "_out" + str(self._out_seq_len)\ |
|
|
+ "_stride" + str(self._stride)\ |
|
|
+ "_" + self._mode |
|
|
return fingerprint |
|
|
|
|
|
@property |
|
|
def _stat_filepath(self): |
|
|
filename = self._fingerprint + ".json" |
|
|
return os.path.join(cfg.HKO.BENCHMARK.STAT_PATH, filename) |
|
|
|
|
|
def _get_benchmark_stat(self): |
|
|
"""Get the general statistics of the benchmark |
|
|
|
|
|
Returns |
|
|
------- |
|
|
stat_dict : dict |
|
|
'pred_seq_num' --> Total number of predictions the model needs to make |
|
|
""" |
|
|
if os.path.exists(self._stat_filepath): |
|
|
stat_dict = json.load(open(self._stat_filepath)) |
|
|
else: |
|
|
seq_num = 0 |
|
|
episode_num = 0 |
|
|
episode_start_datetime = [] |
|
|
while not self._hko_iter.use_up: |
|
|
if self._mode == "fixed": |
|
|
datetime_clips, new_start =\ |
|
|
self._hko_iter.sample(batch_size=1024, only_return_datetime=True) |
|
|
if len(datetime_clips) == 0: |
|
|
continue |
|
|
seq_num += len(datetime_clips) |
|
|
episode_num += len(datetime_clips) |
|
|
elif self._mode == "online": |
|
|
datetime_clips, new_start = \ |
|
|
self._hko_iter.sample(batch_size=1, only_return_datetime=True) |
|
|
if len(datetime_clips) == 0: |
|
|
continue |
|
|
episode_num += new_start |
|
|
if new_start: |
|
|
episode_start_datetime.append(datetime_clips[0][0].strftime('%Y%m%d%H%M')) |
|
|
if self._stride != 1: |
|
|
seq_num += 1 |
|
|
else: |
|
|
seq_num += 1 |
|
|
print(self._fingerprint, seq_num, episode_num) |
|
|
self._hko_iter.reset() |
|
|
stat_dict = {'pred_seq_num': seq_num, |
|
|
'episode_num': episode_num, |
|
|
'episode_start_datetime': episode_start_datetime} |
|
|
json.dump(stat_dict, open(self._stat_filepath, 'w'), indent=3) |
|
|
return stat_dict |
|
|
|
|
|
@property |
|
|
def done(self): |
|
|
return self._received_pred_seq_num >= self._stat_dict["pred_seq_num"] |
|
|
|
|
|
def get_observation(self, batch_size=1): |
|
|
""" |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
batch_size : int |
|
|
|
|
|
|
|
|
Returns |
|
|
------- |
|
|
in_frame_dat : np.ndarray |
|
|
Will be between 0 and 1 |
|
|
in_datetime_clips : list |
|
|
out_datetime_clips : list |
|
|
begin_new_episode : bool |
|
|
need_upload_prediction : bool |
|
|
""" |
|
|
if self._mode == "online": |
|
|
assert batch_size == 1 |
|
|
assert not self._need_upload_prediction |
|
|
assert not self.done |
|
|
assert not self._hko_iter.use_up |
|
|
while True: |
|
|
frame_dat, mask_dat, datetime_clips, new_start =\ |
|
|
self._hko_iter.sample(batch_size=batch_size, only_return_datetime=False) |
|
|
if len(datetime_clips) == 0: |
|
|
continue |
|
|
else: |
|
|
break |
|
|
frame_dat = frame_dat.astype(np.float32) / 255.0 |
|
|
self._need_upload_prediction = True |
|
|
if self._mode == "online": |
|
|
self._begin_new_episode = new_start |
|
|
if new_start and self._stride == 1: |
|
|
self._need_upload_prediction = False |
|
|
else: |
|
|
self._begin_new_episode = True |
|
|
self._in_datetime_clips = [ele[:self._in_seq_len] for ele in datetime_clips] |
|
|
self._out_datetime_clips = [ele[self._in_seq_len:(self._in_seq_len + |
|
|
self._out_seq_len)] |
|
|
for ele in datetime_clips] |
|
|
self._in_frame_dat = frame_dat[:self._in_seq_len, ...] |
|
|
self._out_frame_dat = frame_dat[self._in_seq_len:(self._in_seq_len + self._out_seq_len), |
|
|
...] |
|
|
self._in_mask_dat = mask_dat[:self._in_seq_len, ...] |
|
|
self._out_mask_dat = mask_dat[self._in_seq_len:(self._in_seq_len + self._out_seq_len), ...] |
|
|
return self._in_frame_dat,\ |
|
|
self._in_datetime_clips,\ |
|
|
self._out_datetime_clips,\ |
|
|
self._begin_new_episode, \ |
|
|
self._need_upload_prediction |
|
|
|
|
|
def upload_prediction(self, prediction): |
|
|
""" |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
prediction : np.ndarray |
|
|
|
|
|
""" |
|
|
assert self._need_upload_prediction, "Must call get_observation first!" \ |
|
|
" Also, check the value of need_upload_predction" \ |
|
|
" after calling" |
|
|
self._need_upload_prediction = False |
|
|
received_seq_inds = range(self._received_pred_seq_num, |
|
|
self._received_pred_seq_num + prediction.shape[1]) |
|
|
save_ind_set = set(received_seq_inds).intersection(self._save_seq_inds) |
|
|
if len(save_ind_set) > 0: |
|
|
assert len(save_ind_set) == 1 |
|
|
ind = save_ind_set.pop() |
|
|
ind -= self._received_pred_seq_num |
|
|
if not os.path.exists(os.path.join(self._save_dir, self._fingerprint)): |
|
|
os.makedirs(os.path.join(self._save_dir, self._fingerprint)) |
|
|
print("Saving prediction videos to %s" % os.path.join(self._save_dir, |
|
|
self._fingerprint)) |
|
|
save_hko_movie(im_dat=self._in_frame_dat[:, ind, 0, ...], |
|
|
mask_dat=self._in_mask_dat[:, ind, 0, :, :], |
|
|
datetime_list=self._in_datetime_clips[ind], |
|
|
save_path=os.path.join(self._save_dir, self._fingerprint, |
|
|
"%s_in.mp4" % |
|
|
self._in_datetime_clips[ind][0] |
|
|
.strftime('%Y%m%d%H%M'))) |
|
|
save_hko_movie(im_dat=self._out_frame_dat[:, ind, 0, ...], |
|
|
mask_dat=self._out_mask_dat[:, ind, 0, :, :], |
|
|
masked=False, |
|
|
datetime_list=self._out_datetime_clips[ind], |
|
|
save_path=os.path.join(self._save_dir, self._fingerprint, |
|
|
"%s_out.mp4" % |
|
|
self._in_datetime_clips[ind][0] |
|
|
.strftime('%Y%m%d%H%M'))) |
|
|
save_hko_movie(im_dat=prediction[:, ind, 0, ...], |
|
|
mask_dat=self._out_mask_dat[:, ind, 0, :, :], |
|
|
masked=False, |
|
|
datetime_list=self._out_datetime_clips[ind], |
|
|
save_path=os.path.join(self._save_dir, self._fingerprint, |
|
|
"%s_pred.mp4" % |
|
|
self._in_datetime_clips[ind][0] |
|
|
.strftime('%Y%m%d%H%M'))) |
|
|
self._received_pred_seq_num += prediction.shape[1] |
|
|
if self._mode == "online": |
|
|
if self._stride == 1: |
|
|
assert not self._begin_new_episode |
|
|
self._all_eval.update(gt=self._out_frame_dat, |
|
|
pred=prediction, |
|
|
mask=self._out_mask_dat, |
|
|
start_datetimes=[ele[0] for ele in self._out_datetime_clips]) |
|
|
|
|
|
def print_stat_readable(self): |
|
|
self._all_eval.print_stat_readable(prefix="Received:%d " %self._received_pred_seq_num) |
|
|
|
|
|
def save_eval(self): |
|
|
assert self._received_pred_seq_num == self._stat_dict['pred_seq_num'],\ |
|
|
"Must upload all the predictions to the testbed!" |
|
|
print("Saving evaluation result to %s" %os.path.join(self._save_dir, self._fingerprint)) |
|
|
self._all_eval.save(prefix=os.path.join(self._save_dir, self._fingerprint, "eval_all")) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
env_fixed = HKOBenchmarkEnv(pd_path=cfg.HKO_PD.RAINY_TEST, mode="fixed") |
|
|
env_online = HKOBenchmarkEnv(pd_path=cfg.HKO_PD.RAINY_TEST, mode="online") |
|
|
print("Fixed Rainy Test SeqNum:", env_fixed._stat_dict['pred_seq_num']) |
|
|
print("Online Rainy Test SeqNum:", env_online._stat_dict['pred_seq_num']) |
|
|
env_fixed = HKOBenchmarkEnv(pd_path=cfg.HKO_PD.RAINY_VALID, mode="fixed") |
|
|
env_online = HKOBenchmarkEnv(pd_path=cfg.HKO_PD.RAINY_VALID, mode="online") |
|
|
print("Fixed Rainy Valid SeqNum:", env_fixed._stat_dict['pred_seq_num']) |
|
|
print("Online Rainy Valid SeqNum:", env_online._stat_dict['pred_seq_num']) |
|
|
env_fixed = HKOBenchmarkEnv(pd_path=cfg.HKO_PD.ALL_15, mode="fixed") |
|
|
env_online = HKOBenchmarkEnv(pd_path=cfg.HKO_PD.ALL_15, mode="online") |
|
|
print("Fixed Rainy2015 ALL SeqNum:", env_fixed._stat_dict['pred_seq_num']) |
|
|
print("Online Rainy2015 ALL SeqNum:", env_online._stat_dict['pred_seq_num']) |
|
|
|