feather-a10g-large-runtime / overlay /tests /test_htm_cache_contract.py
icarus112's picture
Update Feather a10g-large training runtime image
3319b2a verified
from __future__ import annotations
import torch
from hydra.htm_cache import htm_cache_key, htm_cache_matches
def test_htm_cache_key_changes_when_same_shape_mask_pattern_changes():
a = torch.tensor([[[1, 2, 3], [4, 5, 6]]], dtype=torch.long)
b = torch.tensor([[[1, 2, 3], [4, 5, 7]]], dtype=torch.long)
key_a = htm_cache_key(a)
key_b = htm_cache_key(b)
assert not torch.equal(key_a, key_b)
assert key_a.shape == key_b.shape == (1, 2, 3)
assert htm_cache_matches(key_a, a)
assert not htm_cache_matches(key_a, b)
def test_htm_cache_key_keeps_device_dtype_shape_contract():
x = torch.arange(12, dtype=torch.long).view(1, 4, 3)
key = htm_cache_key(x)
assert key.shape == (1, 4, 3)
assert key.dtype == torch.long
assert key.device.type == "cpu"