| import torch | |
| from transformers.models.llama.modeling_llama import LlamaModel | |
| def rotate_half(x): | |
| """Rotates half the hidden dims of the input.""" | |
| x1 = x[..., : x.shape[-1] // 4] | |
| x2 = x[..., x.shape[-1] // 4 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| # example where we need some deps and some functions | |
| class DummyModel(LlamaModel): | |
| pass | |