|
|
|
|
|
|
|
|
import re |
|
|
import types |
|
|
from dataclasses import dataclass, field |
|
|
from typing import List, Union |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from swift.utils import get_logger |
|
|
from swift.utils.torch_utils import find_sub_module |
|
|
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PromptConfig(SwiftConfig): |
|
|
""" |
|
|
The configuration class for the prompt module. |
|
|
|
|
|
Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens |
|
|
and prepend to the original tokens in the first layer or multiple layers. |
|
|
'Visual Prompt Tuning' by Jia et al.(2022) |
|
|
See https://arxiv.org/abs/2203.12119 |
|
|
|
|
|
Here we apply the VPT to other fields. |
|
|
|
|
|
Args: |
|
|
dim(`Union[int, List[int]]`): The dimension of the hidden states, use list if there are up-sample blocks |
|
|
or down-sample blocks |
|
|
target_modules(str): The layer module to be replaced, in regex format |
|
|
embedding_pos(Union[str, int]): The position of the embedding tensor |
|
|
attention_mask_pos(Union[str, int]): The position of the attention mask |
|
|
attention_mask_value(Union[float, int, bool]): The value to pad to the attention mask |
|
|
prompt_length(int): The length of the prompt tokens |
|
|
attach_front(bool): When set to True, prompt is attached in front of the embedding |
|
|
extract_embedding(bool): Whether the embedding is extracted at final stage to keep the same dims with inputs |
|
|
""" |
|
|
|
|
|
dim: Union[int, List[int]] = field(default=None, metadata={'help': 'The dimension of the hidden states'}) |
|
|
|
|
|
target_modules: str = field(default=None, metadata={'help': 'The layer module to be replaced, in regex format'}) |
|
|
|
|
|
embedding_pos: Union[str, int] = field(default=None, metadata={'help': 'The position of the embedding tensor'}) |
|
|
|
|
|
attention_mask_pos: Union[str, int] = field(default=None, metadata={'help': 'The position of the attention mask'}) |
|
|
|
|
|
attention_mask_value: Union[float, int, bool] = field( |
|
|
default=0., metadata={'help': 'The value to pad to the attention mask'}) |
|
|
|
|
|
prompt_length: int = field(default=16, metadata={'help': 'The length of the prompt tokens'}) |
|
|
|
|
|
attach_front: bool = field( |
|
|
default=True, metadata={'help': 'When set to True, prompt is attached in front of the embedding'}) |
|
|
|
|
|
extract_embedding: bool = field( |
|
|
default=False, |
|
|
metadata={'help': 'Whether the embedding is extracted at final stage to keep the same dims with inputs'}) |
|
|
|
|
|
def __post_init__(self): |
|
|
from .mapping import SwiftTuners |
|
|
self.swift_type = SwiftTuners.PROMPT |
|
|
|
|
|
|
|
|
class Prompt(SwiftAdapter): |
|
|
|
|
|
@staticmethod |
|
|
def prepare_model(model: nn.Module, config: PromptConfig, adapter_name: str): |
|
|
module_keys = [key for key, _ in model.named_modules()] |
|
|
match_module_keys = [] |
|
|
for module_key in module_keys: |
|
|
if isinstance(config.target_modules, str): |
|
|
target_module_found = re.fullmatch(config.target_modules, module_key) |
|
|
else: |
|
|
target_module_found = any(module_key.endswith(target_key) for target_key in config.target_modules) |
|
|
if target_module_found: |
|
|
module = model.get_submodule(module_key) |
|
|
|
|
|
def _forward(self, *args, **kwargs): |
|
|
if isinstance(config.embedding_pos, int): |
|
|
input_embedding = args[config.embedding_pos] |
|
|
else: |
|
|
input_embedding = kwargs[config.embedding_pos] |
|
|
|
|
|
input_embedding = getattr(self, f'prompt_{adapter_name}').forward(input_embedding) |
|
|
if isinstance(config.embedding_pos, int): |
|
|
args = type(args)( |
|
|
args[0:config.embedding_pos] + (input_embedding, ) + args[config.embedding_pos + 1:]) |
|
|
else: |
|
|
kwargs[config.embedding_pos] = input_embedding |
|
|
|
|
|
if config.attention_mask_pos: |
|
|
attention_mask = None |
|
|
if isinstance(config.attention_mask_pos, int): |
|
|
attention_mask = args[config.attention_mask_pos] |
|
|
elif isinstance(config.attention_mask_pos, str): |
|
|
attention_mask = kwargs[config.attention_mask_pos] |
|
|
|
|
|
if attention_mask is not None: |
|
|
attention_mask = getattr(self, |
|
|
f'prompt_{adapter_name}').patch_attention_mask(attention_mask) |
|
|
if isinstance(config.attention_mask_pos, int): |
|
|
args = type(args)( |
|
|
args[0:config.attention_mask_pos] + (attention_mask, ) |
|
|
+ args[config.attention_mask_pos + 1:]) |
|
|
else: |
|
|
kwargs[config.attention_mask_pos] = attention_mask |
|
|
|
|
|
forward_output = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs) |
|
|
if config.extract_embedding: |
|
|
forward_output = getattr(self, f'prompt_{adapter_name}').extract(forward_output) |
|
|
|
|
|
return forward_output |
|
|
|
|
|
setattr(module, f'forward_origin_{adapter_name}', module.forward) |
|
|
module.forward = types.MethodType(_forward, module) |
|
|
if isinstance(config.dim, list): |
|
|
input_dim = config.dim[len(match_module_keys)] |
|
|
else: |
|
|
input_dim = config.dim |
|
|
prompt_module = PromptModule(input_dim, int(module_key.rsplit('.')[-1]), adapter_name, module_key, |
|
|
config.prompt_length, config.attention_mask_value, config.attach_front) |
|
|
setattr(module, f'prompt_{adapter_name}', prompt_module) |
|
|
logger.info(f'Prompt modules(module_key): {module_key}.prompt_{adapter_name}') |
|
|
match_module_keys.append(module_key) |
|
|
|
|
|
def state_dict_callback(state_dict, adapter_name, **kwargs): |
|
|
return {key: value for key, value in state_dict.items() if f'prompt_{adapter_name}' in key} |
|
|
|
|
|
def mark_trainable_callback(model): |
|
|
return |
|
|
|
|
|
return SwiftOutput( |
|
|
config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback) |
|
|
|
|
|
@staticmethod |
|
|
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): |
|
|
modules = find_sub_module(module, f'prompt_{adapter_name}') |
|
|
for _module in modules: |
|
|
_module: ActivationMixin |
|
|
_module: nn.Module |
|
|
_module.set_activation(adapter_name, activate) |
|
|
SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload) |
|
|
|
|
|
|
|
|
class PromptModule(nn.Module, ActivationMixin): |
|
|
"""The implementation of vision prompt tuning method. |
|
|
|
|
|
Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens |
|
|
and prepend to the original tokens in the first layer or multiple layers. |
|
|
'Visual Prompt Tuning' by Jia et al.(2022) |
|
|
See https://arxiv.org/abs/2203.12119 |
|
|
|
|
|
Args: |
|
|
dim: An integer indicating the embedding dimension. |
|
|
layer_num: An integer indicating number of layers. |
|
|
prompt_length: An integer indicating the length of vision prompt tuning. |
|
|
""" |
|
|
|
|
|
def __init__(self, dim, layer_num, adapter_name, module_key, prompt_length=None, mask_values=0., attach_front=True): |
|
|
super(PromptModule, self).__init__() |
|
|
super(nn.Module, self).__init__(module_key) |
|
|
self.dim = dim |
|
|
self.layer_num = layer_num |
|
|
self.adapter_name = adapter_name |
|
|
self.prompt_length = prompt_length |
|
|
self.mask_values = mask_values |
|
|
self.attach_front = attach_front |
|
|
self.prompt_token = nn.Parameter(torch.zeros(1, prompt_length, dim)) |
|
|
nn.init.xavier_uniform_(self.prompt_token) |
|
|
self.mark_all_sub_modules_as_plugin() |
|
|
|
|
|
def forward(self, x): |
|
|
if not self.is_activated(self.adapter_name): |
|
|
return x |
|
|
prompt_token = self.prompt_token.expand(x.shape[0], -1, -1).to(x.device, x.dtype) |
|
|
|
|
|
if self.layer_num == 0: |
|
|
if self.attach_front: |
|
|
x = torch.cat((prompt_token, x), dim=1) |
|
|
else: |
|
|
x = torch.cat((x, prompt_token), dim=1) |
|
|
else: |
|
|
if self.attach_front: |
|
|
x = torch.cat((prompt_token, x[:, self.prompt_length:, :]), dim=1) |
|
|
else: |
|
|
x = torch.cat((x[:, :-self.prompt_length, :], prompt_token), dim=1) |
|
|
return x |
|
|
|
|
|
def patch_attention_mask(self, m): |
|
|
if not self.is_activated(self.adapter_name): |
|
|
return m |
|
|
prefix_attention_mask = torch.full((*m.shape[:-1], self.prompt_length), self.mask_values).to(m.device) |
|
|
if self.attach_front: |
|
|
return torch.cat((prefix_attention_mask, m), dim=-1) |
|
|
else: |
|
|
return torch.cat((m, prefix_attention_mask), dim=-1) |
|
|
|
|
|
def extract(self, x): |
|
|
if self.attach_front: |
|
|
return x[:, self.prompt_length:, :] |
|
|
else: |
|
|
return x[:, :-self.prompt_length, :] |
|
|
|