try: import cPickle as pickle except: import pickle import numpy as np import logging import os from collections import namedtuple from nowcasting.config import cfg from nowcasting.hko_iterator import get_exclude_mask from nowcasting.helpers.msssim import _SSIMForMultiScale def pixel_to_dBZ(img): """ Parameters ---------- img : np.ndarray or float Returns ------- """ return img * 70.0 - 10.0 def dBZ_to_pixel(dBZ_img): """ Parameters ---------- dBZ_img : np.ndarray Returns ------- """ return np.clip((dBZ_img + 10.0) / 70.0, a_min=0.0, a_max=1.0) def pixel_to_rainfall(img, a=None, b=None): """Convert the pixel values to real rainfall intensity Parameters ---------- img : np.ndarray a : float32, optional b : float32, optional Returns ------- rainfall_intensity : np.ndarray """ if a is None: a = cfg.HKO.EVALUATION.ZR.a if b is None: b = cfg.HKO.EVALUATION.ZR.b dBZ = pixel_to_dBZ(img) dBR = (dBZ - 10.0 * np.log10(a)) / b rainfall_intensity = np.power(10, dBR / 10.0) return rainfall_intensity def rainfall_to_pixel(rainfall_intensity, a=None, b=None): """Convert the rainfall intensity to pixel values Parameters ---------- rainfall_intensity : np.ndarray a : float32, optional b : float32, optional Returns ------- pixel_vals : np.ndarray """ if a is None: a = cfg.HKO.EVALUATION.ZR.a if b is None: b = cfg.HKO.EVALUATION.ZR.b dBR = np.log10(rainfall_intensity) * 10.0 dBZ = dBR * b + 10.0 * np.log10(a) pixel_vals = (dBZ + 10.0) / 70.0 return pixel_vals def get_hit_miss_counts(prediction, truth, mask=None, thresholds=None, sum_batch=False): """This function calculates the overall hits and misses for the prediction, which could be used to get the skill scores and threat scores: This function assumes the input, i.e, prediction and truth are 3-dim tensors, (timestep, row, col) and all inputs should be between 0~1 Parameters ---------- prediction : np.ndarray Shape: (seq_len, batch_size, 1, height, width) truth : np.ndarray Shape: (seq_len, batch_size, 1, height, width) mask : np.ndarray or None Shape: (seq_len, batch_size, 1, height, width) 0 --> not use 1 --> use thresholds : list or tuple Returns ------- hits : np.ndarray (seq_len, len(thresholds)) or (seq_len, batch_size, len(thresholds)) TP misses : np.ndarray (seq_len, len(thresholds)) or (seq_len, batch_size, len(thresholds)) FN false_alarms : np.ndarray (seq_len, len(thresholds)) or (seq_len, batch_size, len(thresholds)) FP correct_negatives : np.ndarray (seq_len, len(thresholds)) or (seq_len, batch_size, len(thresholds)) TN """ if thresholds is None: thresholds = cfg.HKO.EVALUATION.THRESHOLDS assert 5 == prediction.ndim assert 5 == truth.ndim assert prediction.shape == truth.shape assert prediction.shape[2] == 1 thresholds = rainfall_to_pixel(np.array(thresholds, dtype=np.float32) .reshape((1, 1, len(thresholds), 1, 1))) bpred = (prediction >= thresholds) btruth = (truth >= thresholds) bpred_n = np.logical_not(bpred) btruth_n = np.logical_not(btruth) if sum_batch: summation_axis = (1, 3, 4) else: summation_axis = (3, 4) if mask is None: hits = np.logical_and(bpred, btruth).sum(axis=summation_axis) misses = np.logical_and(bpred_n, btruth).sum(axis=summation_axis) false_alarms = np.logical_and(bpred, btruth_n).sum(axis=summation_axis) correct_negatives = np.logical_and(bpred_n, btruth_n).sum(axis=summation_axis) else: hits = np.logical_and(np.logical_and(bpred, btruth), mask)\ .sum(axis=summation_axis) misses = np.logical_and(np.logical_and(bpred_n, btruth), mask)\ .sum(axis=summation_axis) false_alarms = np.logical_and(np.logical_and(bpred, btruth_n), mask)\ .sum(axis=summation_axis) correct_negatives = np.logical_and(np.logical_and(bpred_n, btruth_n), mask)\ .sum(axis=summation_axis) return hits, misses, false_alarms, correct_negatives def get_correlation(prediction, truth): """ Parameters ---------- prediction : np.ndarray truth : np.ndarray Returns ------- """ assert truth.shape == prediction.shape assert 5 == prediction.ndim assert prediction.shape[2] == 1 eps = 1E-12 ret = (prediction * truth).sum(axis=(3, 4)) / ( np.sqrt(np.square(prediction).sum(axis=(3, 4))) * np.sqrt(np.square(truth).sum(axis=(3, 4))) + eps) ret = ret.sum(axis=(1, 2)) return ret def get_rainfall_mse(prediction, truth): ret = np.square(pixel_to_rainfall(prediction) - pixel_to_rainfall(truth)).mean(axis=(2, 3)) ret = ret.sum(axis=1) return ret def get_PSNR(prediction, truth): """Peak Signal Noise Ratio Parameters ---------- prediction : np.ndarray truth : np.ndarray Returns ------- ret : np.ndarray """ mse = np.square(prediction - truth).mean(axis=(2, 3, 4)) ret = 10.0 * np.log10(1.0 / mse) ret = ret.sum(axis=1) return ret def get_SSIM(prediction, truth): """Calculate the SSIM score following [TIP2004] Image Quality Assessment: From Error Visibility to Structural Similarity Same functionality as https://github.com/coupriec/VideoPredictionICLR2016/blob/master/image_error_measures.lua#L50-L75 We use nowcasting.helpers.msssim, which is borrowed from Tensorflow to do the evaluation Parameters ---------- prediction : np.ndarray truth : np.ndarray Returns ------- ret : np.ndarray """ assert truth.shape == prediction.shape assert 5 == prediction.ndim assert prediction.shape[2] == 1 seq_len = prediction.shape[0] batch_size = prediction.shape[1] prediction = prediction.reshape((prediction.shape[0] * prediction.shape[1], prediction.shape[3], prediction.shape[4], 1)) truth = truth.reshape((truth.shape[0] * truth.shape[1], truth.shape[3], truth.shape[4], 1)) ssim, cs = _SSIMForMultiScale(img1=prediction, img2=truth, max_val=1.0) print(ssim.shape) ret = ssim.reshape((seq_len, batch_size)).sum(axis=1) return ret def get_GDL(prediction, truth, mask, sum_batch=False): """Calculate the masked gradient difference loss Parameters ---------- prediction : np.ndarray Shape: (seq_len, batch_size, 1, height, width) truth : np.ndarray Shape: (seq_len, batch_size, 1, height, width) mask : np.ndarray or None Shape: (seq_len, batch_size, 1, height, width) 0 --> not use 1 --> use Returns ------- gdl : np.ndarray Shape: (seq_len,) or (seq_len, batch_size) """ prediction_diff_h = np.abs(np.diff(prediction, axis=3)) prediction_diff_w = np.abs(np.diff(prediction, axis=4)) gt_diff_h = np.abs(np.diff(truth, axis=3)) gt_diff_w = np.abs(np.diff(truth, axis=4)) mask_h = mask[:, :, :, :-1, :] * mask[:, :, :, 1:, :] mask_w = mask[:, :, :, :, :-1] * mask[:, :, :, :, 1:] gd_h = np.abs(prediction_diff_h - gt_diff_h) gd_w = np.abs(prediction_diff_w - gt_diff_w) gd_h[:] *= mask_h gd_w[:] *= mask_w summation_axis = (1, 2, 3, 4) if sum_batch else (2, 3, 4) gdl = np.sum(gd_h, axis=summation_axis) + np.sum(gd_w, axis=summation_axis) return gdl def get_balancing_weights(data, mask, base_balancing_weights=None, thresholds=None): if thresholds is None: thresholds = cfg.HKO.EVALUATION.THRESHOLDS if base_balancing_weights is None: base_balancing_weights = cfg.HKO.EVALUATION.BALANCING_WEIGHTS thresholds = rainfall_to_pixel(np.array(thresholds, dtype=np.float32) .reshape((1, 1, 1, 1, 1, len(thresholds)))) weights = np.ones_like(data) * base_balancing_weights[0] threshold_mask = np.expand_dims(data, axis=5) >= thresholds base_weights = np.diff(np.array(base_balancing_weights, dtype=np.float32))\ .reshape((1, 1, 1, 1, 1, len(base_balancing_weights) - 1)) weights += (threshold_mask * base_weights).sum(axis=-1) weights *= mask return weights try: from nowcasting.numba_accelerated import get_GDL_numba, get_hit_miss_counts_numba,\ get_balancing_weights_numba except: # get_GDL_numba = get_GDL # get_hit_miss_counts_numba = get_hit_miss_counts # get_balancing_weights_numba = get_balancing_weights # print("Numba has not been installed correctly!") raise ImportError("Numba has not been installed correctly!") class HKOEvaluation(object): def __init__(self, seq_len, use_central, no_ssim=True, threholds=None, central_region=None): if central_region is None: central_region = cfg.HKO.EVALUATION.CENTRAL_REGION self._thresholds = cfg.HKO.EVALUATION.THRESHOLDS if threholds is None else threholds self._seq_len = seq_len self._no_ssim = no_ssim self._use_central = use_central self._central_region = central_region self._exclude_mask = get_exclude_mask() self.begin() def begin(self): self._total_hits = np.zeros((self._seq_len, len(self._thresholds)), dtype=np.int) self._total_misses = np.zeros((self._seq_len, len(self._thresholds)), dtype=np.int) self._total_false_alarms = np.zeros((self._seq_len, len(self._thresholds)), dtype=np.int) self._total_correct_negatives = np.zeros((self._seq_len, len(self._thresholds)), dtype=np.int) self._mse = np.zeros((self._seq_len, ), dtype=np.float32) self._mae = np.zeros((self._seq_len, ), dtype=np.float32) self._balanced_mse = np.zeros((self._seq_len, ), dtype=np.float32) self._balanced_mae = np.zeros((self._seq_len,), dtype=np.float32) self._gdl = np.zeros((self._seq_len,), dtype=np.float32) self._ssim = np.zeros((self._seq_len,), dtype=np.float32) self._datetime_dict = {} self._total_batch_num = 0 def clear_all(self): self._total_hits[:] = 0 self._total_misses[:] = 0 self._total_false_alarms[:] = 0 self._total_correct_negatives[:] = 0 self._mse[:] = 0 self._mae[:] = 0 self._gdl[:] = 0 self._ssim[:] = 0 self._total_batch_num = 0 def update(self, gt, pred, mask, start_datetimes=None): """ Parameters ---------- gt : np.ndarray pred : np.ndarray mask : np.ndarray 0 indicates not use and 1 indicates that the location will be taken into account start_datetimes : list The starting datetimes of all the testing instances Returns ------- """ if start_datetimes is not None: batch_size = len(start_datetimes) assert gt.shape[1] == batch_size else: batch_size = gt.shape[1] assert gt.shape[0] == self._seq_len assert gt.shape == pred.shape assert gt.shape == mask.shape if self._use_central: # Crop the central regions for evaluation pred = pred[:, :, :, self._central_region[1]:self._central_region[3], self._central_region[0]:self._central_region[2]] gt = gt[:, :, :, self._central_region[1]:self._central_region[3], self._central_region[0]:self._central_region[2]] mask = mask[:, :, :, self._central_region[1]:self._central_region[3], self._central_region[0]:self._central_region[2]] self._total_batch_num += batch_size #TODO Save all the mse, mae, gdl, hits, misses, false_alarms and correct_negatives mse = (mask * np.square(pred - gt)).sum(axis=(2, 3, 4)) mae = (mask * np.abs(pred - gt)).sum(axis=(2, 3, 4)) weights = get_balancing_weights_numba(data=gt, mask=mask, base_balancing_weights=cfg.HKO.EVALUATION.BALANCING_WEIGHTS, thresholds=self._thresholds) balanced_mse = (weights * np.square(pred - gt)).sum(axis=(2, 3, 4)) balanced_mae = (weights * np.abs(pred - gt)).sum(axis=(2, 3, 4)) gdl = get_GDL_numba(prediction=pred, truth=gt, mask=mask) self._mse += mse.sum(axis=1) self._mae += mae.sum(axis=1) self._balanced_mse += balanced_mse.sum(axis=1) self._balanced_mae += balanced_mae.sum(axis=1) self._gdl += gdl.sum(axis=1) if not self._no_ssim: raise NotImplementedError # self._ssim += get_SSIM(prediction=pred, truth=gt) hits, misses, false_alarms, correct_negatives = \ get_hit_miss_counts_numba(prediction=pred, truth=gt, mask=mask, thresholds=self._thresholds) self._total_hits += hits.sum(axis=1) self._total_misses += misses.sum(axis=1) self._total_false_alarms += false_alarms.sum(axis=1) self._total_correct_negatives += correct_negatives.sum(axis=1) def calculate_stat(self): """The following measurements will be used to measure the score of the forecaster See Also [Weather and Forecasting 2010] Equitability Revisited: Why the "Equitable Threat Score" Is Not Equitable http://www.wxonline.info/topics/verif2.html We will denote (a b (hits false alarms c d) = misses correct negatives) We will report the POD = a / (a + c) FAR = b / (a + b) CSI = a / (a + b + c) Heidke Skill Score (HSS) = 2(ad - bc) / ((a+c) (c+d) + (a+b)(b+d)) Gilbert Skill Score (GSS) = HSS / (2 - HSS), also known as the Equitable Threat Score HSS = 2 * GSS / (GSS + 1) MSE = mask * (pred - gt) **2 MAE = mask * abs(pred - gt) GDL = valid_mask_h * abs(gd_h(pred) - gd_h(gt)) + valid_mask_w * abs(gd_w(pred) - gd_w(gt)) Returns ------- """ a = self._total_hits.astype(np.float64) b = self._total_false_alarms.astype(np.float64) c = self._total_misses.astype(np.float64) d = self._total_correct_negatives.astype(np.float64) pod = a / (a + c) far = b / (a + b) csi = a / (a + b + c) n = a + b + c + d aref = (a + b) / n * (a + c) gss = (a - aref) / (a + b + c - aref) hss = 2 * gss / (gss + 1) mse = self._mse / self._total_batch_num mae = self._mae / self._total_batch_num balanced_mse = self._balanced_mse / self._total_batch_num balanced_mae = self._balanced_mae / self._total_batch_num gdl = self._gdl / self._total_batch_num if not self._no_ssim: raise NotImplementedError # ssim = self._ssim / self._total_batch_num # return pod, far, csi, hss, gss, mse, mae, gdl return pod, far, csi, hss, gss, mse, mae, balanced_mse, balanced_mae, gdl def print_stat_readable(self, prefix=""): logging.info("%sTotal Sequence Number: %d, Use Central: %d" %(prefix, self._total_batch_num, self._use_central)) pod, far, csi, hss, gss, mse, mae, balanced_mse, balanced_mae, gdl = self.calculate_stat() # pod, far, csi, hss, gss, mse, mae, gdl = self.calculate_stat() logging.info(" Hits: " + ', '.join([">%g:%g/%g" % (threshold, self._total_hits[:, i].mean(), self._total_hits[-1, i]) for i, threshold in enumerate(self._thresholds)])) logging.info(" POD: " + ', '.join([">%g:%g/%g" % (threshold, pod[:, i].mean(), pod[-1, i]) for i, threshold in enumerate(self._thresholds)])) logging.info(" FAR: " + ', '.join([">%g:%g/%g" % (threshold, far[:, i].mean(), far[-1, i]) for i, threshold in enumerate(self._thresholds)])) logging.info(" CSI: " + ', '.join([">%g:%g/%g" % (threshold, csi[:, i].mean(), csi[-1, i]) for i, threshold in enumerate(self._thresholds)])) logging.info(" GSS: " + ', '.join([">%g:%g/%g" % (threshold, gss[:, i].mean(), gss[-1, i]) for i, threshold in enumerate(self._thresholds)])) logging.info(" HSS: " + ', '.join([">%g:%g/%g" % (threshold, hss[:, i].mean(), hss[-1, i]) for i, threshold in enumerate(self._thresholds)])) logging.info(" MSE: %g/%g" % (mse.mean(), mse[-1])) logging.info(" MAE: %g/%g" % (mae.mean(), mae[-1])) logging.info(" Balanced MSE: %g/%g" % (balanced_mse.mean(), balanced_mse[-1])) logging.info(" Balanced MAE: %g/%g" % (balanced_mae.mean(), balanced_mae[-1])) logging.info(" GDL: %g/%g" % (gdl.mean(), gdl[-1])) if not self._no_ssim: raise NotImplementedError def save_pkl(self, path): dir_path = os.path.dirname(path) if not os.path.exists(dir_path): os.makedirs(dir_path) f = open(path, 'wb') logging.info("Saving HKOEvaluation to %s" %path) pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL) f.close() def save_txt_readable(self, path): dir_path = os.path.dirname(path) if not os.path.exists(dir_path): os.makedirs(dir_path) pod, far, csi, hss, gss, mse, mae, balanced_mse, balanced_mae, gdl = self.calculate_stat() # pod, far, csi, hss, gss, mse, mae, gdl = self.calculate_stat() f = open(path, 'w') logging.info("Saving readable txt of HKOEvaluation to %s" % path) f.write("Total Sequence Num: %d, Out Seq Len: %d, Use Central: %d\n" %(self._total_batch_num, self._seq_len, self._use_central)) for (i, threshold) in enumerate(self._thresholds): f.write("Threshold = %g:\n" %threshold) f.write(" POD: %s\n" %str(list(pod[:, i]))) f.write(" FAR: %s\n" % str(list(far[:, i]))) f.write(" CSI: %s\n" % str(list(csi[:, i]))) f.write(" GSS: %s\n" % str(list(gss[:, i]))) f.write(" HSS: %s\n" % str(list(hss[:, i]))) f.write(" POD stat: avg %g/final %g\n" %(pod[:, i].mean(), pod[-1, i])) f.write(" FAR stat: avg %g/final %g\n" %(far[:, i].mean(), far[-1, i])) f.write(" CSI stat: avg %g/final %g\n" %(csi[:, i].mean(), csi[-1, i])) f.write(" GSS stat: avg %g/final %g\n" %(gss[:, i].mean(), gss[-1, i])) f.write(" HSS stat: avg %g/final %g\n" % (hss[:, i].mean(), hss[-1, i])) f.write("MSE: %s\n" % str(list(mse))) f.write("MAE: %s\n" % str(list(mae))) f.write("Balanced MSE: %s\n" % str(list(balanced_mse))) f.write("Balanced MAE: %s\n" % str(list(balanced_mae))) f.write("GDL: %s\n" % str(list(gdl))) f.write("MSE stat: avg %g/final %g\n" % (mse.mean(), mse[-1])) f.write("MAE stat: avg %g/final %g\n" % (mae.mean(), mae[-1])) f.write("Balanced MSE stat: avg %g/final %g\n" % (balanced_mse.mean(), balanced_mse[-1])) f.write("Balanced MAE stat: avg %g/final %g\n" % (balanced_mae.mean(), balanced_mae[-1])) f.write("GDL stat: avg %g/final %g\n" % (gdl.mean(), gdl[-1])) f.close() def save(self, prefix): self.save_txt_readable(prefix + ".txt") self.save_pkl(prefix + ".pkl")