| from typing import Optional, Union | |
| import torch | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from transformers.models.llama.modeling_llama import LlamaModel | |
| from ...cache_utils import Cache | |
| # example where we need some deps and some functions | |
| class SuperModel(LlamaModel): | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| ) -> Union[tuple, CausalLMOutputWithPast]: | |
| out = super().forward( | |
| input_ids, | |
| attention_mask, | |
| position_ids, | |
| past_key_values, | |
| inputs_embeds, | |
| use_cache, | |
| output_attentions, | |
| output_hidden_states, | |
| return_dict, | |
| cache_position, | |
| ) | |
| out.logits *= 2**4 | |
| return out | |