File size: 4,355 Bytes
60df24a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""Inference-time caches for Evo2 blocks.

StripedHyena2 has four block types with different caching needs:

  * `attn` blocks   -> InferenceParams (standard KV cache)
  * `hcl` blocks    -> HyenaCascadeIIRInferenceParams (FIR window + IIR state)
  * `hcm` blocks    -> HyenaCascadeFIRInferenceParams (outer FIR + inner FIR)
  * `hcs` blocks    -> HyenaCascadeFIRInferenceParams (outer FIR + inner FIR)

Layer outputs of these caches are wrapped together inside an HF Cache subclass
(`Evo2Cache`) so model.generate() can drive autoregressive decoding without
the user having to instantiate four separate caches by hand.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Optional

import torch
from torch import Tensor


@dataclass
class InferenceParams:
    """Standard KV cache for attention blocks."""

    max_seqlen: int
    max_batch_size: int
    seqlen_offset: int = 0
    batch_size_offset: int = 0
    key_value_memory_dict: dict = field(default_factory=dict)
    lengths_per_sample: Optional[Tensor] = None

    def reset(self, max_seqlen, max_batch_size):
        self.max_seqlen = max_seqlen
        self.max_batch_size = max_batch_size
        self.seqlen_offset = 0
        if self.lengths_per_sample is not None:
            self.lengths_per_sample.zero_()


@dataclass
class HyenaCascadeIIRInferenceParams:
    """Cache for `hcl` blocks: short FIR window + IIR modal state."""

    fir_filter_length: int = 3
    state_dim: int = 16
    seqlen_offset: int = 0
    fir_state_dict: dict = field(default_factory=dict)
    state_dict: dict = field(default_factory=dict)

    def reset(self):
        self.seqlen_offset = 0


@dataclass
class HyenaCascadeFIRInferenceParams:
    """Cache for `hcm` and `hcs` blocks: outer short FIR + inner FIR cascade."""

    fir_filter_length: int = 3
    fir_inner_filter_length: int = 4
    seqlen_offset: int = 0
    fir_inner_state_dict: dict = field(default_factory=dict)
    fir_state_dict: dict = field(default_factory=dict)
    state_dict: dict = field(default_factory=dict)

    def reset(self):
        self.seqlen_offset = 0


class Evo2Cache:
    """Container for per-block-type inference params.

    Not a transformers.Cache subclass (the new Cache API requires per-layer
    dataclasses, which doesn't fit StripedHyena 2's 4 block-type-specific
    state structures). Instead we set Evo2PreTrainedModel._supports_cache_class
    = False so HF's generate() treats this as an opaque past_key_values dict.
    """

    is_compileable = False

    def __init__(
        self,
        max_seqlen: int,
        max_batch_size: int,
        short_filter_length: int,
        hcm_filter_length: int,
        hcs_filter_length: int,
        state_size: int,
    ):
        self.mha = InferenceParams(
            max_seqlen=max_seqlen,
            max_batch_size=max_batch_size,
        )
        self.hcl = HyenaCascadeIIRInferenceParams(
            fir_filter_length=short_filter_length,
            state_dim=state_size,
        )
        self.hcm = HyenaCascadeFIRInferenceParams(
            fir_filter_length=short_filter_length,
            fir_inner_filter_length=hcm_filter_length,
        )
        self.hcs = HyenaCascadeFIRInferenceParams(
            fir_filter_length=short_filter_length,
            fir_inner_filter_length=hcs_filter_length,
        )

    @property
    def seqlen_offset(self) -> int:
        return self.mha.seqlen_offset

    def get_seq_length(self, layer_idx: int = 0) -> int:
        return self.mha.seqlen_offset

    def get_max_cache_shape(self) -> int:
        return self.mha.max_seqlen

    def get_max_length(self) -> int:
        return self.mha.max_seqlen

    def advance(self, n: int = 1) -> None:
        self.mha.seqlen_offset += n
        self.hcl.seqlen_offset += n
        self.hcm.seqlen_offset += n
        self.hcs.seqlen_offset += n

    def set_offset(self, offset: int) -> None:
        self.mha.seqlen_offset = offset
        self.hcl.seqlen_offset = offset
        self.hcm.seqlen_offset = offset
        self.hcs.seqlen_offset = offset

    def reset(self) -> None:
        self.mha.reset(self.mha.max_seqlen, self.mha.max_batch_size)
        self.hcl.reset()
        self.hcm.reset()
        self.hcs.reset()

    def by_block_name(self, name: str):
        return getattr(self, name)