File size: 6,661 Bytes
7bf638f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""SHRAM top-level cache β€” model-wide owner for the full SHRAM decoder stack.

The HuggingFace Cache protocol expects a single top-level Cache object that owns one
CacheLayerMixin per decoder layer. The actual SHRAM caching responsibilities live one level
lower in ShramLayerCache β€” each of which owns a LocalSlidingWindowLayerCache and a MoSRAHCache.
ShramCache bridges those two levels: it constructs one ShramLayerCache per decoder layer,
presents them through the Cache interface, and transparently forwards model-wide operations
across all of them.

ShramCache does not define a composite update() interface. The two attention paths inside each
SHRAM layer have different update semantics, and neither the layer-level boundary (Unit 6.B)
nor the model-level boundary here can meaningfully unify them. Callers must reach down to the
relevant sub-cache directly. ShramCache's role is ownership, construction, and model-wide
coordination of the layer caches β€” not routing attention inputs.

Sequence length is reported by delegating to the local sliding-window sub-cache of the
specified layer, which tracks the cumulative count of token positions processed. This is
what HuggingFace generation reads through get_seq_length().
"""

import torch
from transformers.cache_utils import Cache

from .__cache__shram_layer_cache import ShramLayerCache


class ShramCache(Cache):
    """Top-level cache for the full SHRAM model.

    Owns one ShramLayerCache per decoder layer. Satisfies the HuggingFace top-level Cache
    role and transparently forwards reset, reorder, and sequence-length queries across all
    owned layer caches.

    No composite update() interface is provided. The two attention paths inside each SHRAM
    layer have materially different update semantics; callers must update sub-caches directly
    via cache.layers[layer_idx].sliding_window_cache or cache.layers[layer_idx].mosrah_cache.

    Args:
        num_hidden_layers: Number of SHRAM decoder layers. Determines how many
            ShramLayerCache objects are constructed.
        sliding_window: Token window size passed to each layer's LocalSlidingWindowLayerCache.
        num_local_heads: Number of local attention heads per layer.
        local_head_dim: Per-head embedding width for the local path.
        num_mosrah_heads: Total number of MoSRAH expert heads (L) per layer.
        mosrah_head_dim: Bottlenecked head embedding width (u) for the MoSRAH path.
        batch_size: Number of sequences in the batch.
        device: Device on which to allocate cache tensors.
        initial_buffer_size: Initial per-(batch, head) capacity for each MoSRAHCache.
            Doubled when any slot overflows. Defaults to 64 to avoid repeated reallocation
            during prompt processing.
    """

    def __init__(
        self,
        num_hidden_layers: int,
        sliding_window: int,
        num_local_heads: int,
        local_head_dim: int,
        num_mosrah_heads: int,
        mosrah_head_dim: int,
        batch_size: int,
        device: torch.device,
        initial_buffer_size: int = 64,
    ) -> None:
        layers = [
            ShramLayerCache(
                sliding_window=sliding_window,
                num_local_heads=num_local_heads,
                local_head_dim=local_head_dim,
                num_mosrah_heads=num_mosrah_heads,
                mosrah_head_dim=mosrah_head_dim,
                batch_size=batch_size,
                device=device,
                initial_buffer_size=initial_buffer_size,
            )
            for _ in range(num_hidden_layers)
        ]
        super().__init__(layers=layers)

    # ---------------------------------------------------------------------------
    # Cache β€” composite-meaningful methods
    # ---------------------------------------------------------------------------
    #
    # reset(): Inherited. Iterates all layer caches and calls reset() on each.
    #
    # reorder_cache(beam_idx): Inherited. Iterates all layer caches and reorders each.
    #
    # is_initialized: Inherited property. True iff all layer caches are initialized.
    #   Since ShramLayerCache.is_initialized is True from construction, this is True
    #   immediately after ShramCache.__init__ returns.

    def get_seq_length(self, layer_idx: int = 0) -> int:  # type: ignore[override]
        """Return the cumulative sequence length for the specified layer.

        Delegates to the layer cache at layer_idx, which in turn delegates to the
        local sliding-window sub-cache. That sub-cache is authoritative for sequence
        progress: it sees every token presented to the layer and accumulates a truthful
        total count. Defaults to layer 0, which is sufficient for HuggingFace generation.
        """
        return self.layers[layer_idx].get_seq_length()

    # ---------------------------------------------------------------------------
    # Cache β€” unsupported methods
    # ---------------------------------------------------------------------------

    def update(  # type: ignore[override]
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: dict | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Not supported β€” ShramCache has no composite update interface.

        The two attention paths inside each SHRAM layer have different update semantics.
        Callers must update sub-caches directly:
          cache.layers[layer_idx].sliding_window_cache.update(key_states, value_states)
          cache.layers[layer_idx].mosrah_cache.update(key_states, value_states, active_mask)
        """
        raise NotImplementedError(
            "ShramCache has no composite update interface. "
            "Update sliding_window_cache or mosrah_cache on the relevant layer directly."
        )

    def crop(self, max_length: int) -> None:
        """Not supported β€” ShramCache layers do not implement crop()."""
        raise NotImplementedError("ShramCache does not support crop().")

    @property
    def max_batch_size(self) -> int:
        """Not supported β€” ShramCache does not track a uniform batch size across layers."""
        raise NotImplementedError("ShramCache does not expose max_batch_size.")

    @property
    def max_cache_len(self) -> int:
        """Not supported β€” ShramCache has no single maximum cache length.

        The sliding-window side is bounded by sliding_window; the MoSRAH side is unbounded.
        No truthful scalar maximum represents the composite.
        """
        raise NotImplementedError("ShramCache does not expose max_cache_len.")