File size: 838 Bytes
d3dbf03 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | # Copyright (c) OpenMMLab. All rights reserved.
import torch
from numpy.testing import assert_array_almost_equal
from mmaction.models import BinaryLogisticRegressionLoss, BMNLoss
def test_binary_logistic_regression_loss():
bmn_loss = BMNLoss()
# test tem_loss
pred_start = torch.tensor([0.9, 0.1])
pred_end = torch.tensor([0.1, 0.9])
gt_start = torch.tensor([1., 0.])
gt_end = torch.tensor([0., 1.])
output_tem_loss = bmn_loss.tem_loss(pred_start, pred_end, gt_start, gt_end)
binary_logistic_regression_loss = BinaryLogisticRegressionLoss()
assert_loss = (
binary_logistic_regression_loss(pred_start, gt_start) +
binary_logistic_regression_loss(pred_end, gt_end))
assert_array_almost_equal(
output_tem_loss.numpy(), assert_loss.numpy(), decimal=4)
|