| """Small helpers for safe HTM output caching. | |
| HTM subsampling is only valid when the current sparse input pattern matches the | |
| cached pattern. Shape-only reuse is unsafe for MDLM because each forward can | |
| sample a different random mask pattern with identical shape. | |
| """ | |
| from __future__ import annotations | |
| import torch | |
| def htm_cache_key(active_indices: torch.Tensor) -> torch.Tensor: | |
| """Return a cheap exact key for compact SDR/input active indices. | |
| The key is intentionally a detached CPU int64 copy: small enough for | |
| compact active-index tensors and exact enough to prevent stale cache reuse | |
| across different mask patterns. | |
| """ | |
| return active_indices.detach().to(device="cpu", dtype=torch.long).contiguous().clone() | |
| def htm_cache_matches(cache_key: torch.Tensor | None, active_indices: torch.Tensor) -> bool: | |
| """True only when shape and compact active-index contents match exactly.""" | |
| if cache_key is None: | |
| return False | |
| cur = active_indices.detach().to(device="cpu", dtype=torch.long).contiguous() | |
| return cache_key.shape == cur.shape and bool(torch.equal(cache_key, cur)) | |