|
|
import numpy as np |
|
|
from numba import jit, float32, boolean, int32, float64 |
|
|
from nowcasting.hko_evaluation import rainfall_to_pixel |
|
|
from nowcasting.config import cfg |
|
|
|
|
|
@jit(float32(float32, float32, boolean)) |
|
|
def get_GDL_numba(prediction, truth, mask): |
|
|
"""Accelerated version of get_GDL using numba(http://numba.pydata.org/) |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
prediction |
|
|
truth |
|
|
mask |
|
|
|
|
|
Returns |
|
|
------- |
|
|
gdl |
|
|
""" |
|
|
seqlen, batch_size, _, height, width = prediction.shape |
|
|
gdl = np.zeros(shape=(seqlen, batch_size), dtype=np.float32) |
|
|
for i in range(seqlen): |
|
|
for j in range(batch_size): |
|
|
for m in range(height): |
|
|
for n in range(width): |
|
|
if m + 1 < height: |
|
|
if mask[i][j][0][m+1][n] and mask[i][j][0][m][n]: |
|
|
pred_diff_h = abs(prediction[i][j][0][m+1][n] - |
|
|
prediction[i][j][0][m][n]) |
|
|
gt_diff_h = abs(truth[i][j][0][m+1][n] - truth[i][j][0][m][n]) |
|
|
gdl[i][j] += abs(pred_diff_h - gt_diff_h) |
|
|
if n + 1 < width: |
|
|
if mask[i][j][0][m][n+1] and mask[i][j][0][m][n]: |
|
|
pred_diff_w = abs(prediction[i][j][0][m][n+1] - |
|
|
prediction[i][j][0][m][n]) |
|
|
gt_diff_w = abs(truth[i][j][0][m][n+1] - truth[i][j][0][m][n]) |
|
|
gdl[i][j] += abs(pred_diff_w - gt_diff_w) |
|
|
return gdl |
|
|
|
|
|
|
|
|
def get_hit_miss_counts_numba(prediction, truth, mask, thresholds=None): |
|
|
"""This function calculates the overall hits and misses for the prediction, which could be used |
|
|
to get the skill scores and threat scores: |
|
|
|
|
|
|
|
|
This function assumes the input, i.e, prediction and truth are 3-dim tensors, (timestep, row, col) |
|
|
and all inputs should be between 0~1 |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
prediction : np.ndarray |
|
|
Shape: (seq_len, batch_size, 1, height, width) |
|
|
truth : np.ndarray |
|
|
Shape: (seq_len, batch_size, 1, height, width) |
|
|
mask : np.ndarray or None |
|
|
Shape: (seq_len, batch_size, 1, height, width) |
|
|
0 --> not use |
|
|
1 --> use |
|
|
thresholds : list or tuple |
|
|
|
|
|
Returns |
|
|
------- |
|
|
hits : np.ndarray |
|
|
(seq_len, batch_size, len(thresholds)) |
|
|
TP |
|
|
misses : np.ndarray |
|
|
(seq_len, batch_size, len(thresholds)) |
|
|
FN |
|
|
false_alarms : np.ndarray |
|
|
(seq_len, batch_size, len(thresholds)) |
|
|
FP |
|
|
correct_negatives : np.ndarray |
|
|
(seq_len, batch_size, len(thresholds)) |
|
|
TN |
|
|
""" |
|
|
if thresholds is None: |
|
|
thresholds = cfg.HKO.EVALUATION.THRESHOLDS |
|
|
assert 5 == prediction.ndim |
|
|
assert 5 == truth.ndim |
|
|
assert prediction.shape == truth.shape |
|
|
assert prediction.shape[2] == 1 |
|
|
thresholds = [rainfall_to_pixel(thresholds[i]) for i in range(len(thresholds))] |
|
|
thresholds = sorted(thresholds) |
|
|
ret = _get_hit_miss_counts_numba(prediction=prediction, |
|
|
truth=truth, |
|
|
mask=mask, |
|
|
thresholds=thresholds) |
|
|
return ret[:, :, :, 0], ret[:, :, :, 1], ret[:, :, :, 2], ret[:, :, :, 3] |
|
|
|
|
|
|
|
|
@jit(int32(float32, float32, boolean, float32)) |
|
|
def _get_hit_miss_counts_numba(prediction, truth, mask, thresholds): |
|
|
seqlen, batch_size, _, height, width = prediction.shape |
|
|
threshold_num = len(thresholds) |
|
|
ret = np.zeros(shape=(seqlen, batch_size, threshold_num, 4), dtype=np.int32) |
|
|
|
|
|
for i in range(seqlen): |
|
|
for j in range(batch_size): |
|
|
for m in range(height): |
|
|
for n in range(width): |
|
|
if mask[i][j][0][m][n]: |
|
|
for k in range(threshold_num): |
|
|
bpred = prediction[i][j][0][m][n] >= thresholds[k] |
|
|
btruth = truth[i][j][0][m][n] >= thresholds[k] |
|
|
ind = (1 - btruth) * 2 + (1 - bpred) |
|
|
ret[i][j][k][ind] += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ret |
|
|
|
|
|
|
|
|
def get_balancing_weights_numba(data, mask, base_balancing_weights=None, thresholds=None): |
|
|
"""Get the balancing weights |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
data |
|
|
mask |
|
|
base_balancing_weights |
|
|
thresholds |
|
|
|
|
|
Returns |
|
|
------- |
|
|
|
|
|
""" |
|
|
if thresholds is None: |
|
|
thresholds = cfg.HKO.EVALUATION.THRESHOLDS |
|
|
if base_balancing_weights is None: |
|
|
base_balancing_weights = cfg.HKO.EVALUATION.BALANCING_WEIGHTS |
|
|
assert data.shape[2] == 1 |
|
|
thresholds = [rainfall_to_pixel(thresholds[i]) for i in range(len(thresholds))] |
|
|
thresholds = sorted(thresholds) |
|
|
ret = _get_balancing_weights_numba(data=data, |
|
|
mask=mask, |
|
|
base_balancing_weights=base_balancing_weights, |
|
|
thresholds=thresholds) |
|
|
return ret |
|
|
|
|
|
|
|
|
@jit(float32(float32, boolean, float32, float32)) |
|
|
def _get_balancing_weights_numba(data, mask, base_balancing_weights, thresholds): |
|
|
seqlen, batch_size, _, height, width = data.shape |
|
|
threshold_num = len(thresholds) |
|
|
ret = np.zeros(shape=(seqlen, batch_size, 1, height, width), dtype=np.float32) |
|
|
|
|
|
for i in range(seqlen): |
|
|
for j in range(batch_size): |
|
|
for m in range(height): |
|
|
for n in range(width): |
|
|
if mask[i][j][0][m][n]: |
|
|
ele = data[i][j][0][m][n] |
|
|
for k in range(threshold_num): |
|
|
if ele < thresholds[k]: |
|
|
ret[i][j][0][m][n] = base_balancing_weights[k] |
|
|
break |
|
|
if ele >= thresholds[threshold_num - 1]: |
|
|
ret[i][j][0][m][n] = base_balancing_weights[threshold_num] |
|
|
return ret |
|
|
|
|
|
if __name__ == '__main__': |
|
|
from nowcasting.hko_evaluation import get_GDL, get_hit_miss_counts, get_balancing_weights |
|
|
from numpy.testing import assert_allclose, assert_almost_equal |
|
|
|
|
|
prediction = np.random.uniform(size=(10, 16, 1, 480, 480)) |
|
|
truth = np.random.uniform(size=(10, 16, 1, 480, 480)) |
|
|
mask = np.random.randint(low=0, high=2, size=(10, 16, 1, 480, 480)).astype(np.bool) |
|
|
import time |
|
|
|
|
|
begin = time.time() |
|
|
gdl = get_GDL(prediction=prediction, truth=truth, mask=mask) |
|
|
end = time.time() |
|
|
|
|
|
print("numpy gdl:", end - begin) |
|
|
begin = time.time() |
|
|
gdl_numba = get_GDL_numba(prediction=prediction, truth=truth, mask=mask) |
|
|
end = time.time() |
|
|
print("numba gdl:", end - begin) |
|
|
|
|
|
|
|
|
assert_allclose(gdl, gdl_numba, rtol=1E-4, atol=1E-3) |
|
|
|
|
|
begin = time.time() |
|
|
for i in range(5): |
|
|
hits, misses, false_alarms, correct_negatives = get_hit_miss_counts(prediction=prediction, |
|
|
truth=truth, |
|
|
mask=mask) |
|
|
end = time.time() |
|
|
print("numpy hits misses:", end - begin) |
|
|
|
|
|
begin = time.time() |
|
|
for i in range(5): |
|
|
hits_numba, misses_numba, false_alarms_numba, correct_negatives_numba = get_hit_miss_counts_numba( |
|
|
prediction=prediction, |
|
|
truth=truth, |
|
|
mask=mask) |
|
|
end = time.time() |
|
|
print("numba hits misses:", end - begin) |
|
|
print(np.abs(hits - hits_numba).max()) |
|
|
print(np.abs(misses - misses_numba).max(), np.abs(misses - misses_numba).argmax()) |
|
|
print(np.abs(false_alarms - false_alarms_numba).max(), |
|
|
np.abs(false_alarms - false_alarms_numba).argmax()) |
|
|
print(np.abs(correct_negatives - correct_negatives_numba).max(), |
|
|
np.abs(correct_negatives - correct_negatives_numba).argmax()) |
|
|
|
|
|
begin = time.time() |
|
|
for i in range(5): |
|
|
weights_npy = get_balancing_weights(data=truth, mask=mask, |
|
|
base_balancing_weights=None, thresholds=None) |
|
|
end = time.time() |
|
|
print("numpy balancing weights:", end - begin) |
|
|
|
|
|
begin = time.time() |
|
|
for i in range(5): |
|
|
weights_numba = get_balancing_weights_numba(data=truth, mask=mask, |
|
|
base_balancing_weights=None, thresholds=None) |
|
|
end = time.time() |
|
|
print("numba balancing weights:", end - begin) |
|
|
print("Inconsistent Number:", (np.abs(weights_npy - weights_numba) > 1E-5).sum()) |
|
|
|