Gausson commited on
Commit
61c53db
·
verified ·
1 Parent(s): fa46716

Delete monkey_patching_utils.py

Browse files
Files changed (1) hide show
  1. monkey_patching_utils.py +0 -154
monkey_patching_utils.py DELETED
@@ -1,154 +0,0 @@
1
- import torch
2
- import inspect
3
- import importlib
4
- import transformers
5
- import types
6
-
7
- import torch.nn as nn
8
- from transformers.modeling_utils import PreTrainedModel
9
- from typing import Callable, Optional, Union, Any, List
10
-
11
- from .functions_2_patch import _validate_model_kwargs, llama_atten_forward
12
-
13
-
14
- def get_full_class_import_path(obj):
15
- """Get the complete class import path of an object"""
16
- # Get the class of the object
17
- cls = obj.__class__
18
-
19
- # Get the module name where the class is defined
20
- module = cls.__module__
21
-
22
- # Get the qualified name of the class (including outer classes)
23
- qualname = cls.__qualname__
24
-
25
- # Handle nested classes (e.g., ClassA.ClassB)
26
- if '.' in qualname:
27
- # Replace nested class separators
28
- class_path = f"{module}.{qualname.replace('.', '_')}"
29
- else:
30
- class_path = f"{module}.{qualname}"
31
-
32
- return class_path
33
-
34
-
35
- def get_importable_class_path(obj):
36
- """Get the directly importable class path (handling special cases and dynamic classes)"""
37
- cls = obj.__class__
38
- module = cls.__module__
39
- qualname = cls.__qualname__
40
-
41
- # Handle built-in types
42
- if module == 'builtins':
43
- return qualname
44
-
45
- # Handle dynamically generated classes (e.g., functools.partial)
46
- if not hasattr(cls, '__module__') or module is None:
47
- return f"<dynamic class {qualname}>"
48
-
49
- # Handle nested classes
50
- if '.' in qualname:
51
- # Try to import the parent module to validate the path
52
- try:
53
- import importlib
54
- parent_module = importlib.import_module(module)
55
-
56
- # Follow the qualified name path
57
- parts = qualname.split('.')
58
- current = parent_module
59
- for part in parts:
60
- current = getattr(current, part)
61
-
62
- # If successful access, return the original path
63
- return f"{module}.{qualname}"
64
- except (ImportError, AttributeError):
65
- # Fallback: use underscore connection
66
- return f"{module}.{qualname.replace('.', '_')}"
67
-
68
- return f"{module}.{qualname}"
69
-
70
-
71
-
72
- def monkey_patch_by_class_path(model, new_forward):
73
- """Perform monkey patching through class path"""
74
- # Get the complete class path
75
- class_path = get_importable_class_path(model)
76
-
77
- # Dynamically import the class
78
- try:
79
- import importlib
80
- module_path, class_name = class_path.rsplit('.', 1)
81
- module = importlib.import_module(module_path)
82
- target_class = getattr(module, class_name)
83
-
84
- # Save the original method
85
- if not hasattr(target_class, '_original_forward'):
86
- target_class._original_forward = target_class.forward
87
-
88
- # Apply the patch
89
- target_class.forward = new_forward
90
-
91
- # Update the method binding of the current instance
92
- model.forward = types.MethodType(target_class.forward, model)
93
-
94
- return f"Successful Monkey Patch: {class_path}.forward"
95
-
96
- except (ImportError, AttributeError, ValueError) as e:
97
- return f"Patch Failed: {str(e)}"
98
-
99
-
100
-
101
-
102
- def find_inner_attribute(obj, attr_name_list: List[str], default_type = PreTrainedModel ):
103
- # try to find the attribute of the name in `attr_name_list`.
104
- for target_attr_name in attr_name_list:
105
- if hasattr(obj, target_attr_name):
106
- return getattr(obj, target_attr_name)
107
-
108
- # else: try to find the attribute of the type `default_type`
109
- for attr_name in dir(obj):
110
- attr_value = getattr(obj, attr_name)
111
- if isinstance(attr_value, default_type):
112
- return attr_value
113
-
114
- raise AttributeError(f"In the {obj} object, there is no attribute whose name matches any name in {attr_name_list} or whose type is {default_type}.")
115
-
116
-
117
- def find_attribute_name(obj, name_pattern_list: List[str], exclude_pattern_list: List[str], match_type = nn.Module):
118
- for attr_name in dir(obj):
119
- attr_value = getattr(obj, attr_name)
120
- for pattern in name_pattern_list:
121
- for ex_pattern in exclude_pattern_list:
122
- if isinstance(attr_value, match_type) and (pattern.lower() in attr_value.__class__.__name__.lower()) and ( ex_pattern.lower() not in attr_value.__class__.__name__.lower() ):
123
- return attr_value
124
- elif isinstance(attr_value, match_type) and (pattern.lower() in attr_name.lower()) and (ex_pattern.lower() not in attr_name.lower() ):
125
- return attr_value
126
-
127
- raise AttributeError(f"In the {obj} object, there is no attribute whose name matches any pattern in {name_pattern_list} and excludes any pattern in {exclude_pattern_list}, and whose type is {match_type}.")
128
-
129
-
130
-
131
- def monkey_patching(model_obj, model_atten_forward , verbose = True):
132
- transformers.generation.GenerationMixin._validate_model_kwargs = _validate_model_kwargs
133
-
134
- ## get inner model
135
- possible_inner_model_names = ["model", "transformer", "gpt_neox"]
136
- inner_model_type = PreTrainedModel
137
- inner_model = find_inner_attribute(model_obj, possible_inner_model_names, inner_model_type)
138
-
139
-
140
- possible_layers_names = ["layers", "h" ]
141
- layers_type = nn.ModuleList
142
- model_layers = find_inner_attribute(inner_model, possible_layers_names, layers_type)
143
-
144
- atten_attr_name_pattern_list = ["attention", "self_attn"]
145
- atten_attr_name_pattern_exclude = ["norm", "layer"]
146
-
147
- for i, decoder_layer in enumerate(model_layers):
148
- self_attn_module = find_attribute_name(decoder_layer, atten_attr_name_pattern_list, atten_attr_name_pattern_exclude, nn.Module)
149
- result = monkey_patch_by_class_path(self_attn_module, model_atten_forward)
150
- if verbose:
151
- decoder_class_name = get_importable_class_path(decoder_layer)
152
- print(f"For Layer {i}'s `{decoder_class_name}`: {result}")
153
-
154
- return model_layers