| | import torch |
| | import inspect |
| | import importlib |
| | import transformers |
| | import types |
| |
|
| | import torch.nn as nn |
| | from transformers.modeling_utils import PreTrainedModel |
| | from typing import Callable, Optional, Union, Any, List |
| |
|
| | from .functions_2_patch import _validate_model_kwargs, llama_atten_forward |
| |
|
| |
|
| | def get_full_class_import_path(obj): |
| | """Get the complete class import path of an object""" |
| | |
| | cls = obj.__class__ |
| | |
| | |
| | module = cls.__module__ |
| | |
| | |
| | qualname = cls.__qualname__ |
| | |
| | |
| | if '.' in qualname: |
| | |
| | class_path = f"{module}.{qualname.replace('.', '_')}" |
| | else: |
| | class_path = f"{module}.{qualname}" |
| | |
| | return class_path |
| |
|
| |
|
| | def get_importable_class_path(obj): |
| | """Get the directly importable class path (handling special cases and dynamic classes)""" |
| | cls = obj.__class__ |
| | module = cls.__module__ |
| | qualname = cls.__qualname__ |
| | |
| | |
| | if module == 'builtins': |
| | return qualname |
| | |
| | |
| | if not hasattr(cls, '__module__') or module is None: |
| | return f"<dynamic class {qualname}>" |
| | |
| | |
| | if '.' in qualname: |
| | |
| | try: |
| | import importlib |
| | parent_module = importlib.import_module(module) |
| | |
| | |
| | parts = qualname.split('.') |
| | current = parent_module |
| | for part in parts: |
| | current = getattr(current, part) |
| | |
| | |
| | return f"{module}.{qualname}" |
| | except (ImportError, AttributeError): |
| | |
| | return f"{module}.{qualname.replace('.', '_')}" |
| | |
| | return f"{module}.{qualname}" |
| |
|
| |
|
| |
|
| | def monkey_patch_by_class_path(model, new_forward): |
| | """Perform monkey patching through class path""" |
| | |
| | class_path = get_importable_class_path(model) |
| | |
| | |
| | try: |
| | import importlib |
| | module_path, class_name = class_path.rsplit('.', 1) |
| | module = importlib.import_module(module_path) |
| | target_class = getattr(module, class_name) |
| | |
| | |
| | if not hasattr(target_class, '_original_forward'): |
| | target_class._original_forward = target_class.forward |
| | |
| | |
| | target_class.forward = new_forward |
| | |
| | |
| | model.forward = types.MethodType(target_class.forward, model) |
| | |
| | return f"Successful Monkey Patch: {class_path}.forward" |
| | |
| | except (ImportError, AttributeError, ValueError) as e: |
| | return f"Patch Failed: {str(e)}" |
| |
|
| |
|
| |
|
| |
|
| | def find_inner_attribute(obj, attr_name_list: List[str], default_type = PreTrainedModel ): |
| | |
| | for target_attr_name in attr_name_list: |
| | if hasattr(obj, target_attr_name): |
| | return getattr(obj, target_attr_name) |
| | |
| | |
| | for attr_name in dir(obj): |
| | attr_value = getattr(obj, attr_name) |
| | if isinstance(attr_value, default_type): |
| | return attr_value |
| |
|
| | raise AttributeError(f"In the {obj} object, there is no attribute whose name matches any name in {attr_name_list} or whose type is {default_type}.") |
| |
|
| |
|
| | def find_attribute_name(obj, name_pattern_list: List[str], exclude_pattern_list: List[str], match_type = nn.Module): |
| | for attr_name in dir(obj): |
| | attr_value = getattr(obj, attr_name) |
| | for pattern in name_pattern_list: |
| | for ex_pattern in exclude_pattern_list: |
| | if isinstance(attr_value, match_type) and (pattern.lower() in attr_value.__class__.__name__.lower()) and ( ex_pattern.lower() not in attr_value.__class__.__name__.lower() ): |
| | return attr_value |
| | elif isinstance(attr_value, match_type) and (pattern.lower() in attr_name.lower()) and (ex_pattern.lower() not in attr_name.lower() ): |
| | return attr_value |
| |
|
| | raise AttributeError(f"In the {obj} object, there is no attribute whose name matches any pattern in {name_pattern_list} and excludes any pattern in {exclude_pattern_list}, and whose type is {match_type}.") |
| |
|
| |
|
| |
|
| | def monkey_patching(model_obj, model_atten_forward , verbose = True): |
| | transformers.generation.GenerationMixin._validate_model_kwargs = _validate_model_kwargs |
| |
|
| | |
| | possible_inner_model_names = ["model", "transformer", "gpt_neox"] |
| | inner_model_type = PreTrainedModel |
| | inner_model = find_inner_attribute(model_obj, possible_inner_model_names, inner_model_type) |
| |
|
| |
|
| | possible_layers_names = ["layers", "h" ] |
| | layers_type = nn.ModuleList |
| | model_layers = find_inner_attribute(inner_model, possible_layers_names, layers_type) |
| | |
| | atten_attr_name_pattern_list = ["attention", "self_attn"] |
| | atten_attr_name_pattern_exclude = ["norm", "layer"] |
| |
|
| | for i, decoder_layer in enumerate(model_layers): |
| | self_attn_module = find_attribute_name(decoder_layer, atten_attr_name_pattern_list, atten_attr_name_pattern_exclude, nn.Module) |
| | result = monkey_patch_by_class_path(self_attn_module, model_atten_forward) |
| | if verbose: |
| | decoder_class_name = get_importable_class_path(decoder_layer) |
| | print(f"For Layer {i}'s `{decoder_class_name}`: {result}") |
| |
|
| | return model_layers |
| |
|