"""Small helpers for safe HTM output caching. HTM subsampling is only valid when the current sparse input pattern matches the cached pattern. Shape-only reuse is unsafe for MDLM because each forward can sample a different random mask pattern with identical shape. """ from __future__ import annotations import torch def htm_cache_key(active_indices: torch.Tensor) -> torch.Tensor: """Return a cheap exact key for compact SDR/input active indices. The key is intentionally a detached CPU int64 copy: small enough for compact active-index tensors and exact enough to prevent stale cache reuse across different mask patterns. """ return active_indices.detach().to(device="cpu", dtype=torch.long).contiguous().clone() def htm_cache_matches(cache_key: torch.Tensor | None, active_indices: torch.Tensor) -> bool: """True only when shape and compact active-index contents match exactly.""" if cache_key is None: return False cur = active_indices.detach().to(device="cpu", dtype=torch.long).contiguous() return cache_key.shape == cur.shape and bool(torch.equal(cache_key, cur))