| |
| import re |
| from copy import deepcopy |
| from dataclasses import dataclass |
| from types import MethodType |
| from typing import Dict, Optional |
|
|
| import torch |
| from torch import nn |
|
|
| from swift.utils import get_logger |
| from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput |
|
|
| logger = get_logger() |
|
|
|
|
| @dataclass |
| class PartConfig(SwiftConfig): |
| """ |
| Freeze the model and train a part of it. |
| |
| Args: |
| target_modules(`Optional[str]`): The target modules to be trained in regex format |
| """ |
|
|
| target_modules: Optional[str] = None |
|
|
| def __post_init__(self): |
| from .mapping import SwiftTuners |
| self.swift_type = SwiftTuners.PART |
|
|
|
|
| class Part(SwiftAdapter): |
|
|
| @staticmethod |
| def target_module_matched(module_key: str, config: PartConfig): |
| return re.fullmatch(config.target_modules, module_key) |
|
|
| @staticmethod |
| def prepare_model(model: nn.Module, config: PartConfig, adapter_name: str): |
| name_list = [name for name, _ in model.named_modules(remove_duplicate=False)] |
| for name in name_list: |
| module: nn.Module = model.get_submodule(name) |
| if Part.target_module_matched(name, config) and not getattr(module, 'plugin', False): |
| if hasattr(module, 'base_layer'): |
| module = module.base_layer |
|
|
| def _forward(self, *args, **kwargs): |
| child_list = [ |
| sub_module for name, sub_module in self.named_modules(remove_duplicate=False) |
| if '_part_' in name |
| ] |
| sub_modules = [child for child in child_list if getattr(child, 'activated', False)] |
| assert len(sub_modules) <= 1 |
| if len(sub_modules) == 1: |
| return sub_modules[0].forward(*args, **kwargs) |
| else: |
| return self.forward_origin(*args, **kwargs) |
|
|
| if not hasattr(module, 'forward_origin'): |
| module.forward_origin = module.forward |
| module.forward = MethodType(_forward, module) |
|
|
| new_module = deepcopy(module) |
| for attr in dir(new_module): |
| if '_part_' in attr: |
| delattr(new_module, attr) |
| new_module.part_name = adapter_name |
| ActivationMixin.mark_all_sub_modules_as_plugin(new_module) |
| setattr(module, f'_part_{adapter_name}', new_module) |
| new_module.requires_grad_(True) |
|
|
| def state_dict_callback(state_dict, adapter_name, **kwargs): |
| new_state_dict = {} |
| for key, value in state_dict.items(): |
| if f'_part_{adapter_name}.' in key: |
| if kwargs.get('replace_key', True): |
| new_key = key.replace(f'_part_{adapter_name}.', '').replace('base_layer.', '') |
| else: |
| new_key = key |
| new_state_dict[new_key] = value |
|
|
| return new_state_dict |
|
|
| def mark_trainable_callback(model: nn.Module): |
| pass |
|
|
| def load_state_dict_callback(model: nn.Module, adapter_name: str, state_dict: Dict[str, torch.Tensor]): |
| new_state_dict = {} |
| for name, module in model.named_modules(remove_duplicate=False): |
| module: nn.Module |
| if Part.target_module_matched(name, config): |
| for param_name in state_dict: |
| if param_name.startswith(name): |
| end = param_name[len(name):] |
| if '_part_' not in param_name: |
| if hasattr(module, 'base_layer'): |
| new_state_dict[name + f'.base_layer._part_{adapter_name}' |
| + end] = state_dict[param_name] |
| else: |
| new_state_dict[name + f'._part_{adapter_name}' + end] = state_dict[param_name] |
| else: |
| new_state_dict[param_name] = state_dict[param_name] |
| return new_state_dict |
|
|
| return SwiftOutput( |
| config=config, |
| state_dict_callback=state_dict_callback, |
| mark_trainable_callback=mark_trainable_callback, |
| load_state_dict_callback=load_state_dict_callback) |
|
|
| @staticmethod |
| def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): |
| name_list = [name for name, _ in module.named_modules(remove_duplicate=False)] |
| for name in name_list: |
| sub_module: nn.Module = module.get_submodule(name) |
| if re.fullmatch(f'.*_part_{adapter_name}$', name): |
| sub_module.activated = activate |
| SwiftAdapter.save_memory(sub_module, adapter_name, name, activate, offload) |
|
|