| import torch | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from transformers.models.llama.modeling_llama import LlamaModel | |
| from ...cache_utils import Cache | |
| # example where we need some deps and some functions | |
| class SuperModel(LlamaModel): | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: torch.Tensor | None = None, | |
| position_ids: torch.LongTensor | None = None, | |
| past_key_values: Cache | None = None, | |
| inputs_embeds: torch.FloatTensor | None = None, | |
| use_cache: bool | None = None, | |
| output_attentions: bool | None = None, | |
| output_hidden_states: bool | None = None, | |
| return_dict: bool | None = None, | |
| cache_position: torch.LongTensor | None = None, | |
| ) -> tuple | CausalLMOutputWithPast: | |
| out = super().forward( | |
| input_ids, | |
| attention_mask, | |
| position_ids, | |
| past_key_values, | |
| inputs_embeds, | |
| use_cache, | |
| output_attentions, | |
| output_hidden_states, | |
| return_dict, | |
| cache_position, | |
| ) | |
| out.logits *= 2**4 | |
| return out | |