STLDM_official / nowcasting /hko_evaluation.py
sqfoo's picture
Upload 99 files
6021dd1 verified
try:
import cPickle as pickle
except:
import pickle
import numpy as np
import logging
import os
from collections import namedtuple
from nowcasting.config import cfg
from nowcasting.hko_iterator import get_exclude_mask
from nowcasting.helpers.msssim import _SSIMForMultiScale
def pixel_to_dBZ(img):
"""
Parameters
----------
img : np.ndarray or float
Returns
-------
"""
return img * 70.0 - 10.0
def dBZ_to_pixel(dBZ_img):
"""
Parameters
----------
dBZ_img : np.ndarray
Returns
-------
"""
return np.clip((dBZ_img + 10.0) / 70.0, a_min=0.0, a_max=1.0)
def pixel_to_rainfall(img, a=None, b=None):
"""Convert the pixel values to real rainfall intensity
Parameters
----------
img : np.ndarray
a : float32, optional
b : float32, optional
Returns
-------
rainfall_intensity : np.ndarray
"""
if a is None:
a = cfg.HKO.EVALUATION.ZR.a
if b is None:
b = cfg.HKO.EVALUATION.ZR.b
dBZ = pixel_to_dBZ(img)
dBR = (dBZ - 10.0 * np.log10(a)) / b
rainfall_intensity = np.power(10, dBR / 10.0)
return rainfall_intensity
def rainfall_to_pixel(rainfall_intensity, a=None, b=None):
"""Convert the rainfall intensity to pixel values
Parameters
----------
rainfall_intensity : np.ndarray
a : float32, optional
b : float32, optional
Returns
-------
pixel_vals : np.ndarray
"""
if a is None:
a = cfg.HKO.EVALUATION.ZR.a
if b is None:
b = cfg.HKO.EVALUATION.ZR.b
dBR = np.log10(rainfall_intensity) * 10.0
dBZ = dBR * b + 10.0 * np.log10(a)
pixel_vals = (dBZ + 10.0) / 70.0
return pixel_vals
def get_hit_miss_counts(prediction, truth, mask=None, thresholds=None, sum_batch=False):
"""This function calculates the overall hits and misses for the prediction, which could be used
to get the skill scores and threat scores:
This function assumes the input, i.e, prediction and truth are 3-dim tensors, (timestep, row, col)
and all inputs should be between 0~1
Parameters
----------
prediction : np.ndarray
Shape: (seq_len, batch_size, 1, height, width)
truth : np.ndarray
Shape: (seq_len, batch_size, 1, height, width)
mask : np.ndarray or None
Shape: (seq_len, batch_size, 1, height, width)
0 --> not use
1 --> use
thresholds : list or tuple
Returns
-------
hits : np.ndarray
(seq_len, len(thresholds)) or (seq_len, batch_size, len(thresholds))
TP
misses : np.ndarray
(seq_len, len(thresholds)) or (seq_len, batch_size, len(thresholds))
FN
false_alarms : np.ndarray
(seq_len, len(thresholds)) or (seq_len, batch_size, len(thresholds))
FP
correct_negatives : np.ndarray
(seq_len, len(thresholds)) or (seq_len, batch_size, len(thresholds))
TN
"""
if thresholds is None:
thresholds = cfg.HKO.EVALUATION.THRESHOLDS
assert 5 == prediction.ndim
assert 5 == truth.ndim
assert prediction.shape == truth.shape
assert prediction.shape[2] == 1
thresholds = rainfall_to_pixel(np.array(thresholds,
dtype=np.float32)
.reshape((1, 1, len(thresholds), 1, 1)))
bpred = (prediction >= thresholds)
btruth = (truth >= thresholds)
bpred_n = np.logical_not(bpred)
btruth_n = np.logical_not(btruth)
if sum_batch:
summation_axis = (1, 3, 4)
else:
summation_axis = (3, 4)
if mask is None:
hits = np.logical_and(bpred, btruth).sum(axis=summation_axis)
misses = np.logical_and(bpred_n, btruth).sum(axis=summation_axis)
false_alarms = np.logical_and(bpred, btruth_n).sum(axis=summation_axis)
correct_negatives = np.logical_and(bpred_n, btruth_n).sum(axis=summation_axis)
else:
hits = np.logical_and(np.logical_and(bpred, btruth), mask)\
.sum(axis=summation_axis)
misses = np.logical_and(np.logical_and(bpred_n, btruth), mask)\
.sum(axis=summation_axis)
false_alarms = np.logical_and(np.logical_and(bpred, btruth_n), mask)\
.sum(axis=summation_axis)
correct_negatives = np.logical_and(np.logical_and(bpred_n, btruth_n), mask)\
.sum(axis=summation_axis)
return hits, misses, false_alarms, correct_negatives
def get_correlation(prediction, truth):
"""
Parameters
----------
prediction : np.ndarray
truth : np.ndarray
Returns
-------
"""
assert truth.shape == prediction.shape
assert 5 == prediction.ndim
assert prediction.shape[2] == 1
eps = 1E-12
ret = (prediction * truth).sum(axis=(3, 4)) / (
np.sqrt(np.square(prediction).sum(axis=(3, 4))) * np.sqrt(np.square(truth).sum(axis=(3, 4))) + eps)
ret = ret.sum(axis=(1, 2))
return ret
def get_rainfall_mse(prediction, truth):
ret = np.square(pixel_to_rainfall(prediction) - pixel_to_rainfall(truth)).mean(axis=(2, 3))
ret = ret.sum(axis=1)
return ret
def get_PSNR(prediction, truth):
"""Peak Signal Noise Ratio
Parameters
----------
prediction : np.ndarray
truth : np.ndarray
Returns
-------
ret : np.ndarray
"""
mse = np.square(prediction - truth).mean(axis=(2, 3, 4))
ret = 10.0 * np.log10(1.0 / mse)
ret = ret.sum(axis=1)
return ret
def get_SSIM(prediction, truth):
"""Calculate the SSIM score following
[TIP2004] Image Quality Assessment: From Error Visibility to Structural Similarity
Same functionality as
https://github.com/coupriec/VideoPredictionICLR2016/blob/master/image_error_measures.lua#L50-L75
We use nowcasting.helpers.msssim, which is borrowed from Tensorflow to do the evaluation
Parameters
----------
prediction : np.ndarray
truth : np.ndarray
Returns
-------
ret : np.ndarray
"""
assert truth.shape == prediction.shape
assert 5 == prediction.ndim
assert prediction.shape[2] == 1
seq_len = prediction.shape[0]
batch_size = prediction.shape[1]
prediction = prediction.reshape((prediction.shape[0] * prediction.shape[1],
prediction.shape[3], prediction.shape[4], 1))
truth = truth.reshape((truth.shape[0] * truth.shape[1],
truth.shape[3], truth.shape[4], 1))
ssim, cs = _SSIMForMultiScale(img1=prediction, img2=truth, max_val=1.0)
print(ssim.shape)
ret = ssim.reshape((seq_len, batch_size)).sum(axis=1)
return ret
def get_GDL(prediction, truth, mask, sum_batch=False):
"""Calculate the masked gradient difference loss
Parameters
----------
prediction : np.ndarray
Shape: (seq_len, batch_size, 1, height, width)
truth : np.ndarray
Shape: (seq_len, batch_size, 1, height, width)
mask : np.ndarray or None
Shape: (seq_len, batch_size, 1, height, width)
0 --> not use
1 --> use
Returns
-------
gdl : np.ndarray
Shape: (seq_len,) or (seq_len, batch_size)
"""
prediction_diff_h = np.abs(np.diff(prediction, axis=3))
prediction_diff_w = np.abs(np.diff(prediction, axis=4))
gt_diff_h = np.abs(np.diff(truth, axis=3))
gt_diff_w = np.abs(np.diff(truth, axis=4))
mask_h = mask[:, :, :, :-1, :] * mask[:, :, :, 1:, :]
mask_w = mask[:, :, :, :, :-1] * mask[:, :, :, :, 1:]
gd_h = np.abs(prediction_diff_h - gt_diff_h)
gd_w = np.abs(prediction_diff_w - gt_diff_w)
gd_h[:] *= mask_h
gd_w[:] *= mask_w
summation_axis = (1, 2, 3, 4) if sum_batch else (2, 3, 4)
gdl = np.sum(gd_h, axis=summation_axis) + np.sum(gd_w, axis=summation_axis)
return gdl
def get_balancing_weights(data, mask, base_balancing_weights=None, thresholds=None):
if thresholds is None:
thresholds = cfg.HKO.EVALUATION.THRESHOLDS
if base_balancing_weights is None:
base_balancing_weights = cfg.HKO.EVALUATION.BALANCING_WEIGHTS
thresholds = rainfall_to_pixel(np.array(thresholds, dtype=np.float32)
.reshape((1, 1, 1, 1, 1, len(thresholds))))
weights = np.ones_like(data) * base_balancing_weights[0]
threshold_mask = np.expand_dims(data, axis=5) >= thresholds
base_weights = np.diff(np.array(base_balancing_weights, dtype=np.float32))\
.reshape((1, 1, 1, 1, 1, len(base_balancing_weights) - 1))
weights += (threshold_mask * base_weights).sum(axis=-1)
weights *= mask
return weights
try:
from nowcasting.numba_accelerated import get_GDL_numba, get_hit_miss_counts_numba,\
get_balancing_weights_numba
except:
# get_GDL_numba = get_GDL
# get_hit_miss_counts_numba = get_hit_miss_counts
# get_balancing_weights_numba = get_balancing_weights
# print("Numba has not been installed correctly!")
raise ImportError("Numba has not been installed correctly!")
class HKOEvaluation(object):
def __init__(self, seq_len, use_central, no_ssim=True, threholds=None,
central_region=None):
if central_region is None:
central_region = cfg.HKO.EVALUATION.CENTRAL_REGION
self._thresholds = cfg.HKO.EVALUATION.THRESHOLDS if threholds is None else threholds
self._seq_len = seq_len
self._no_ssim = no_ssim
self._use_central = use_central
self._central_region = central_region
self._exclude_mask = get_exclude_mask()
self.begin()
def begin(self):
self._total_hits = np.zeros((self._seq_len, len(self._thresholds)), dtype=np.int)
self._total_misses = np.zeros((self._seq_len, len(self._thresholds)), dtype=np.int)
self._total_false_alarms = np.zeros((self._seq_len, len(self._thresholds)), dtype=np.int)
self._total_correct_negatives = np.zeros((self._seq_len, len(self._thresholds)),
dtype=np.int)
self._mse = np.zeros((self._seq_len, ), dtype=np.float32)
self._mae = np.zeros((self._seq_len, ), dtype=np.float32)
self._balanced_mse = np.zeros((self._seq_len, ), dtype=np.float32)
self._balanced_mae = np.zeros((self._seq_len,), dtype=np.float32)
self._gdl = np.zeros((self._seq_len,), dtype=np.float32)
self._ssim = np.zeros((self._seq_len,), dtype=np.float32)
self._datetime_dict = {}
self._total_batch_num = 0
def clear_all(self):
self._total_hits[:] = 0
self._total_misses[:] = 0
self._total_false_alarms[:] = 0
self._total_correct_negatives[:] = 0
self._mse[:] = 0
self._mae[:] = 0
self._gdl[:] = 0
self._ssim[:] = 0
self._total_batch_num = 0
def update(self, gt, pred, mask, start_datetimes=None):
"""
Parameters
----------
gt : np.ndarray
pred : np.ndarray
mask : np.ndarray
0 indicates not use and 1 indicates that the location will be taken into account
start_datetimes : list
The starting datetimes of all the testing instances
Returns
-------
"""
if start_datetimes is not None:
batch_size = len(start_datetimes)
assert gt.shape[1] == batch_size
else:
batch_size = gt.shape[1]
assert gt.shape[0] == self._seq_len
assert gt.shape == pred.shape
assert gt.shape == mask.shape
if self._use_central:
# Crop the central regions for evaluation
pred = pred[:, :, :,
self._central_region[1]:self._central_region[3],
self._central_region[0]:self._central_region[2]]
gt = gt[:, :, :,
self._central_region[1]:self._central_region[3],
self._central_region[0]:self._central_region[2]]
mask = mask[:, :, :,
self._central_region[1]:self._central_region[3],
self._central_region[0]:self._central_region[2]]
self._total_batch_num += batch_size
#TODO Save all the mse, mae, gdl, hits, misses, false_alarms and correct_negatives
mse = (mask * np.square(pred - gt)).sum(axis=(2, 3, 4))
mae = (mask * np.abs(pred - gt)).sum(axis=(2, 3, 4))
weights = get_balancing_weights_numba(data=gt, mask=mask,
base_balancing_weights=cfg.HKO.EVALUATION.BALANCING_WEIGHTS,
thresholds=self._thresholds)
balanced_mse = (weights * np.square(pred - gt)).sum(axis=(2, 3, 4))
balanced_mae = (weights * np.abs(pred - gt)).sum(axis=(2, 3, 4))
gdl = get_GDL_numba(prediction=pred, truth=gt, mask=mask)
self._mse += mse.sum(axis=1)
self._mae += mae.sum(axis=1)
self._balanced_mse += balanced_mse.sum(axis=1)
self._balanced_mae += balanced_mae.sum(axis=1)
self._gdl += gdl.sum(axis=1)
if not self._no_ssim:
raise NotImplementedError
# self._ssim += get_SSIM(prediction=pred, truth=gt)
hits, misses, false_alarms, correct_negatives = \
get_hit_miss_counts_numba(prediction=pred, truth=gt, mask=mask,
thresholds=self._thresholds)
self._total_hits += hits.sum(axis=1)
self._total_misses += misses.sum(axis=1)
self._total_false_alarms += false_alarms.sum(axis=1)
self._total_correct_negatives += correct_negatives.sum(axis=1)
def calculate_stat(self):
"""The following measurements will be used to measure the score of the forecaster
See Also
[Weather and Forecasting 2010] Equitability Revisited: Why the "Equitable Threat Score" Is Not Equitable
http://www.wxonline.info/topics/verif2.html
We will denote
(a b (hits false alarms
c d) = misses correct negatives)
We will report the
POD = a / (a + c)
FAR = b / (a + b)
CSI = a / (a + b + c)
Heidke Skill Score (HSS) = 2(ad - bc) / ((a+c) (c+d) + (a+b)(b+d))
Gilbert Skill Score (GSS) = HSS / (2 - HSS), also known as the Equitable Threat Score
HSS = 2 * GSS / (GSS + 1)
MSE = mask * (pred - gt) **2
MAE = mask * abs(pred - gt)
GDL = valid_mask_h * abs(gd_h(pred) - gd_h(gt)) + valid_mask_w * abs(gd_w(pred) - gd_w(gt))
Returns
-------
"""
a = self._total_hits.astype(np.float64)
b = self._total_false_alarms.astype(np.float64)
c = self._total_misses.astype(np.float64)
d = self._total_correct_negatives.astype(np.float64)
pod = a / (a + c)
far = b / (a + b)
csi = a / (a + b + c)
n = a + b + c + d
aref = (a + b) / n * (a + c)
gss = (a - aref) / (a + b + c - aref)
hss = 2 * gss / (gss + 1)
mse = self._mse / self._total_batch_num
mae = self._mae / self._total_batch_num
balanced_mse = self._balanced_mse / self._total_batch_num
balanced_mae = self._balanced_mae / self._total_batch_num
gdl = self._gdl / self._total_batch_num
if not self._no_ssim:
raise NotImplementedError
# ssim = self._ssim / self._total_batch_num
# return pod, far, csi, hss, gss, mse, mae, gdl
return pod, far, csi, hss, gss, mse, mae, balanced_mse, balanced_mae, gdl
def print_stat_readable(self, prefix=""):
logging.info("%sTotal Sequence Number: %d, Use Central: %d"
%(prefix, self._total_batch_num, self._use_central))
pod, far, csi, hss, gss, mse, mae, balanced_mse, balanced_mae, gdl = self.calculate_stat()
# pod, far, csi, hss, gss, mse, mae, gdl = self.calculate_stat()
logging.info(" Hits: " + ', '.join([">%g:%g/%g" % (threshold,
self._total_hits[:, i].mean(),
self._total_hits[-1, i])
for i, threshold in enumerate(self._thresholds)]))
logging.info(" POD: " + ', '.join([">%g:%g/%g" % (threshold, pod[:, i].mean(), pod[-1, i])
for i, threshold in enumerate(self._thresholds)]))
logging.info(" FAR: " + ', '.join([">%g:%g/%g" % (threshold, far[:, i].mean(), far[-1, i])
for i, threshold in enumerate(self._thresholds)]))
logging.info(" CSI: " + ', '.join([">%g:%g/%g" % (threshold, csi[:, i].mean(), csi[-1, i])
for i, threshold in enumerate(self._thresholds)]))
logging.info(" GSS: " + ', '.join([">%g:%g/%g" % (threshold, gss[:, i].mean(), gss[-1, i])
for i, threshold in enumerate(self._thresholds)]))
logging.info(" HSS: " + ', '.join([">%g:%g/%g" % (threshold, hss[:, i].mean(), hss[-1, i])
for i, threshold in enumerate(self._thresholds)]))
logging.info(" MSE: %g/%g" % (mse.mean(), mse[-1]))
logging.info(" MAE: %g/%g" % (mae.mean(), mae[-1]))
logging.info(" Balanced MSE: %g/%g" % (balanced_mse.mean(), balanced_mse[-1]))
logging.info(" Balanced MAE: %g/%g" % (balanced_mae.mean(), balanced_mae[-1]))
logging.info(" GDL: %g/%g" % (gdl.mean(), gdl[-1]))
if not self._no_ssim:
raise NotImplementedError
def save_pkl(self, path):
dir_path = os.path.dirname(path)
if not os.path.exists(dir_path):
os.makedirs(dir_path)
f = open(path, 'wb')
logging.info("Saving HKOEvaluation to %s" %path)
pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
f.close()
def save_txt_readable(self, path):
dir_path = os.path.dirname(path)
if not os.path.exists(dir_path):
os.makedirs(dir_path)
pod, far, csi, hss, gss, mse, mae, balanced_mse, balanced_mae, gdl = self.calculate_stat()
# pod, far, csi, hss, gss, mse, mae, gdl = self.calculate_stat()
f = open(path, 'w')
logging.info("Saving readable txt of HKOEvaluation to %s" % path)
f.write("Total Sequence Num: %d, Out Seq Len: %d, Use Central: %d\n"
%(self._total_batch_num,
self._seq_len,
self._use_central))
for (i, threshold) in enumerate(self._thresholds):
f.write("Threshold = %g:\n" %threshold)
f.write(" POD: %s\n" %str(list(pod[:, i])))
f.write(" FAR: %s\n" % str(list(far[:, i])))
f.write(" CSI: %s\n" % str(list(csi[:, i])))
f.write(" GSS: %s\n" % str(list(gss[:, i])))
f.write(" HSS: %s\n" % str(list(hss[:, i])))
f.write(" POD stat: avg %g/final %g\n" %(pod[:, i].mean(), pod[-1, i]))
f.write(" FAR stat: avg %g/final %g\n" %(far[:, i].mean(), far[-1, i]))
f.write(" CSI stat: avg %g/final %g\n" %(csi[:, i].mean(), csi[-1, i]))
f.write(" GSS stat: avg %g/final %g\n" %(gss[:, i].mean(), gss[-1, i]))
f.write(" HSS stat: avg %g/final %g\n" % (hss[:, i].mean(), hss[-1, i]))
f.write("MSE: %s\n" % str(list(mse)))
f.write("MAE: %s\n" % str(list(mae)))
f.write("Balanced MSE: %s\n" % str(list(balanced_mse)))
f.write("Balanced MAE: %s\n" % str(list(balanced_mae)))
f.write("GDL: %s\n" % str(list(gdl)))
f.write("MSE stat: avg %g/final %g\n" % (mse.mean(), mse[-1]))
f.write("MAE stat: avg %g/final %g\n" % (mae.mean(), mae[-1]))
f.write("Balanced MSE stat: avg %g/final %g\n" % (balanced_mse.mean(), balanced_mse[-1]))
f.write("Balanced MAE stat: avg %g/final %g\n" % (balanced_mae.mean(), balanced_mae[-1]))
f.write("GDL stat: avg %g/final %g\n" % (gdl.mean(), gdl[-1]))
f.close()
def save(self, prefix):
self.save_txt_readable(prefix + ".txt")
self.save_pkl(prefix + ".pkl")