Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import torch | |
| from hydra.htm_cache import htm_cache_key, htm_cache_matches | |
| def test_htm_cache_key_changes_when_same_shape_mask_pattern_changes(): | |
| a = torch.tensor([[[1, 2, 3], [4, 5, 6]]], dtype=torch.long) | |
| b = torch.tensor([[[1, 2, 3], [4, 5, 7]]], dtype=torch.long) | |
| key_a = htm_cache_key(a) | |
| key_b = htm_cache_key(b) | |
| assert not torch.equal(key_a, key_b) | |
| assert key_a.shape == key_b.shape == (1, 2, 3) | |
| assert htm_cache_matches(key_a, a) | |
| assert not htm_cache_matches(key_a, b) | |
| def test_htm_cache_key_keeps_device_dtype_shape_contract(): | |
| x = torch.arange(12, dtype=torch.long).view(1, 4, 3) | |
| key = htm_cache_key(x) | |
| assert key.shape == (1, 4, 3) | |
| assert key.dtype == torch.long | |
| assert key.device.type == "cpu" | |