SHRAM / __cache__shram_cache.py
smithblack-0's picture
Update architecture and tokenizer
7bf638f verified
"""SHRAM top-level cache — model-wide owner for the full SHRAM decoder stack.
The HuggingFace Cache protocol expects a single top-level Cache object that owns one
CacheLayerMixin per decoder layer. The actual SHRAM caching responsibilities live one level
lower in ShramLayerCache — each of which owns a LocalSlidingWindowLayerCache and a MoSRAHCache.
ShramCache bridges those two levels: it constructs one ShramLayerCache per decoder layer,
presents them through the Cache interface, and transparently forwards model-wide operations
across all of them.
ShramCache does not define a composite update() interface. The two attention paths inside each
SHRAM layer have different update semantics, and neither the layer-level boundary (Unit 6.B)
nor the model-level boundary here can meaningfully unify them. Callers must reach down to the
relevant sub-cache directly. ShramCache's role is ownership, construction, and model-wide
coordination of the layer caches — not routing attention inputs.
Sequence length is reported by delegating to the local sliding-window sub-cache of the
specified layer, which tracks the cumulative count of token positions processed. This is
what HuggingFace generation reads through get_seq_length().
"""
import torch
from transformers.cache_utils import Cache
from .__cache__shram_layer_cache import ShramLayerCache
class ShramCache(Cache):
"""Top-level cache for the full SHRAM model.
Owns one ShramLayerCache per decoder layer. Satisfies the HuggingFace top-level Cache
role and transparently forwards reset, reorder, and sequence-length queries across all
owned layer caches.
No composite update() interface is provided. The two attention paths inside each SHRAM
layer have materially different update semantics; callers must update sub-caches directly
via cache.layers[layer_idx].sliding_window_cache or cache.layers[layer_idx].mosrah_cache.
Args:
num_hidden_layers: Number of SHRAM decoder layers. Determines how many
ShramLayerCache objects are constructed.
sliding_window: Token window size passed to each layer's LocalSlidingWindowLayerCache.
num_local_heads: Number of local attention heads per layer.
local_head_dim: Per-head embedding width for the local path.
num_mosrah_heads: Total number of MoSRAH expert heads (L) per layer.
mosrah_head_dim: Bottlenecked head embedding width (u) for the MoSRAH path.
batch_size: Number of sequences in the batch.
device: Device on which to allocate cache tensors.
initial_buffer_size: Initial per-(batch, head) capacity for each MoSRAHCache.
Doubled when any slot overflows. Defaults to 64 to avoid repeated reallocation
during prompt processing.
"""
def __init__(
self,
num_hidden_layers: int,
sliding_window: int,
num_local_heads: int,
local_head_dim: int,
num_mosrah_heads: int,
mosrah_head_dim: int,
batch_size: int,
device: torch.device,
initial_buffer_size: int = 64,
) -> None:
layers = [
ShramLayerCache(
sliding_window=sliding_window,
num_local_heads=num_local_heads,
local_head_dim=local_head_dim,
num_mosrah_heads=num_mosrah_heads,
mosrah_head_dim=mosrah_head_dim,
batch_size=batch_size,
device=device,
initial_buffer_size=initial_buffer_size,
)
for _ in range(num_hidden_layers)
]
super().__init__(layers=layers)
# ---------------------------------------------------------------------------
# Cache — composite-meaningful methods
# ---------------------------------------------------------------------------
#
# reset(): Inherited. Iterates all layer caches and calls reset() on each.
#
# reorder_cache(beam_idx): Inherited. Iterates all layer caches and reorders each.
#
# is_initialized: Inherited property. True iff all layer caches are initialized.
# Since ShramLayerCache.is_initialized is True from construction, this is True
# immediately after ShramCache.__init__ returns.
def get_seq_length(self, layer_idx: int = 0) -> int: # type: ignore[override]
"""Return the cumulative sequence length for the specified layer.
Delegates to the layer cache at layer_idx, which in turn delegates to the
local sliding-window sub-cache. That sub-cache is authoritative for sequence
progress: it sees every token presented to the layer and accumulates a truthful
total count. Defaults to layer 0, which is sufficient for HuggingFace generation.
"""
return self.layers[layer_idx].get_seq_length()
# ---------------------------------------------------------------------------
# Cache — unsupported methods
# ---------------------------------------------------------------------------
def update( # type: ignore[override]
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: dict | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Not supported — ShramCache has no composite update interface.
The two attention paths inside each SHRAM layer have different update semantics.
Callers must update sub-caches directly:
cache.layers[layer_idx].sliding_window_cache.update(key_states, value_states)
cache.layers[layer_idx].mosrah_cache.update(key_states, value_states, active_mask)
"""
raise NotImplementedError(
"ShramCache has no composite update interface. "
"Update sliding_window_cache or mosrah_cache on the relevant layer directly."
)
def crop(self, max_length: int) -> None:
"""Not supported — ShramCache layers do not implement crop()."""
raise NotImplementedError("ShramCache does not support crop().")
@property
def max_batch_size(self) -> int:
"""Not supported — ShramCache does not track a uniform batch size across layers."""
raise NotImplementedError("ShramCache does not expose max_batch_size.")
@property
def max_cache_len(self) -> int:
"""Not supported — ShramCache has no single maximum cache length.
The sliding-window side is bounded by sliding_window; the MoSRAH side is unbounded.
No truthful scalar maximum represents the composite.
"""
raise NotImplementedError("ShramCache does not expose max_cache_len.")