icarus112's picture
Upload folder using huggingface_hub
c383594 verified
"""Small helpers for safe HTM output caching.
HTM subsampling is only valid when the current sparse input pattern matches the
cached pattern. Shape-only reuse is unsafe for MDLM because each forward can
sample a different random mask pattern with identical shape.
"""
from __future__ import annotations
import torch
def htm_cache_key(active_indices: torch.Tensor) -> torch.Tensor:
"""Return a cheap exact key for compact SDR/input active indices.
The key is intentionally a detached CPU int64 copy: small enough for
compact active-index tensors and exact enough to prevent stale cache reuse
across different mask patterns.
"""
return active_indices.detach().to(device="cpu", dtype=torch.long).contiguous().clone()
def htm_cache_matches(cache_key: torch.Tensor | None, active_indices: torch.Tensor) -> bool:
"""True only when shape and compact active-index contents match exactly."""
if cache_key is None:
return False
cur = active_indices.detach().to(device="cpu", dtype=torch.long).contiguous()
return cache_key.shape == cur.shape and bool(torch.equal(cache_key, cur))