STLDM_official / nowcasting /numba_accelerated.py
sqfoo's picture
Upload 99 files
6021dd1 verified
import numpy as np
from numba import jit, float32, boolean, int32, float64
from nowcasting.hko_evaluation import rainfall_to_pixel
from nowcasting.config import cfg
@jit(float32(float32, float32, boolean))
def get_GDL_numba(prediction, truth, mask):
"""Accelerated version of get_GDL using numba(http://numba.pydata.org/)
Parameters
----------
prediction
truth
mask
Returns
-------
gdl
"""
seqlen, batch_size, _, height, width = prediction.shape
gdl = np.zeros(shape=(seqlen, batch_size), dtype=np.float32)
for i in range(seqlen):
for j in range(batch_size):
for m in range(height):
for n in range(width):
if m + 1 < height:
if mask[i][j][0][m+1][n] and mask[i][j][0][m][n]:
pred_diff_h = abs(prediction[i][j][0][m+1][n] -
prediction[i][j][0][m][n])
gt_diff_h = abs(truth[i][j][0][m+1][n] - truth[i][j][0][m][n])
gdl[i][j] += abs(pred_diff_h - gt_diff_h)
if n + 1 < width:
if mask[i][j][0][m][n+1] and mask[i][j][0][m][n]:
pred_diff_w = abs(prediction[i][j][0][m][n+1] -
prediction[i][j][0][m][n])
gt_diff_w = abs(truth[i][j][0][m][n+1] - truth[i][j][0][m][n])
gdl[i][j] += abs(pred_diff_w - gt_diff_w)
return gdl
def get_hit_miss_counts_numba(prediction, truth, mask, thresholds=None):
"""This function calculates the overall hits and misses for the prediction, which could be used
to get the skill scores and threat scores:
This function assumes the input, i.e, prediction and truth are 3-dim tensors, (timestep, row, col)
and all inputs should be between 0~1
Parameters
----------
prediction : np.ndarray
Shape: (seq_len, batch_size, 1, height, width)
truth : np.ndarray
Shape: (seq_len, batch_size, 1, height, width)
mask : np.ndarray or None
Shape: (seq_len, batch_size, 1, height, width)
0 --> not use
1 --> use
thresholds : list or tuple
Returns
-------
hits : np.ndarray
(seq_len, batch_size, len(thresholds))
TP
misses : np.ndarray
(seq_len, batch_size, len(thresholds))
FN
false_alarms : np.ndarray
(seq_len, batch_size, len(thresholds))
FP
correct_negatives : np.ndarray
(seq_len, batch_size, len(thresholds))
TN
"""
if thresholds is None:
thresholds = cfg.HKO.EVALUATION.THRESHOLDS
assert 5 == prediction.ndim
assert 5 == truth.ndim
assert prediction.shape == truth.shape
assert prediction.shape[2] == 1
thresholds = [rainfall_to_pixel(thresholds[i]) for i in range(len(thresholds))]
thresholds = sorted(thresholds)
ret = _get_hit_miss_counts_numba(prediction=prediction,
truth=truth,
mask=mask,
thresholds=thresholds)
return ret[:, :, :, 0], ret[:, :, :, 1], ret[:, :, :, 2], ret[:, :, :, 3]
@jit(int32(float32, float32, boolean, float32))
def _get_hit_miss_counts_numba(prediction, truth, mask, thresholds):
seqlen, batch_size, _, height, width = prediction.shape
threshold_num = len(thresholds)
ret = np.zeros(shape=(seqlen, batch_size, threshold_num, 4), dtype=np.int32)
for i in range(seqlen):
for j in range(batch_size):
for m in range(height):
for n in range(width):
if mask[i][j][0][m][n]:
for k in range(threshold_num):
bpred = prediction[i][j][0][m][n] >= thresholds[k]
btruth = truth[i][j][0][m][n] >= thresholds[k]
ind = (1 - btruth) * 2 + (1 - bpred)
ret[i][j][k][ind] += 1
# The above code is the same as:
# ret[i][j][k][0] += bpred * btruth
# ret[i][j][k][1] += (1 - bpred) * btruth
# ret[i][j][k][2] += bpred * (1 - btruth)
# ret[i][j][k][3] += (1 - bpred) * (1- btruth)
return ret
def get_balancing_weights_numba(data, mask, base_balancing_weights=None, thresholds=None):
"""Get the balancing weights
Parameters
----------
data
mask
base_balancing_weights
thresholds
Returns
-------
"""
if thresholds is None:
thresholds = cfg.HKO.EVALUATION.THRESHOLDS
if base_balancing_weights is None:
base_balancing_weights = cfg.HKO.EVALUATION.BALANCING_WEIGHTS
assert data.shape[2] == 1
thresholds = [rainfall_to_pixel(thresholds[i]) for i in range(len(thresholds))]
thresholds = sorted(thresholds)
ret = _get_balancing_weights_numba(data=data,
mask=mask,
base_balancing_weights=base_balancing_weights,
thresholds=thresholds)
return ret
@jit(float32(float32, boolean, float32, float32))
def _get_balancing_weights_numba(data, mask, base_balancing_weights, thresholds):
seqlen, batch_size, _, height, width = data.shape
threshold_num = len(thresholds)
ret = np.zeros(shape=(seqlen, batch_size, 1, height, width), dtype=np.float32)
for i in range(seqlen):
for j in range(batch_size):
for m in range(height):
for n in range(width):
if mask[i][j][0][m][n]:
ele = data[i][j][0][m][n]
for k in range(threshold_num):
if ele < thresholds[k]:
ret[i][j][0][m][n] = base_balancing_weights[k]
break
if ele >= thresholds[threshold_num - 1]:
ret[i][j][0][m][n] = base_balancing_weights[threshold_num]
return ret
if __name__ == '__main__':
from nowcasting.hko_evaluation import get_GDL, get_hit_miss_counts, get_balancing_weights
from numpy.testing import assert_allclose, assert_almost_equal
prediction = np.random.uniform(size=(10, 16, 1, 480, 480))
truth = np.random.uniform(size=(10, 16, 1, 480, 480))
mask = np.random.randint(low=0, high=2, size=(10, 16, 1, 480, 480)).astype(np.bool)
import time
begin = time.time()
gdl = get_GDL(prediction=prediction, truth=truth, mask=mask)
end = time.time()
print("numpy gdl:", end - begin)
begin = time.time()
gdl_numba = get_GDL_numba(prediction=prediction, truth=truth, mask=mask)
end = time.time()
print("numba gdl:", end - begin)
# gdl_mx = mx_get_GDL(prediction=prediction, truth=truth, mask=mask)
# print gdl_mx
assert_allclose(gdl, gdl_numba, rtol=1E-4, atol=1E-3)
begin = time.time()
for i in range(5):
hits, misses, false_alarms, correct_negatives = get_hit_miss_counts(prediction=prediction,
truth=truth,
mask=mask)
end = time.time()
print("numpy hits misses:", end - begin)
begin = time.time()
for i in range(5):
hits_numba, misses_numba, false_alarms_numba, correct_negatives_numba = get_hit_miss_counts_numba(
prediction=prediction,
truth=truth,
mask=mask)
end = time.time()
print("numba hits misses:", end - begin)
print(np.abs(hits - hits_numba).max())
print(np.abs(misses - misses_numba).max(), np.abs(misses - misses_numba).argmax())
print(np.abs(false_alarms - false_alarms_numba).max(),
np.abs(false_alarms - false_alarms_numba).argmax())
print(np.abs(correct_negatives - correct_negatives_numba).max(),
np.abs(correct_negatives - correct_negatives_numba).argmax())
begin = time.time()
for i in range(5):
weights_npy = get_balancing_weights(data=truth, mask=mask,
base_balancing_weights=None, thresholds=None)
end = time.time()
print("numpy balancing weights:", end - begin)
begin = time.time()
for i in range(5):
weights_numba = get_balancing_weights_numba(data=truth, mask=mask,
base_balancing_weights=None, thresholds=None)
end = time.time()
print("numba balancing weights:", end - begin)
print("Inconsistent Number:", (np.abs(weights_npy - weights_numba) > 1E-5).sum())