Gausson commited on
Commit
c6c6a1b
·
verified ·
1 Parent(s): e994cb5

Upload 2 files

Browse files
Files changed (2) hide show
  1. functions_2_patch.py +221 -0
  2. monkey_patching_utils.py +154 -0
functions_2_patch.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import inspect
3
+ import importlib
4
+
5
+ from typing import Callable, Optional, Union, Any, List
6
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
7
+ from transformers.cache_utils import Cache
8
+ from transformers.processing_utils import Unpack
9
+
10
+ from .sep_cache_utils import SepCache
11
+
12
+
13
+
14
+ def truncate_input_ids_4_autoregression(input_ids, key_states):
15
+ if input_ids.shape[-1] != key_states.shape[-2]:
16
+ assert input_ids.shape[-1] >= key_states.shape[-2]
17
+ truncated_input_ids = input_ids[..., -key_states.shape[-2]: ]
18
+ return truncated_input_ids
19
+ else:
20
+ return input_ids
21
+
22
+ def llama_atten_forward(
23
+ self,
24
+ hidden_states: torch.Tensor,
25
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
26
+ attention_mask: Optional[torch.Tensor],
27
+ past_key_value: Optional[Cache] = None,
28
+ cache_position: Optional[torch.LongTensor] = None,
29
+ **kwargs: Unpack[FlashAttentionKwargs],
30
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
31
+ input_shape = hidden_states.shape[:-1]
32
+
33
+ if hasattr(self, "head_dim"):
34
+ head_dim = self.head_dim
35
+ elif hasattr(self, "head_size"):
36
+ head_dim = self.head_size
37
+
38
+ hidden_shape = (*input_shape, -1, head_dim)
39
+
40
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
41
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
42
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
43
+
44
+
45
+ ###########################SepCache########################
46
+ assert isinstance(past_key_value, SepCache), f"`past_key_value` must be of the type: `SepCache`."
47
+ APPLY_PE_SHIFT = past_key_value.APPLY_PE_SHIFT
48
+ APPLY_PES_INSIDE = past_key_value.APPLY_PES_INSIDE
49
+ ###########################################################
50
+
51
+
52
+ ########################Monkey Patching####################
53
+ module = importlib.import_module(self.__module__)
54
+
55
+ apply_rotary_pos_emb = module.apply_rotary_pos_emb
56
+ rotate_half = module.rotate_half
57
+ eager_attention_forward = module.eager_attention_forward
58
+ ALL_ATTENTION_FUNCTIONS = module.ALL_ATTENTION_FUNCTIONS
59
+ ###########################################################
60
+
61
+ if not APPLY_PE_SHIFT:
62
+ cos, sin = position_embeddings
63
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
64
+
65
+ if past_key_value is not None:
66
+ # ##################################################Default#########################################################
67
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
68
+ # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
69
+ # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
70
+ # ##################################################################################################################
71
+
72
+ ##################################################SepCache#########################################################
73
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
74
+ if APPLY_PE_SHIFT and (not APPLY_PES_INSIDE):
75
+ ### At least the shifted `sin` and `cos` should be properly provided (not `None`).
76
+ cache_kwargs = {"sin": sin, "cos": cos, "cos_q": cos_q, "sin_q": sin_q, "cache_position": cache_position, "partial_rotation_size": None }
77
+ else:
78
+ cache_kwargs = {}
79
+
80
+
81
+ if "kwargs" in locals():
82
+ pass
83
+ elif "flash_attn_kwargs" in locals():
84
+ kwargs = flash_attn_kwargs
85
+ else:
86
+ raise NameError("`kwargs` or `flash_attn_kwargs` should be given and they need to contain `sepllm_kwargs` (which contains `input_ids`) and `position_ids`.")
87
+
88
+ if "input_ids" not in locals():
89
+ if "input_ids" in kwargs:
90
+ input_ids = kwargs.get("input_ids", None)
91
+ else:
92
+ sepllm_kwargs = kwargs.get("sepllm_kwargs", None)
93
+ assert sepllm_kwargs is not None, f"`sepllm_kwargs` must be provided when `input_ids` is not given."
94
+ input_ids = sepllm_kwargs.get("input_ids", None)
95
+
96
+ assert input_ids is not None, f"`input_ids` must be properly provided directly or through `sepllm_kwargs` when calling `update()` in `SepCache`."
97
+
98
+ if "position_ids" not in locals():
99
+ position_ids = kwargs.get("position_ids")
100
+
101
+ assert input_ids is not None, f"`input_ids` must be properly provided when calling `update()` in `SepCache`."
102
+ bsz, q_len, _ = hidden_states.size()
103
+
104
+ input_ids = truncate_input_ids_4_autoregression(input_ids = input_ids, key_states = key_states )
105
+
106
+ if APPLY_PE_SHIFT:
107
+ key_states, value_states, query_states = past_key_value.update(
108
+ key_states = key_states,
109
+ value_states = value_states,
110
+ query_states = query_states,
111
+ input_ids = input_ids,
112
+ layer_idx = self.layer_idx,
113
+ position_ids = position_ids,
114
+ PREFILLING_FLAG = q_len > 1,
115
+ cache_kwargs = cache_kwargs )
116
+
117
+ else:
118
+ key_states, value_states = past_key_value.update(
119
+ key_states = key_states,
120
+ value_states = value_states,
121
+ input_ids = input_ids,
122
+ layer_idx = self.layer_idx,
123
+ position_ids = position_ids,
124
+ PREFILLING_FLAG = q_len > 1,
125
+ cache_kwargs = cache_kwargs )
126
+
127
+ seq_len = past_key_value.get_usable_length(self.layer_idx)
128
+
129
+ if attention_mask is not None:
130
+ attention_mask = attention_mask[..., :seq_len]
131
+ ##################################################################################################################
132
+
133
+
134
+ attention_interface: Callable = eager_attention_forward
135
+ if self.config._attn_implementation != "eager":
136
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
137
+
138
+ attn_output, attn_weights = attention_interface(
139
+ self,
140
+ query_states,
141
+ key_states,
142
+ value_states,
143
+ attention_mask,
144
+ dropout=0.0 if not self.training else self.attention_dropout,
145
+ scaling=self.scaling,
146
+ **kwargs,
147
+ )
148
+
149
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
150
+ attn_output = self.o_proj(attn_output)
151
+ return attn_output, attn_weights
152
+
153
+
154
+
155
+
156
+
157
+
158
+ def _validate_model_kwargs(self, model_kwargs: dict[str, Any]):
159
+ """Validates model kwargs for generation. Generate argument typos will also be caught here."""
160
+ # If a `Cache` instance is passed, checks whether the model is compatible with it
161
+ if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class:
162
+ raise ValueError(
163
+ f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please "
164
+ "check the model documentation for supported cache formats."
165
+ )
166
+
167
+ # Excludes arguments that are handled before calling any model function
168
+ if self.config.is_encoder_decoder:
169
+ for key in ["decoder_input_ids"]:
170
+ model_kwargs.pop(key, None)
171
+
172
+ unused_model_args = []
173
+ model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
174
+ # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
175
+ # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
176
+ if "kwargs" in model_args or "model_kwargs" in model_args:
177
+ model_args |= set(inspect.signature(self.forward).parameters)
178
+
179
+ # Encoder-Decoder models may also need Encoder arguments from `model_kwargs`
180
+ if self.config.is_encoder_decoder:
181
+ base_model = getattr(self, self.base_model_prefix, None)
182
+
183
+ # allow encoder kwargs
184
+ encoder = getattr(self, "encoder", None)
185
+ # `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`.
186
+ # Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder`
187
+ # TODO: A better way to handle this.
188
+ if encoder is None and base_model is not None:
189
+ encoder = getattr(base_model, "encoder", None)
190
+
191
+ if encoder is not None:
192
+ encoder_model_args = set(inspect.signature(encoder.forward).parameters)
193
+ model_args |= encoder_model_args
194
+
195
+ # allow decoder kwargs
196
+ decoder = getattr(self, "decoder", None)
197
+ if decoder is None and base_model is not None:
198
+ decoder = getattr(base_model, "decoder", None)
199
+
200
+ if decoder is not None:
201
+ decoder_model_args = set(inspect.signature(decoder.forward).parameters)
202
+ model_args |= {f"decoder_{x}" for x in decoder_model_args}
203
+
204
+ for key, value in model_kwargs.items():
205
+ # #############################Default###########################
206
+ # if value is not None and key not in model_args:
207
+ # unused_model_args.append(key)
208
+ # ###############################################################
209
+
210
+ ###############################SepCache###########################
211
+ if (value is not None) and (key not in model_args) and ("sep" not in str(key).lower()):
212
+ unused_model_args.append(key)
213
+ ###################################################################
214
+
215
+ if unused_model_args:
216
+ raise ValueError(
217
+ f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
218
+ " generate arguments will also show up in this list)"
219
+ )
220
+
221
+
monkey_patching_utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import inspect
3
+ import importlib
4
+ import transformers
5
+ import types
6
+
7
+ import torch.nn as nn
8
+ from transformers.modeling_utils import PreTrainedModel
9
+ from typing import Callable, Optional, Union, Any, List
10
+
11
+ from .functions_2_patch import _validate_model_kwargs, llama_atten_forward
12
+
13
+
14
+ def get_full_class_import_path(obj):
15
+ """Get the complete class import path of an object"""
16
+ # Get the class of the object
17
+ cls = obj.__class__
18
+
19
+ # Get the module name where the class is defined
20
+ module = cls.__module__
21
+
22
+ # Get the qualified name of the class (including outer classes)
23
+ qualname = cls.__qualname__
24
+
25
+ # Handle nested classes (e.g., ClassA.ClassB)
26
+ if '.' in qualname:
27
+ # Replace nested class separators
28
+ class_path = f"{module}.{qualname.replace('.', '_')}"
29
+ else:
30
+ class_path = f"{module}.{qualname}"
31
+
32
+ return class_path
33
+
34
+
35
+ def get_importable_class_path(obj):
36
+ """Get the directly importable class path (handling special cases and dynamic classes)"""
37
+ cls = obj.__class__
38
+ module = cls.__module__
39
+ qualname = cls.__qualname__
40
+
41
+ # Handle built-in types
42
+ if module == 'builtins':
43
+ return qualname
44
+
45
+ # Handle dynamically generated classes (e.g., functools.partial)
46
+ if not hasattr(cls, '__module__') or module is None:
47
+ return f"<dynamic class {qualname}>"
48
+
49
+ # Handle nested classes
50
+ if '.' in qualname:
51
+ # Try to import the parent module to validate the path
52
+ try:
53
+ import importlib
54
+ parent_module = importlib.import_module(module)
55
+
56
+ # Follow the qualified name path
57
+ parts = qualname.split('.')
58
+ current = parent_module
59
+ for part in parts:
60
+ current = getattr(current, part)
61
+
62
+ # If successful access, return the original path
63
+ return f"{module}.{qualname}"
64
+ except (ImportError, AttributeError):
65
+ # Fallback: use underscore connection
66
+ return f"{module}.{qualname.replace('.', '_')}"
67
+
68
+ return f"{module}.{qualname}"
69
+
70
+
71
+
72
+ def monkey_patch_by_class_path(model, new_forward):
73
+ """Perform monkey patching through class path"""
74
+ # Get the complete class path
75
+ class_path = get_importable_class_path(model)
76
+
77
+ # Dynamically import the class
78
+ try:
79
+ import importlib
80
+ module_path, class_name = class_path.rsplit('.', 1)
81
+ module = importlib.import_module(module_path)
82
+ target_class = getattr(module, class_name)
83
+
84
+ # Save the original method
85
+ if not hasattr(target_class, '_original_forward'):
86
+ target_class._original_forward = target_class.forward
87
+
88
+ # Apply the patch
89
+ target_class.forward = new_forward
90
+
91
+ # Update the method binding of the current instance
92
+ model.forward = types.MethodType(target_class.forward, model)
93
+
94
+ return f"Successful Monkey Patch: {class_path}.forward"
95
+
96
+ except (ImportError, AttributeError, ValueError) as e:
97
+ return f"Patch Failed: {str(e)}"
98
+
99
+
100
+
101
+
102
+ def find_inner_attribute(obj, attr_name_list: List[str], default_type = PreTrainedModel ):
103
+ # try to find the attribute of the name in `attr_name_list`.
104
+ for target_attr_name in attr_name_list:
105
+ if hasattr(obj, target_attr_name):
106
+ return getattr(obj, target_attr_name)
107
+
108
+ # else: try to find the attribute of the type `default_type`
109
+ for attr_name in dir(obj):
110
+ attr_value = getattr(obj, attr_name)
111
+ if isinstance(attr_value, default_type):
112
+ return attr_value
113
+
114
+ raise AttributeError(f"In the {obj} object, there is no attribute whose name matches any name in {attr_name_list} or whose type is {default_type}.")
115
+
116
+
117
+ def find_attribute_name(obj, name_pattern_list: List[str], exclude_pattern_list: List[str], match_type = nn.Module):
118
+ for attr_name in dir(obj):
119
+ attr_value = getattr(obj, attr_name)
120
+ for pattern in name_pattern_list:
121
+ for ex_pattern in exclude_pattern_list:
122
+ if isinstance(attr_value, match_type) and (pattern.lower() in attr_value.__class__.__name__.lower()) and ( ex_pattern.lower() not in attr_value.__class__.__name__.lower() ):
123
+ return attr_value
124
+ elif isinstance(attr_value, match_type) and (pattern.lower() in attr_name.lower()) and (ex_pattern.lower() not in attr_name.lower() ):
125
+ return attr_value
126
+
127
+ raise AttributeError(f"In the {obj} object, there is no attribute whose name matches any pattern in {name_pattern_list} and excludes any pattern in {exclude_pattern_list}, and whose type is {match_type}.")
128
+
129
+
130
+
131
+ def monkey_patching(model_obj, model_atten_forward , verbose = True):
132
+ transformers.generation.GenerationMixin._validate_model_kwargs = _validate_model_kwargs
133
+
134
+ ## get inner model
135
+ possible_inner_model_names = ["model", "transformer", "gpt_neox"]
136
+ inner_model_type = PreTrainedModel
137
+ inner_model = find_inner_attribute(model_obj, possible_inner_model_names, inner_model_type)
138
+
139
+
140
+ possible_layers_names = ["layers", "h" ]
141
+ layers_type = nn.ModuleList
142
+ model_layers = find_inner_attribute(inner_model, possible_layers_names, layers_type)
143
+
144
+ atten_attr_name_pattern_list = ["attention", "self_attn"]
145
+ atten_attr_name_pattern_exclude = ["norm", "layer"]
146
+
147
+ for i, decoder_layer in enumerate(model_layers):
148
+ self_attn_module = find_attribute_name(decoder_layer, atten_attr_name_pattern_list, atten_attr_name_pattern_exclude, nn.Module)
149
+ result = monkey_patch_by_class_path(self_attn_module, model_atten_forward)
150
+ if verbose:
151
+ decoder_class_name = get_importable_class_path(decoder_layer)
152
+ print(f"For Layer {i}'s `{decoder_class_name}`: {result}")
153
+
154
+ return model_layers