File size: 1,125 Bytes
c383594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
"""Small helpers for safe HTM output caching.

HTM subsampling is only valid when the current sparse input pattern matches the
cached pattern. Shape-only reuse is unsafe for MDLM because each forward can
sample a different random mask pattern with identical shape.
"""
from __future__ import annotations

import torch


def htm_cache_key(active_indices: torch.Tensor) -> torch.Tensor:
    """Return a cheap exact key for compact SDR/input active indices.

    The key is intentionally a detached CPU int64 copy: small enough for
    compact active-index tensors and exact enough to prevent stale cache reuse
    across different mask patterns.
    """
    return active_indices.detach().to(device="cpu", dtype=torch.long).contiguous().clone()


def htm_cache_matches(cache_key: torch.Tensor | None, active_indices: torch.Tensor) -> bool:
    """True only when shape and compact active-index contents match exactly."""
    if cache_key is None:
        return False
    cur = active_indices.detach().to(device="cpu", dtype=torch.long).contiguous()
    return cache_key.shape == cur.shape and bool(torch.equal(cache_key, cur))