from __future__ import annotations import torch from hydra.htm_cache import htm_cache_key, htm_cache_matches def test_htm_cache_key_changes_when_same_shape_mask_pattern_changes(): a = torch.tensor([[[1, 2, 3], [4, 5, 6]]], dtype=torch.long) b = torch.tensor([[[1, 2, 3], [4, 5, 7]]], dtype=torch.long) key_a = htm_cache_key(a) key_b = htm_cache_key(b) assert not torch.equal(key_a, key_b) assert key_a.shape == key_b.shape == (1, 2, 3) assert htm_cache_matches(key_a, a) assert not htm_cache_matches(key_a, b) def test_htm_cache_key_keeps_device_dtype_shape_contract(): x = torch.arange(12, dtype=torch.long).view(1, 4, 3) key = htm_cache_key(x) assert key.shape == (1, 4, 3) assert key.dtype == torch.long assert key.device.type == "cpu"