# Copyright (c) Alibaba, Inc. and its affiliates. import re from copy import deepcopy from dataclasses import dataclass from types import MethodType from typing import Dict, Optional import torch from torch import nn from swift.utils import get_logger from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput logger = get_logger() @dataclass class PartConfig(SwiftConfig): """ Freeze the model and train a part of it. Args: target_modules(`Optional[str]`): The target modules to be trained in regex format """ target_modules: Optional[str] = None def __post_init__(self): from .mapping import SwiftTuners self.swift_type = SwiftTuners.PART class Part(SwiftAdapter): @staticmethod def target_module_matched(module_key: str, config: PartConfig): return re.fullmatch(config.target_modules, module_key) @staticmethod def prepare_model(model: nn.Module, config: PartConfig, adapter_name: str): name_list = [name for name, _ in model.named_modules(remove_duplicate=False)] for name in name_list: module: nn.Module = model.get_submodule(name) if Part.target_module_matched(name, config) and not getattr(module, 'plugin', False): if hasattr(module, 'base_layer'): module = module.base_layer def _forward(self, *args, **kwargs): child_list = [ sub_module for name, sub_module in self.named_modules(remove_duplicate=False) if '_part_' in name ] sub_modules = [child for child in child_list if getattr(child, 'activated', False)] assert len(sub_modules) <= 1 if len(sub_modules) == 1: return sub_modules[0].forward(*args, **kwargs) else: return self.forward_origin(*args, **kwargs) if not hasattr(module, 'forward_origin'): module.forward_origin = module.forward module.forward = MethodType(_forward, module) new_module = deepcopy(module) for attr in dir(new_module): if '_part_' in attr: delattr(new_module, attr) new_module.part_name = adapter_name ActivationMixin.mark_all_sub_modules_as_plugin(new_module) setattr(module, f'_part_{adapter_name}', new_module) new_module.requires_grad_(True) def state_dict_callback(state_dict, adapter_name, **kwargs): new_state_dict = {} for key, value in state_dict.items(): if f'_part_{adapter_name}.' in key: if kwargs.get('replace_key', True): new_key = key.replace(f'_part_{adapter_name}.', '').replace('base_layer.', '') else: new_key = key new_state_dict[new_key] = value return new_state_dict def mark_trainable_callback(model: nn.Module): pass def load_state_dict_callback(model: nn.Module, adapter_name: str, state_dict: Dict[str, torch.Tensor]): new_state_dict = {} for name, module in model.named_modules(remove_duplicate=False): module: nn.Module if Part.target_module_matched(name, config): for param_name in state_dict: if param_name.startswith(name): end = param_name[len(name):] if '_part_' not in param_name: if hasattr(module, 'base_layer'): new_state_dict[name + f'.base_layer._part_{adapter_name}' + end] = state_dict[param_name] else: new_state_dict[name + f'._part_{adapter_name}' + end] = state_dict[param_name] else: new_state_dict[param_name] = state_dict[param_name] return new_state_dict return SwiftOutput( config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback, load_state_dict_callback=load_state_dict_callback) @staticmethod def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): name_list = [name for name, _ in module.named_modules(remove_duplicate=False)] for name in name_list: sub_module: nn.Module = module.get_submodule(name) if re.fullmatch(f'.*_part_{adapter_name}$', name): sub_module.activated = activate SwiftAdapter.save_memory(sub_module, adapter_name, name, activate, offload)