Spaces:
Sleeping
Sleeping
| import pytest | |
| import torch | |
| from lzero.policy.scaling_transform import inverse_scalar_transform, InverseScalarTransform | |
| def test_scaling_transform(): | |
| import time | |
| logit = torch.randn(16, 601) | |
| start = time.time() | |
| output_1 = inverse_scalar_transform(logit, 300) | |
| print('t1', time.time() - start) | |
| handle = InverseScalarTransform(300) | |
| start = time.time() | |
| output_2 = handle(logit) | |
| print('t2', time.time() - start) | |
| assert output_1.shape == output_2.shape == (16, 1) | |
| assert (output_1 == output_2).all() | |