| | import torch |
| | import inspect |
| | import importlib |
| |
|
| | from typing import Callable, Optional, Union, Any, List |
| | from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| | from transformers.cache_utils import Cache |
| | from transformers.processing_utils import Unpack |
| |
|
| | from .sep_cache_utils import SepCache |
| |
|
| |
|
| |
|
| | def truncate_input_ids_4_autoregression(input_ids, key_states): |
| | if input_ids.shape[-1] != key_states.shape[-2]: |
| | assert input_ids.shape[-1] >= key_states.shape[-2] |
| | truncated_input_ids = input_ids[..., -key_states.shape[-2]: ] |
| | return truncated_input_ids |
| | else: |
| | return input_ids |
| |
|
| | def llama_atten_forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| | attention_mask: Optional[torch.Tensor], |
| | past_key_value: Optional[Cache] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | **kwargs: Unpack[FlashAttentionKwargs], |
| | ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: |
| | input_shape = hidden_states.shape[:-1] |
| |
|
| | if hasattr(self, "head_dim"): |
| | head_dim = self.head_dim |
| | elif hasattr(self, "head_size"): |
| | head_dim = self.head_size |
| |
|
| | hidden_shape = (*input_shape, -1, head_dim) |
| |
|
| | query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| | key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| | value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| |
|
| |
|
| | |
| | assert isinstance(past_key_value, SepCache), f"`past_key_value` must be of the type: `SepCache`." |
| | APPLY_PE_SHIFT = past_key_value.APPLY_PE_SHIFT |
| | APPLY_PES_INSIDE = past_key_value.APPLY_PES_INSIDE |
| | |
| |
|
| |
|
| | |
| | module = importlib.import_module(self.__module__) |
| | |
| | apply_rotary_pos_emb = module.apply_rotary_pos_emb |
| | rotate_half = module.rotate_half |
| | eager_attention_forward = module.eager_attention_forward |
| | ALL_ATTENTION_FUNCTIONS = module.ALL_ATTENTION_FUNCTIONS |
| | |
| |
|
| | if not APPLY_PE_SHIFT: |
| | cos, sin = position_embeddings |
| | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
| |
|
| | if past_key_value is not None: |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | if APPLY_PE_SHIFT and (not APPLY_PES_INSIDE): |
| | |
| | cache_kwargs = {"sin": sin, "cos": cos, "cos_q": cos_q, "sin_q": sin_q, "cache_position": cache_position, "partial_rotation_size": None } |
| | else: |
| | cache_kwargs = {} |
| |
|
| |
|
| | if "kwargs" in locals(): |
| | pass |
| | elif "flash_attn_kwargs" in locals(): |
| | kwargs = flash_attn_kwargs |
| | else: |
| | raise NameError("`kwargs` or `flash_attn_kwargs` should be given and they need to contain `sepllm_kwargs` (which contains `input_ids`) and `position_ids`.") |
| |
|
| | if "input_ids" not in locals(): |
| | if "input_ids" in kwargs: |
| | input_ids = kwargs.get("input_ids", None) |
| | else: |
| | sepllm_kwargs = kwargs.get("sepllm_kwargs", None) |
| | assert sepllm_kwargs is not None, f"`sepllm_kwargs` must be provided when `input_ids` is not given." |
| | input_ids = sepllm_kwargs.get("input_ids", None) |
| | |
| | assert input_ids is not None, f"`input_ids` must be properly provided directly or through `sepllm_kwargs` when calling `update()` in `SepCache`." |
| |
|
| | if "position_ids" not in locals(): |
| | position_ids = kwargs.get("position_ids") |
| | |
| | assert input_ids is not None, f"`input_ids` must be properly provided when calling `update()` in `SepCache`." |
| | bsz, q_len, _ = hidden_states.size() |
| |
|
| | input_ids = truncate_input_ids_4_autoregression(input_ids = input_ids, key_states = key_states ) |
| |
|
| | if APPLY_PE_SHIFT: |
| | key_states, value_states, query_states = past_key_value.update( |
| | key_states = key_states, |
| | value_states = value_states, |
| | query_states = query_states, |
| | input_ids = input_ids, |
| | layer_idx = self.layer_idx, |
| | position_ids = position_ids, |
| | PREFILLING_FLAG = q_len > 1, |
| | cache_kwargs = cache_kwargs ) |
| |
|
| | else: |
| | key_states, value_states = past_key_value.update( |
| | key_states = key_states, |
| | value_states = value_states, |
| | input_ids = input_ids, |
| | layer_idx = self.layer_idx, |
| | position_ids = position_ids, |
| | PREFILLING_FLAG = q_len > 1, |
| | cache_kwargs = cache_kwargs ) |
| | |
| | seq_len = past_key_value.get_usable_length(self.layer_idx) |
| |
|
| | if attention_mask is not None: |
| | attention_mask = attention_mask[..., :seq_len] |
| | |
| |
|
| |
|
| | attention_interface: Callable = eager_attention_forward |
| | if self.config._attn_implementation != "eager": |
| | attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
| |
|
| | attn_output, attn_weights = attention_interface( |
| | self, |
| | query_states, |
| | key_states, |
| | value_states, |
| | attention_mask, |
| | dropout=0.0 if not self.training else self.attention_dropout, |
| | scaling=self.scaling, |
| | **kwargs, |
| | ) |
| |
|
| | attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| | attn_output = self.o_proj(attn_output) |
| | return attn_output, attn_weights |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def _validate_model_kwargs(self, model_kwargs: dict[str, Any]): |
| | """Validates model kwargs for generation. Generate argument typos will also be caught here.""" |
| | |
| | if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class: |
| | raise ValueError( |
| | f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please " |
| | "check the model documentation for supported cache formats." |
| | ) |
| |
|
| | |
| | if self.config.is_encoder_decoder: |
| | for key in ["decoder_input_ids"]: |
| | model_kwargs.pop(key, None) |
| |
|
| | unused_model_args = [] |
| | model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) |
| | |
| | |
| | if "kwargs" in model_args or "model_kwargs" in model_args: |
| | model_args |= set(inspect.signature(self.forward).parameters) |
| |
|
| | |
| | if self.config.is_encoder_decoder: |
| | base_model = getattr(self, self.base_model_prefix, None) |
| |
|
| | |
| | encoder = getattr(self, "encoder", None) |
| | |
| | |
| | |
| | if encoder is None and base_model is not None: |
| | encoder = getattr(base_model, "encoder", None) |
| |
|
| | if encoder is not None: |
| | encoder_model_args = set(inspect.signature(encoder.forward).parameters) |
| | model_args |= encoder_model_args |
| |
|
| | |
| | decoder = getattr(self, "decoder", None) |
| | if decoder is None and base_model is not None: |
| | decoder = getattr(base_model, "decoder", None) |
| |
|
| | if decoder is not None: |
| | decoder_model_args = set(inspect.signature(decoder.forward).parameters) |
| | model_args |= {f"decoder_{x}" for x in decoder_model_args} |
| |
|
| | for key, value in model_kwargs.items(): |
| | |
| | |
| | |
| | |
| |
|
| | |
| | if (value is not None) and (key not in model_args) and ("sep" not in str(key).lower()): |
| | unused_model_args.append(key) |
| | |
| |
|
| | if unused_model_args: |
| | raise ValueError( |
| | f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" |
| | " generate arguments will also show up in this list)" |
| | ) |
| |
|
| |
|
| |
|