evo2-7b / cache.py
ishanjmukherjee's picture
Copy Python verbatim from vortex
43539ed
# Copied verbatim from vortex
# Copyright (c) 2024, Michael Poli.
from dataclasses import dataclass, field
from typing import Optional
from torch import Tensor
# https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seqlen: int
max_batch_size: int
seqlen_offset: int = 0
batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict)
lengths_per_sample: Optional[Tensor] = None
def reset(self, max_seqlen, max_batch_size):
self.max_seqlen = max_seqlen
self.max_batch_size = max_batch_size
self.seqlen_offset = 0
if self.lengths_per_sample is not None:
self.lengths_per_sample.zero_()
@dataclass
class HyenaCascadeIIRInferenceParams:
"""Inference parameters passed to long Hyena blocks with recurrent mode."""
fir_filter_length: int = 3
state_dim: int = 16
seqlen_offset: int = 0
fir_state_dict: dict = field(default_factory=dict)
state_dict: dict = field(default_factory=dict)
def reset(self):
self.fir_filter_length = 3
self.state_dim = 16
self.seqlen_offset = 0
@dataclass
class HyenaCascadeFIRInferenceParams:
"""Inference parameters passed to short and medium Hyena blocks."""
fir_filter_length: int = 3
fir_inner_filter_length: int = 4
seqlen_offset: int = 0
fir_inner_state_dict: dict = field(default_factory=dict)
fir_state_dict: dict = field(default_factory=dict)
state_dict: dict = field(default_factory=dict)
def reset(self):
self.fir_filter_length = 3
self.fir_inner_filter_length = 4
self.seqlen_offset = 0