| # Note that zamba does not have the `apply_rotary_pos_emb` function! | |
| from transformers.models.llama.modeling_llama import apply_rotary_pos_emb | |
| from transformers.models.zamba.modeling_zamba import ZambaAttention | |
| # When following ZambaAttention dependencies, the function `apply_rotary_pos_emb` is not present | |
| # by default as it is absent from the class definition (and the file altogether). | |
| # Note that this syntax should be able to add both `apply_rotary_pos_emb` as imported directly, but | |
| # `rotate_half` as well as a dependency from the imported function!! | |
| class TestAttention(ZambaAttention): | |
| def __init__(self): | |
| pass | |
| def forward(self): | |
| _ = apply_rotary_pos_emb(1, 1, 1, 1) | |