Gausson commited on
Commit
a8296ac
·
verified ·
1 Parent(s): 1fca93a

Delete functions_2_patch.py

Browse files
Files changed (1) hide show
  1. functions_2_patch.py +0 -221
functions_2_patch.py DELETED
@@ -1,221 +0,0 @@
1
- import torch
2
- import inspect
3
- import importlib
4
-
5
- from typing import Callable, Optional, Union, Any, List
6
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
7
- from transformers.cache_utils import Cache
8
- from transformers.processing_utils import Unpack
9
-
10
- from .sep_cache_utils import SepCache
11
-
12
-
13
-
14
- def truncate_input_ids_4_autoregression(input_ids, key_states):
15
- if input_ids.shape[-1] != key_states.shape[-2]:
16
- assert input_ids.shape[-1] >= key_states.shape[-2]
17
- truncated_input_ids = input_ids[..., -key_states.shape[-2]: ]
18
- return truncated_input_ids
19
- else:
20
- return input_ids
21
-
22
- def llama_atten_forward(
23
- self,
24
- hidden_states: torch.Tensor,
25
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
26
- attention_mask: Optional[torch.Tensor],
27
- past_key_value: Optional[Cache] = None,
28
- cache_position: Optional[torch.LongTensor] = None,
29
- **kwargs: Unpack[FlashAttentionKwargs],
30
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
31
- input_shape = hidden_states.shape[:-1]
32
-
33
- if hasattr(self, "head_dim"):
34
- head_dim = self.head_dim
35
- elif hasattr(self, "head_size"):
36
- head_dim = self.head_size
37
-
38
- hidden_shape = (*input_shape, -1, head_dim)
39
-
40
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
41
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
42
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
43
-
44
-
45
- ###########################SepCache########################
46
- assert isinstance(past_key_value, SepCache), f"`past_key_value` must be of the type: `SepCache`."
47
- APPLY_PE_SHIFT = past_key_value.APPLY_PE_SHIFT
48
- APPLY_PES_INSIDE = past_key_value.APPLY_PES_INSIDE
49
- ###########################################################
50
-
51
-
52
- ########################Monkey Patching####################
53
- module = importlib.import_module(self.__module__)
54
-
55
- apply_rotary_pos_emb = module.apply_rotary_pos_emb
56
- rotate_half = module.rotate_half
57
- eager_attention_forward = module.eager_attention_forward
58
- ALL_ATTENTION_FUNCTIONS = module.ALL_ATTENTION_FUNCTIONS
59
- ###########################################################
60
-
61
- if not APPLY_PE_SHIFT:
62
- cos, sin = position_embeddings
63
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
64
-
65
- if past_key_value is not None:
66
- # ##################################################Default#########################################################
67
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
68
- # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
69
- # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
70
- # ##################################################################################################################
71
-
72
- ##################################################SepCache#########################################################
73
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
74
- if APPLY_PE_SHIFT and (not APPLY_PES_INSIDE):
75
- ### At least the shifted `sin` and `cos` should be properly provided (not `None`).
76
- cache_kwargs = {"sin": sin, "cos": cos, "cos_q": cos_q, "sin_q": sin_q, "cache_position": cache_position, "partial_rotation_size": None }
77
- else:
78
- cache_kwargs = {}
79
-
80
-
81
- if "kwargs" in locals():
82
- pass
83
- elif "flash_attn_kwargs" in locals():
84
- kwargs = flash_attn_kwargs
85
- else:
86
- raise NameError("`kwargs` or `flash_attn_kwargs` should be given and they need to contain `sepllm_kwargs` (which contains `input_ids`) and `position_ids`.")
87
-
88
- if "input_ids" not in locals():
89
- if "input_ids" in kwargs:
90
- input_ids = kwargs.get("input_ids", None)
91
- else:
92
- sepllm_kwargs = kwargs.get("sepllm_kwargs", None)
93
- assert sepllm_kwargs is not None, f"`sepllm_kwargs` must be provided when `input_ids` is not given."
94
- input_ids = sepllm_kwargs.get("input_ids", None)
95
-
96
- assert input_ids is not None, f"`input_ids` must be properly provided directly or through `sepllm_kwargs` when calling `update()` in `SepCache`."
97
-
98
- if "position_ids" not in locals():
99
- position_ids = kwargs.get("position_ids")
100
-
101
- assert input_ids is not None, f"`input_ids` must be properly provided when calling `update()` in `SepCache`."
102
- bsz, q_len, _ = hidden_states.size()
103
-
104
- input_ids = truncate_input_ids_4_autoregression(input_ids = input_ids, key_states = key_states )
105
-
106
- if APPLY_PE_SHIFT:
107
- key_states, value_states, query_states = past_key_value.update(
108
- key_states = key_states,
109
- value_states = value_states,
110
- query_states = query_states,
111
- input_ids = input_ids,
112
- layer_idx = self.layer_idx,
113
- position_ids = position_ids,
114
- PREFILLING_FLAG = q_len > 1,
115
- cache_kwargs = cache_kwargs )
116
-
117
- else:
118
- key_states, value_states = past_key_value.update(
119
- key_states = key_states,
120
- value_states = value_states,
121
- input_ids = input_ids,
122
- layer_idx = self.layer_idx,
123
- position_ids = position_ids,
124
- PREFILLING_FLAG = q_len > 1,
125
- cache_kwargs = cache_kwargs )
126
-
127
- seq_len = past_key_value.get_usable_length(self.layer_idx)
128
-
129
- if attention_mask is not None:
130
- attention_mask = attention_mask[..., :seq_len]
131
- ##################################################################################################################
132
-
133
-
134
- attention_interface: Callable = eager_attention_forward
135
- if self.config._attn_implementation != "eager":
136
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
137
-
138
- attn_output, attn_weights = attention_interface(
139
- self,
140
- query_states,
141
- key_states,
142
- value_states,
143
- attention_mask,
144
- dropout=0.0 if not self.training else self.attention_dropout,
145
- scaling=self.scaling,
146
- **kwargs,
147
- )
148
-
149
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
150
- attn_output = self.o_proj(attn_output)
151
- return attn_output, attn_weights
152
-
153
-
154
-
155
-
156
-
157
-
158
- def _validate_model_kwargs(self, model_kwargs: dict[str, Any]):
159
- """Validates model kwargs for generation. Generate argument typos will also be caught here."""
160
- # If a `Cache` instance is passed, checks whether the model is compatible with it
161
- if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class:
162
- raise ValueError(
163
- f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please "
164
- "check the model documentation for supported cache formats."
165
- )
166
-
167
- # Excludes arguments that are handled before calling any model function
168
- if self.config.is_encoder_decoder:
169
- for key in ["decoder_input_ids"]:
170
- model_kwargs.pop(key, None)
171
-
172
- unused_model_args = []
173
- model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
174
- # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
175
- # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
176
- if "kwargs" in model_args or "model_kwargs" in model_args:
177
- model_args |= set(inspect.signature(self.forward).parameters)
178
-
179
- # Encoder-Decoder models may also need Encoder arguments from `model_kwargs`
180
- if self.config.is_encoder_decoder:
181
- base_model = getattr(self, self.base_model_prefix, None)
182
-
183
- # allow encoder kwargs
184
- encoder = getattr(self, "encoder", None)
185
- # `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`.
186
- # Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder`
187
- # TODO: A better way to handle this.
188
- if encoder is None and base_model is not None:
189
- encoder = getattr(base_model, "encoder", None)
190
-
191
- if encoder is not None:
192
- encoder_model_args = set(inspect.signature(encoder.forward).parameters)
193
- model_args |= encoder_model_args
194
-
195
- # allow decoder kwargs
196
- decoder = getattr(self, "decoder", None)
197
- if decoder is None and base_model is not None:
198
- decoder = getattr(base_model, "decoder", None)
199
-
200
- if decoder is not None:
201
- decoder_model_args = set(inspect.signature(decoder.forward).parameters)
202
- model_args |= {f"decoder_{x}" for x in decoder_model_args}
203
-
204
- for key, value in model_kwargs.items():
205
- # #############################Default###########################
206
- # if value is not None and key not in model_args:
207
- # unused_model_args.append(key)
208
- # ###############################################################
209
-
210
- ###############################SepCache###########################
211
- if (value is not None) and (key not in model_args) and ("sep" not in str(key).lower()):
212
- unused_model_args.append(key)
213
- ###################################################################
214
-
215
- if unused_model_args:
216
- raise ValueError(
217
- f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
218
- " generate arguments will also show up in this list)"
219
- )
220
-
221
-