|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import mxnet as mx |
|
|
from nowcasting.config import cfg |
|
|
from nowcasting.hko_evaluation import rainfall_to_pixel |
|
|
from nowcasting.encoder_forecaster import EncoderForecasterBaseFactory |
|
|
from nowcasting.operators import * |
|
|
from nowcasting.ops import * |
|
|
|
|
|
|
|
|
def get_loss_weight_symbol(data, mask, seq_len): |
|
|
if cfg.MODEL.USE_BALANCED_LOSS: |
|
|
balancing_weights = cfg.HKO.EVALUATION.BALANCING_WEIGHTS |
|
|
weights = mx.sym.ones_like(data) * balancing_weights[0] |
|
|
thresholds = [rainfall_to_pixel(ele) for ele in cfg.HKO.EVALUATION.THRESHOLDS] |
|
|
for i, threshold in enumerate(thresholds): |
|
|
weights = weights + (balancing_weights[i + 1] - balancing_weights[i]) * (data >= threshold) |
|
|
weights = weights * mask |
|
|
else: |
|
|
weights = mask |
|
|
if cfg.MODEL.TEMPORAL_WEIGHT_TYPE == "same": |
|
|
return weights |
|
|
elif cfg.MODEL.TEMPORAL_WEIGHT_TYPE == "linear": |
|
|
upper = cfg.MODEL.TEMPORAL_WEIGHT_UPPER |
|
|
assert upper >= 1.0 |
|
|
temporal_mult = 1 + \ |
|
|
mx.sym.arange(start=0, stop=seq_len) * (upper - 1.0) / (seq_len - 1.0) |
|
|
temporal_mult = mx.sym.reshape(temporal_mult, shape=(seq_len, 1, 1, 1, 1)) |
|
|
weights = mx.sym.broadcast_mul(weights, temporal_mult) |
|
|
return weights |
|
|
elif cfg.MODEL.TEMPORAL_WEIGHT_TYPE == "exponential": |
|
|
upper = cfg.MODEL.TEMPORAL_WEIGHT_UPPER |
|
|
assert upper >= 1.0 |
|
|
base_factor = np.log(upper) / (seq_len - 1.0) |
|
|
temporal_mult = mx.sym.exp(mx.sym.arange(start=0, stop=seq_len) * base_factor) |
|
|
temporal_mult = mx.sym.reshape(temporal_mult, shape=(seq_len, 1, 1, 1, 1)) |
|
|
weights = mx.sym.broadcast_mul(weights, temporal_mult) |
|
|
return weights |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
class HKONowcastingFactory(EncoderForecasterBaseFactory): |
|
|
def __init__(self, |
|
|
batch_size, |
|
|
in_seq_len, |
|
|
out_seq_len, |
|
|
name="hko_nowcasting"): |
|
|
super(HKONowcastingFactory, self).__init__(batch_size=batch_size, |
|
|
in_seq_len=in_seq_len, |
|
|
out_seq_len=out_seq_len, |
|
|
height=cfg.HKO.ITERATOR.HEIGHT, |
|
|
width=cfg.HKO.ITERATOR.WIDTH, |
|
|
name=name) |
|
|
self._central_region = cfg.HKO.EVALUATION.CENTRAL_REGION |
|
|
|
|
|
def _slice_central(self, data): |
|
|
"""Slice the central region in the given symbol |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
data : mx.sym.Symbol |
|
|
|
|
|
Returns |
|
|
------- |
|
|
ret : mx.sym.Symbol |
|
|
""" |
|
|
x_begin, y_begin, x_end, y_end = self._central_region |
|
|
return mx.sym.slice(data, |
|
|
begin=(0, 0, 0, y_begin, x_begin), |
|
|
end=(None, None, None, y_end, x_end)) |
|
|
|
|
|
def _concat_month_code(self): |
|
|
|
|
|
raise NotImplementedError |
|
|
|
|
|
def loss_sym(self, |
|
|
pred=mx.sym.Variable('pred'), |
|
|
mask=mx.sym.Variable('mask'), |
|
|
target=mx.sym.Variable('target')): |
|
|
"""Construct loss symbol. |
|
|
|
|
|
Optional args: |
|
|
pred: Shape (out_seq_len, batch_size, C, H, W) |
|
|
mask: Shape (out_seq_len, batch_size, C, H, W) |
|
|
target: Shape (out_seq_len, batch_size, C, H, W) |
|
|
""" |
|
|
self.reset_all() |
|
|
weights = get_loss_weight_symbol(data=target, mask=mask, seq_len=self._out_seq_len) |
|
|
mse = weighted_mse(pred=pred, gt=target, weight=weights) |
|
|
mae = weighted_mae(pred=pred, gt=target, weight=weights) |
|
|
gdl = masked_gdl_loss(pred=pred, gt=target, mask=mask) |
|
|
avg_mse = mx.sym.mean(mse) |
|
|
avg_mae = mx.sym.mean(mae) |
|
|
avg_gdl = mx.sym.mean(gdl) |
|
|
global_grad_scale = cfg.MODEL.NORMAL_LOSS_GLOBAL_SCALE |
|
|
if cfg.MODEL.L2_LAMBDA > 0: |
|
|
avg_mse = mx.sym.MakeLoss(avg_mse, |
|
|
grad_scale=global_grad_scale * cfg.MODEL.L2_LAMBDA, |
|
|
name="mse") |
|
|
else: |
|
|
avg_mse = mx.sym.BlockGrad(avg_mse, name="mse") |
|
|
if cfg.MODEL.L1_LAMBDA > 0: |
|
|
avg_mae = mx.sym.MakeLoss(avg_mae, |
|
|
grad_scale=global_grad_scale * cfg.MODEL.L1_LAMBDA, |
|
|
name="mae") |
|
|
else: |
|
|
avg_mae = mx.sym.BlockGrad(avg_mae, name="mae") |
|
|
if cfg.MODEL.GDL_LAMBDA > 0: |
|
|
avg_gdl = mx.sym.MakeLoss(avg_gdl, |
|
|
grad_scale=global_grad_scale * cfg.MODEL.GDL_LAMBDA, |
|
|
name="gdl") |
|
|
else: |
|
|
avg_gdl = mx.sym.BlockGrad(avg_gdl, name="gdl") |
|
|
loss = mx.sym.Group([avg_mse, avg_mae, avg_gdl]) |
|
|
return loss |
|
|
|