transformers / examples /modular-transformers /modular_global_indexing.py
AbdulElahGwaith's picture
Upload folder using huggingface_hub
a9bd396 verified
from transformers.modeling_utils import AttentionInterface
from transformers.models.llama.modeling_llama import LlamaAttention
def custom_flex(x, **kwargs):
"""Dummy function."""
return x
ALL_ATTENTION_FUNCTIONS = AttentionInterface()
# This indexing statement and associated function should be exported correctly!
ALL_ATTENTION_FUNCTIONS["flex_attention"] = custom_flex
class GlobalIndexingAttention(LlamaAttention):
pass