import torch from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.llama.modeling_llama import LlamaModel from ...cache_utils import Cache # example where we need some deps and some functions class SuperModel(LlamaModel): def forward( self, input_ids: torch.LongTensor = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, cache_position: torch.LongTensor | None = None, ) -> tuple | CausalLMOutputWithPast: out = super().forward( input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, ) out.logits *= 2**4 return out