|
|
try: |
|
|
import cPickle as pickle |
|
|
except: |
|
|
import pickle |
|
|
import numpy as np |
|
|
import logging |
|
|
import os |
|
|
from collections import namedtuple |
|
|
from nowcasting.config import cfg |
|
|
from nowcasting.hko_iterator import get_exclude_mask |
|
|
from nowcasting.helpers.msssim import _SSIMForMultiScale |
|
|
|
|
|
def pixel_to_dBZ(img): |
|
|
""" |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
img : np.ndarray or float |
|
|
|
|
|
Returns |
|
|
------- |
|
|
|
|
|
""" |
|
|
return img * 70.0 - 10.0 |
|
|
|
|
|
|
|
|
def dBZ_to_pixel(dBZ_img): |
|
|
""" |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
dBZ_img : np.ndarray |
|
|
|
|
|
Returns |
|
|
------- |
|
|
|
|
|
""" |
|
|
return np.clip((dBZ_img + 10.0) / 70.0, a_min=0.0, a_max=1.0) |
|
|
|
|
|
|
|
|
def pixel_to_rainfall(img, a=None, b=None): |
|
|
"""Convert the pixel values to real rainfall intensity |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
img : np.ndarray |
|
|
a : float32, optional |
|
|
b : float32, optional |
|
|
|
|
|
Returns |
|
|
------- |
|
|
rainfall_intensity : np.ndarray |
|
|
""" |
|
|
if a is None: |
|
|
a = cfg.HKO.EVALUATION.ZR.a |
|
|
if b is None: |
|
|
b = cfg.HKO.EVALUATION.ZR.b |
|
|
dBZ = pixel_to_dBZ(img) |
|
|
dBR = (dBZ - 10.0 * np.log10(a)) / b |
|
|
rainfall_intensity = np.power(10, dBR / 10.0) |
|
|
return rainfall_intensity |
|
|
|
|
|
|
|
|
def rainfall_to_pixel(rainfall_intensity, a=None, b=None): |
|
|
"""Convert the rainfall intensity to pixel values |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
rainfall_intensity : np.ndarray |
|
|
a : float32, optional |
|
|
b : float32, optional |
|
|
|
|
|
Returns |
|
|
------- |
|
|
pixel_vals : np.ndarray |
|
|
""" |
|
|
if a is None: |
|
|
a = cfg.HKO.EVALUATION.ZR.a |
|
|
if b is None: |
|
|
b = cfg.HKO.EVALUATION.ZR.b |
|
|
dBR = np.log10(rainfall_intensity) * 10.0 |
|
|
dBZ = dBR * b + 10.0 * np.log10(a) |
|
|
pixel_vals = (dBZ + 10.0) / 70.0 |
|
|
return pixel_vals |
|
|
|
|
|
|
|
|
def get_hit_miss_counts(prediction, truth, mask=None, thresholds=None, sum_batch=False): |
|
|
"""This function calculates the overall hits and misses for the prediction, which could be used |
|
|
to get the skill scores and threat scores: |
|
|
|
|
|
|
|
|
This function assumes the input, i.e, prediction and truth are 3-dim tensors, (timestep, row, col) |
|
|
and all inputs should be between 0~1 |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
prediction : np.ndarray |
|
|
Shape: (seq_len, batch_size, 1, height, width) |
|
|
truth : np.ndarray |
|
|
Shape: (seq_len, batch_size, 1, height, width) |
|
|
mask : np.ndarray or None |
|
|
Shape: (seq_len, batch_size, 1, height, width) |
|
|
0 --> not use |
|
|
1 --> use |
|
|
thresholds : list or tuple |
|
|
|
|
|
Returns |
|
|
------- |
|
|
hits : np.ndarray |
|
|
(seq_len, len(thresholds)) or (seq_len, batch_size, len(thresholds)) |
|
|
TP |
|
|
misses : np.ndarray |
|
|
(seq_len, len(thresholds)) or (seq_len, batch_size, len(thresholds)) |
|
|
FN |
|
|
false_alarms : np.ndarray |
|
|
(seq_len, len(thresholds)) or (seq_len, batch_size, len(thresholds)) |
|
|
FP |
|
|
correct_negatives : np.ndarray |
|
|
(seq_len, len(thresholds)) or (seq_len, batch_size, len(thresholds)) |
|
|
TN |
|
|
""" |
|
|
if thresholds is None: |
|
|
thresholds = cfg.HKO.EVALUATION.THRESHOLDS |
|
|
assert 5 == prediction.ndim |
|
|
assert 5 == truth.ndim |
|
|
assert prediction.shape == truth.shape |
|
|
assert prediction.shape[2] == 1 |
|
|
thresholds = rainfall_to_pixel(np.array(thresholds, |
|
|
dtype=np.float32) |
|
|
.reshape((1, 1, len(thresholds), 1, 1))) |
|
|
bpred = (prediction >= thresholds) |
|
|
btruth = (truth >= thresholds) |
|
|
bpred_n = np.logical_not(bpred) |
|
|
btruth_n = np.logical_not(btruth) |
|
|
if sum_batch: |
|
|
summation_axis = (1, 3, 4) |
|
|
else: |
|
|
summation_axis = (3, 4) |
|
|
if mask is None: |
|
|
hits = np.logical_and(bpred, btruth).sum(axis=summation_axis) |
|
|
misses = np.logical_and(bpred_n, btruth).sum(axis=summation_axis) |
|
|
false_alarms = np.logical_and(bpred, btruth_n).sum(axis=summation_axis) |
|
|
correct_negatives = np.logical_and(bpred_n, btruth_n).sum(axis=summation_axis) |
|
|
else: |
|
|
hits = np.logical_and(np.logical_and(bpred, btruth), mask)\ |
|
|
.sum(axis=summation_axis) |
|
|
misses = np.logical_and(np.logical_and(bpred_n, btruth), mask)\ |
|
|
.sum(axis=summation_axis) |
|
|
false_alarms = np.logical_and(np.logical_and(bpred, btruth_n), mask)\ |
|
|
.sum(axis=summation_axis) |
|
|
correct_negatives = np.logical_and(np.logical_and(bpred_n, btruth_n), mask)\ |
|
|
.sum(axis=summation_axis) |
|
|
return hits, misses, false_alarms, correct_negatives |
|
|
|
|
|
|
|
|
def get_correlation(prediction, truth): |
|
|
""" |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
prediction : np.ndarray |
|
|
truth : np.ndarray |
|
|
|
|
|
Returns |
|
|
------- |
|
|
|
|
|
""" |
|
|
assert truth.shape == prediction.shape |
|
|
assert 5 == prediction.ndim |
|
|
assert prediction.shape[2] == 1 |
|
|
eps = 1E-12 |
|
|
ret = (prediction * truth).sum(axis=(3, 4)) / ( |
|
|
np.sqrt(np.square(prediction).sum(axis=(3, 4))) * np.sqrt(np.square(truth).sum(axis=(3, 4))) + eps) |
|
|
ret = ret.sum(axis=(1, 2)) |
|
|
return ret |
|
|
|
|
|
|
|
|
def get_rainfall_mse(prediction, truth): |
|
|
ret = np.square(pixel_to_rainfall(prediction) - pixel_to_rainfall(truth)).mean(axis=(2, 3)) |
|
|
ret = ret.sum(axis=1) |
|
|
return ret |
|
|
|
|
|
|
|
|
def get_PSNR(prediction, truth): |
|
|
"""Peak Signal Noise Ratio |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
prediction : np.ndarray |
|
|
truth : np.ndarray |
|
|
|
|
|
Returns |
|
|
------- |
|
|
ret : np.ndarray |
|
|
""" |
|
|
mse = np.square(prediction - truth).mean(axis=(2, 3, 4)) |
|
|
ret = 10.0 * np.log10(1.0 / mse) |
|
|
ret = ret.sum(axis=1) |
|
|
return ret |
|
|
|
|
|
|
|
|
def get_SSIM(prediction, truth): |
|
|
"""Calculate the SSIM score following |
|
|
[TIP2004] Image Quality Assessment: From Error Visibility to Structural Similarity |
|
|
|
|
|
Same functionality as |
|
|
https://github.com/coupriec/VideoPredictionICLR2016/blob/master/image_error_measures.lua#L50-L75 |
|
|
|
|
|
We use nowcasting.helpers.msssim, which is borrowed from Tensorflow to do the evaluation |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
prediction : np.ndarray |
|
|
truth : np.ndarray |
|
|
|
|
|
Returns |
|
|
------- |
|
|
ret : np.ndarray |
|
|
""" |
|
|
assert truth.shape == prediction.shape |
|
|
assert 5 == prediction.ndim |
|
|
assert prediction.shape[2] == 1 |
|
|
seq_len = prediction.shape[0] |
|
|
batch_size = prediction.shape[1] |
|
|
prediction = prediction.reshape((prediction.shape[0] * prediction.shape[1], |
|
|
prediction.shape[3], prediction.shape[4], 1)) |
|
|
truth = truth.reshape((truth.shape[0] * truth.shape[1], |
|
|
truth.shape[3], truth.shape[4], 1)) |
|
|
ssim, cs = _SSIMForMultiScale(img1=prediction, img2=truth, max_val=1.0) |
|
|
print(ssim.shape) |
|
|
ret = ssim.reshape((seq_len, batch_size)).sum(axis=1) |
|
|
return ret |
|
|
|
|
|
|
|
|
def get_GDL(prediction, truth, mask, sum_batch=False): |
|
|
"""Calculate the masked gradient difference loss |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
prediction : np.ndarray |
|
|
Shape: (seq_len, batch_size, 1, height, width) |
|
|
truth : np.ndarray |
|
|
Shape: (seq_len, batch_size, 1, height, width) |
|
|
mask : np.ndarray or None |
|
|
Shape: (seq_len, batch_size, 1, height, width) |
|
|
0 --> not use |
|
|
1 --> use |
|
|
|
|
|
Returns |
|
|
------- |
|
|
gdl : np.ndarray |
|
|
Shape: (seq_len,) or (seq_len, batch_size) |
|
|
""" |
|
|
prediction_diff_h = np.abs(np.diff(prediction, axis=3)) |
|
|
prediction_diff_w = np.abs(np.diff(prediction, axis=4)) |
|
|
gt_diff_h = np.abs(np.diff(truth, axis=3)) |
|
|
gt_diff_w = np.abs(np.diff(truth, axis=4)) |
|
|
mask_h = mask[:, :, :, :-1, :] * mask[:, :, :, 1:, :] |
|
|
mask_w = mask[:, :, :, :, :-1] * mask[:, :, :, :, 1:] |
|
|
gd_h = np.abs(prediction_diff_h - gt_diff_h) |
|
|
gd_w = np.abs(prediction_diff_w - gt_diff_w) |
|
|
gd_h[:] *= mask_h |
|
|
gd_w[:] *= mask_w |
|
|
summation_axis = (1, 2, 3, 4) if sum_batch else (2, 3, 4) |
|
|
gdl = np.sum(gd_h, axis=summation_axis) + np.sum(gd_w, axis=summation_axis) |
|
|
return gdl |
|
|
|
|
|
|
|
|
def get_balancing_weights(data, mask, base_balancing_weights=None, thresholds=None): |
|
|
if thresholds is None: |
|
|
thresholds = cfg.HKO.EVALUATION.THRESHOLDS |
|
|
if base_balancing_weights is None: |
|
|
base_balancing_weights = cfg.HKO.EVALUATION.BALANCING_WEIGHTS |
|
|
thresholds = rainfall_to_pixel(np.array(thresholds, dtype=np.float32) |
|
|
.reshape((1, 1, 1, 1, 1, len(thresholds)))) |
|
|
weights = np.ones_like(data) * base_balancing_weights[0] |
|
|
threshold_mask = np.expand_dims(data, axis=5) >= thresholds |
|
|
base_weights = np.diff(np.array(base_balancing_weights, dtype=np.float32))\ |
|
|
.reshape((1, 1, 1, 1, 1, len(base_balancing_weights) - 1)) |
|
|
weights += (threshold_mask * base_weights).sum(axis=-1) |
|
|
weights *= mask |
|
|
return weights |
|
|
|
|
|
|
|
|
try: |
|
|
from nowcasting.numba_accelerated import get_GDL_numba, get_hit_miss_counts_numba,\ |
|
|
get_balancing_weights_numba |
|
|
except: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raise ImportError("Numba has not been installed correctly!") |
|
|
|
|
|
class HKOEvaluation(object): |
|
|
def __init__(self, seq_len, use_central, no_ssim=True, threholds=None, |
|
|
central_region=None): |
|
|
if central_region is None: |
|
|
central_region = cfg.HKO.EVALUATION.CENTRAL_REGION |
|
|
self._thresholds = cfg.HKO.EVALUATION.THRESHOLDS if threholds is None else threholds |
|
|
self._seq_len = seq_len |
|
|
self._no_ssim = no_ssim |
|
|
self._use_central = use_central |
|
|
self._central_region = central_region |
|
|
self._exclude_mask = get_exclude_mask() |
|
|
self.begin() |
|
|
|
|
|
def begin(self): |
|
|
self._total_hits = np.zeros((self._seq_len, len(self._thresholds)), dtype=np.int) |
|
|
self._total_misses = np.zeros((self._seq_len, len(self._thresholds)), dtype=np.int) |
|
|
self._total_false_alarms = np.zeros((self._seq_len, len(self._thresholds)), dtype=np.int) |
|
|
self._total_correct_negatives = np.zeros((self._seq_len, len(self._thresholds)), |
|
|
dtype=np.int) |
|
|
self._mse = np.zeros((self._seq_len, ), dtype=np.float32) |
|
|
self._mae = np.zeros((self._seq_len, ), dtype=np.float32) |
|
|
self._balanced_mse = np.zeros((self._seq_len, ), dtype=np.float32) |
|
|
self._balanced_mae = np.zeros((self._seq_len,), dtype=np.float32) |
|
|
self._gdl = np.zeros((self._seq_len,), dtype=np.float32) |
|
|
self._ssim = np.zeros((self._seq_len,), dtype=np.float32) |
|
|
self._datetime_dict = {} |
|
|
self._total_batch_num = 0 |
|
|
|
|
|
def clear_all(self): |
|
|
self._total_hits[:] = 0 |
|
|
self._total_misses[:] = 0 |
|
|
self._total_false_alarms[:] = 0 |
|
|
self._total_correct_negatives[:] = 0 |
|
|
self._mse[:] = 0 |
|
|
self._mae[:] = 0 |
|
|
self._gdl[:] = 0 |
|
|
self._ssim[:] = 0 |
|
|
self._total_batch_num = 0 |
|
|
|
|
|
def update(self, gt, pred, mask, start_datetimes=None): |
|
|
""" |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
gt : np.ndarray |
|
|
pred : np.ndarray |
|
|
mask : np.ndarray |
|
|
0 indicates not use and 1 indicates that the location will be taken into account |
|
|
start_datetimes : list |
|
|
The starting datetimes of all the testing instances |
|
|
|
|
|
Returns |
|
|
------- |
|
|
|
|
|
""" |
|
|
if start_datetimes is not None: |
|
|
batch_size = len(start_datetimes) |
|
|
assert gt.shape[1] == batch_size |
|
|
else: |
|
|
batch_size = gt.shape[1] |
|
|
assert gt.shape[0] == self._seq_len |
|
|
assert gt.shape == pred.shape |
|
|
assert gt.shape == mask.shape |
|
|
|
|
|
if self._use_central: |
|
|
|
|
|
pred = pred[:, :, :, |
|
|
self._central_region[1]:self._central_region[3], |
|
|
self._central_region[0]:self._central_region[2]] |
|
|
gt = gt[:, :, :, |
|
|
self._central_region[1]:self._central_region[3], |
|
|
self._central_region[0]:self._central_region[2]] |
|
|
mask = mask[:, :, :, |
|
|
self._central_region[1]:self._central_region[3], |
|
|
self._central_region[0]:self._central_region[2]] |
|
|
self._total_batch_num += batch_size |
|
|
|
|
|
mse = (mask * np.square(pred - gt)).sum(axis=(2, 3, 4)) |
|
|
mae = (mask * np.abs(pred - gt)).sum(axis=(2, 3, 4)) |
|
|
weights = get_balancing_weights_numba(data=gt, mask=mask, |
|
|
base_balancing_weights=cfg.HKO.EVALUATION.BALANCING_WEIGHTS, |
|
|
thresholds=self._thresholds) |
|
|
balanced_mse = (weights * np.square(pred - gt)).sum(axis=(2, 3, 4)) |
|
|
balanced_mae = (weights * np.abs(pred - gt)).sum(axis=(2, 3, 4)) |
|
|
gdl = get_GDL_numba(prediction=pred, truth=gt, mask=mask) |
|
|
self._mse += mse.sum(axis=1) |
|
|
self._mae += mae.sum(axis=1) |
|
|
self._balanced_mse += balanced_mse.sum(axis=1) |
|
|
self._balanced_mae += balanced_mae.sum(axis=1) |
|
|
self._gdl += gdl.sum(axis=1) |
|
|
if not self._no_ssim: |
|
|
raise NotImplementedError |
|
|
|
|
|
hits, misses, false_alarms, correct_negatives = \ |
|
|
get_hit_miss_counts_numba(prediction=pred, truth=gt, mask=mask, |
|
|
thresholds=self._thresholds) |
|
|
self._total_hits += hits.sum(axis=1) |
|
|
self._total_misses += misses.sum(axis=1) |
|
|
self._total_false_alarms += false_alarms.sum(axis=1) |
|
|
self._total_correct_negatives += correct_negatives.sum(axis=1) |
|
|
|
|
|
def calculate_stat(self): |
|
|
"""The following measurements will be used to measure the score of the forecaster |
|
|
|
|
|
See Also |
|
|
[Weather and Forecasting 2010] Equitability Revisited: Why the "Equitable Threat Score" Is Not Equitable |
|
|
http://www.wxonline.info/topics/verif2.html |
|
|
|
|
|
We will denote |
|
|
(a b (hits false alarms |
|
|
c d) = misses correct negatives) |
|
|
|
|
|
We will report the |
|
|
POD = a / (a + c) |
|
|
FAR = b / (a + b) |
|
|
CSI = a / (a + b + c) |
|
|
Heidke Skill Score (HSS) = 2(ad - bc) / ((a+c) (c+d) + (a+b)(b+d)) |
|
|
Gilbert Skill Score (GSS) = HSS / (2 - HSS), also known as the Equitable Threat Score |
|
|
HSS = 2 * GSS / (GSS + 1) |
|
|
MSE = mask * (pred - gt) **2 |
|
|
MAE = mask * abs(pred - gt) |
|
|
GDL = valid_mask_h * abs(gd_h(pred) - gd_h(gt)) + valid_mask_w * abs(gd_w(pred) - gd_w(gt)) |
|
|
Returns |
|
|
------- |
|
|
|
|
|
""" |
|
|
a = self._total_hits.astype(np.float64) |
|
|
b = self._total_false_alarms.astype(np.float64) |
|
|
c = self._total_misses.astype(np.float64) |
|
|
d = self._total_correct_negatives.astype(np.float64) |
|
|
pod = a / (a + c) |
|
|
far = b / (a + b) |
|
|
csi = a / (a + b + c) |
|
|
n = a + b + c + d |
|
|
aref = (a + b) / n * (a + c) |
|
|
gss = (a - aref) / (a + b + c - aref) |
|
|
hss = 2 * gss / (gss + 1) |
|
|
mse = self._mse / self._total_batch_num |
|
|
mae = self._mae / self._total_batch_num |
|
|
balanced_mse = self._balanced_mse / self._total_batch_num |
|
|
balanced_mae = self._balanced_mae / self._total_batch_num |
|
|
gdl = self._gdl / self._total_batch_num |
|
|
if not self._no_ssim: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
return pod, far, csi, hss, gss, mse, mae, balanced_mse, balanced_mae, gdl |
|
|
|
|
|
def print_stat_readable(self, prefix=""): |
|
|
logging.info("%sTotal Sequence Number: %d, Use Central: %d" |
|
|
%(prefix, self._total_batch_num, self._use_central)) |
|
|
pod, far, csi, hss, gss, mse, mae, balanced_mse, balanced_mae, gdl = self.calculate_stat() |
|
|
|
|
|
logging.info(" Hits: " + ', '.join([">%g:%g/%g" % (threshold, |
|
|
self._total_hits[:, i].mean(), |
|
|
self._total_hits[-1, i]) |
|
|
for i, threshold in enumerate(self._thresholds)])) |
|
|
logging.info(" POD: " + ', '.join([">%g:%g/%g" % (threshold, pod[:, i].mean(), pod[-1, i]) |
|
|
for i, threshold in enumerate(self._thresholds)])) |
|
|
logging.info(" FAR: " + ', '.join([">%g:%g/%g" % (threshold, far[:, i].mean(), far[-1, i]) |
|
|
for i, threshold in enumerate(self._thresholds)])) |
|
|
logging.info(" CSI: " + ', '.join([">%g:%g/%g" % (threshold, csi[:, i].mean(), csi[-1, i]) |
|
|
for i, threshold in enumerate(self._thresholds)])) |
|
|
logging.info(" GSS: " + ', '.join([">%g:%g/%g" % (threshold, gss[:, i].mean(), gss[-1, i]) |
|
|
for i, threshold in enumerate(self._thresholds)])) |
|
|
logging.info(" HSS: " + ', '.join([">%g:%g/%g" % (threshold, hss[:, i].mean(), hss[-1, i]) |
|
|
for i, threshold in enumerate(self._thresholds)])) |
|
|
logging.info(" MSE: %g/%g" % (mse.mean(), mse[-1])) |
|
|
logging.info(" MAE: %g/%g" % (mae.mean(), mae[-1])) |
|
|
logging.info(" Balanced MSE: %g/%g" % (balanced_mse.mean(), balanced_mse[-1])) |
|
|
logging.info(" Balanced MAE: %g/%g" % (balanced_mae.mean(), balanced_mae[-1])) |
|
|
logging.info(" GDL: %g/%g" % (gdl.mean(), gdl[-1])) |
|
|
if not self._no_ssim: |
|
|
raise NotImplementedError |
|
|
|
|
|
def save_pkl(self, path): |
|
|
dir_path = os.path.dirname(path) |
|
|
if not os.path.exists(dir_path): |
|
|
os.makedirs(dir_path) |
|
|
f = open(path, 'wb') |
|
|
logging.info("Saving HKOEvaluation to %s" %path) |
|
|
pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
f.close() |
|
|
|
|
|
def save_txt_readable(self, path): |
|
|
dir_path = os.path.dirname(path) |
|
|
if not os.path.exists(dir_path): |
|
|
os.makedirs(dir_path) |
|
|
pod, far, csi, hss, gss, mse, mae, balanced_mse, balanced_mae, gdl = self.calculate_stat() |
|
|
|
|
|
f = open(path, 'w') |
|
|
logging.info("Saving readable txt of HKOEvaluation to %s" % path) |
|
|
f.write("Total Sequence Num: %d, Out Seq Len: %d, Use Central: %d\n" |
|
|
%(self._total_batch_num, |
|
|
self._seq_len, |
|
|
self._use_central)) |
|
|
for (i, threshold) in enumerate(self._thresholds): |
|
|
f.write("Threshold = %g:\n" %threshold) |
|
|
f.write(" POD: %s\n" %str(list(pod[:, i]))) |
|
|
f.write(" FAR: %s\n" % str(list(far[:, i]))) |
|
|
f.write(" CSI: %s\n" % str(list(csi[:, i]))) |
|
|
f.write(" GSS: %s\n" % str(list(gss[:, i]))) |
|
|
f.write(" HSS: %s\n" % str(list(hss[:, i]))) |
|
|
f.write(" POD stat: avg %g/final %g\n" %(pod[:, i].mean(), pod[-1, i])) |
|
|
f.write(" FAR stat: avg %g/final %g\n" %(far[:, i].mean(), far[-1, i])) |
|
|
f.write(" CSI stat: avg %g/final %g\n" %(csi[:, i].mean(), csi[-1, i])) |
|
|
f.write(" GSS stat: avg %g/final %g\n" %(gss[:, i].mean(), gss[-1, i])) |
|
|
f.write(" HSS stat: avg %g/final %g\n" % (hss[:, i].mean(), hss[-1, i])) |
|
|
f.write("MSE: %s\n" % str(list(mse))) |
|
|
f.write("MAE: %s\n" % str(list(mae))) |
|
|
f.write("Balanced MSE: %s\n" % str(list(balanced_mse))) |
|
|
f.write("Balanced MAE: %s\n" % str(list(balanced_mae))) |
|
|
f.write("GDL: %s\n" % str(list(gdl))) |
|
|
f.write("MSE stat: avg %g/final %g\n" % (mse.mean(), mse[-1])) |
|
|
f.write("MAE stat: avg %g/final %g\n" % (mae.mean(), mae[-1])) |
|
|
f.write("Balanced MSE stat: avg %g/final %g\n" % (balanced_mse.mean(), balanced_mse[-1])) |
|
|
f.write("Balanced MAE stat: avg %g/final %g\n" % (balanced_mae.mean(), balanced_mae[-1])) |
|
|
f.write("GDL stat: avg %g/final %g\n" % (gdl.mean(), gdl[-1])) |
|
|
f.close() |
|
|
|
|
|
def save(self, prefix): |
|
|
self.save_txt_readable(prefix + ".txt") |
|
|
self.save_pkl(prefix + ".pkl") |
|
|
|