File size: 800 Bytes
3319b2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from __future__ import annotations

import torch

from hydra.htm_cache import htm_cache_key, htm_cache_matches


def test_htm_cache_key_changes_when_same_shape_mask_pattern_changes():
    a = torch.tensor([[[1, 2, 3], [4, 5, 6]]], dtype=torch.long)
    b = torch.tensor([[[1, 2, 3], [4, 5, 7]]], dtype=torch.long)

    key_a = htm_cache_key(a)
    key_b = htm_cache_key(b)

    assert not torch.equal(key_a, key_b)
    assert key_a.shape == key_b.shape == (1, 2, 3)
    assert htm_cache_matches(key_a, a)
    assert not htm_cache_matches(key_a, b)


def test_htm_cache_key_keeps_device_dtype_shape_contract():
    x = torch.arange(12, dtype=torch.long).view(1, 4, 3)
    key = htm_cache_key(x)

    assert key.shape == (1, 4, 3)
    assert key.dtype == torch.long
    assert key.device.type == "cpu"