Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import pytest | |
| import torch | |
| from mmcv.utils import digit_version, is_jit_tracing | |
| def test_is_jit_tracing(): | |
| def foo(x): | |
| if is_jit_tracing(): | |
| return x | |
| else: | |
| return x.tolist() | |
| x = torch.rand(3) | |
| # test without trace | |
| assert isinstance(foo(x), list) | |
| # test with trace | |
| traced_foo = torch.jit.trace(foo, (torch.rand(1), )) | |
| assert isinstance(traced_foo(x), torch.Tensor) | |