Student0809's picture
Add files using upload-large-folder tool
7feac49 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import re
import types
from dataclasses import dataclass, field
from typing import List, Union
import torch
from torch import nn
from swift.utils import get_logger
from swift.utils.torch_utils import find_sub_module
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
logger = get_logger()
@dataclass
class PromptConfig(SwiftConfig):
"""
The configuration class for the prompt module.
Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens
and prepend to the original tokens in the first layer or multiple layers.
'Visual Prompt Tuning' by Jia et al.(2022)
See https://arxiv.org/abs/2203.12119
Here we apply the VPT to other fields.
Args:
dim(`Union[int, List[int]]`): The dimension of the hidden states, use list if there are up-sample blocks
or down-sample blocks
target_modules(str): The layer module to be replaced, in regex format
embedding_pos(Union[str, int]): The position of the embedding tensor
attention_mask_pos(Union[str, int]): The position of the attention mask
attention_mask_value(Union[float, int, bool]): The value to pad to the attention mask
prompt_length(int): The length of the prompt tokens
attach_front(bool): When set to True, prompt is attached in front of the embedding
extract_embedding(bool): Whether the embedding is extracted at final stage to keep the same dims with inputs
"""
dim: Union[int, List[int]] = field(default=None, metadata={'help': 'The dimension of the hidden states'})
target_modules: str = field(default=None, metadata={'help': 'The layer module to be replaced, in regex format'})
embedding_pos: Union[str, int] = field(default=None, metadata={'help': 'The position of the embedding tensor'})
attention_mask_pos: Union[str, int] = field(default=None, metadata={'help': 'The position of the attention mask'})
attention_mask_value: Union[float, int, bool] = field(
default=0., metadata={'help': 'The value to pad to the attention mask'})
prompt_length: int = field(default=16, metadata={'help': 'The length of the prompt tokens'})
attach_front: bool = field(
default=True, metadata={'help': 'When set to True, prompt is attached in front of the embedding'})
extract_embedding: bool = field(
default=False,
metadata={'help': 'Whether the embedding is extracted at final stage to keep the same dims with inputs'})
def __post_init__(self):
from .mapping import SwiftTuners
self.swift_type = SwiftTuners.PROMPT
class Prompt(SwiftAdapter):
@staticmethod
def prepare_model(model: nn.Module, config: PromptConfig, adapter_name: str):
module_keys = [key for key, _ in model.named_modules()]
match_module_keys = []
for module_key in module_keys:
if isinstance(config.target_modules, str):
target_module_found = re.fullmatch(config.target_modules, module_key)
else:
target_module_found = any(module_key.endswith(target_key) for target_key in config.target_modules)
if target_module_found: # noqa
module = model.get_submodule(module_key)
def _forward(self, *args, **kwargs):
if isinstance(config.embedding_pos, int):
input_embedding = args[config.embedding_pos]
else:
input_embedding = kwargs[config.embedding_pos]
input_embedding = getattr(self, f'prompt_{adapter_name}').forward(input_embedding)
if isinstance(config.embedding_pos, int):
args = type(args)(
args[0:config.embedding_pos] + (input_embedding, ) + args[config.embedding_pos + 1:])
else:
kwargs[config.embedding_pos] = input_embedding
if config.attention_mask_pos:
attention_mask = None
if isinstance(config.attention_mask_pos, int):
attention_mask = args[config.attention_mask_pos]
elif isinstance(config.attention_mask_pos, str):
attention_mask = kwargs[config.attention_mask_pos]
if attention_mask is not None:
attention_mask = getattr(self,
f'prompt_{adapter_name}').patch_attention_mask(attention_mask)
if isinstance(config.attention_mask_pos, int):
args = type(args)(
args[0:config.attention_mask_pos] + (attention_mask, )
+ args[config.attention_mask_pos + 1:])
else:
kwargs[config.attention_mask_pos] = attention_mask
forward_output = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
if config.extract_embedding:
forward_output = getattr(self, f'prompt_{adapter_name}').extract(forward_output)
return forward_output
setattr(module, f'forward_origin_{adapter_name}', module.forward)
module.forward = types.MethodType(_forward, module)
if isinstance(config.dim, list):
input_dim = config.dim[len(match_module_keys)]
else:
input_dim = config.dim
prompt_module = PromptModule(input_dim, int(module_key.rsplit('.')[-1]), adapter_name, module_key,
config.prompt_length, config.attention_mask_value, config.attach_front)
setattr(module, f'prompt_{adapter_name}', prompt_module)
logger.info(f'Prompt modules(module_key): {module_key}.prompt_{adapter_name}')
match_module_keys.append(module_key)
def state_dict_callback(state_dict, adapter_name, **kwargs):
return {key: value for key, value in state_dict.items() if f'prompt_{adapter_name}' in key}
def mark_trainable_callback(model):
return
return SwiftOutput(
config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
@staticmethod
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
modules = find_sub_module(module, f'prompt_{adapter_name}')
for _module in modules:
_module: ActivationMixin
_module: nn.Module
_module.set_activation(adapter_name, activate)
SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload)
class PromptModule(nn.Module, ActivationMixin):
"""The implementation of vision prompt tuning method.
Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens
and prepend to the original tokens in the first layer or multiple layers.
'Visual Prompt Tuning' by Jia et al.(2022)
See https://arxiv.org/abs/2203.12119
Args:
dim: An integer indicating the embedding dimension.
layer_num: An integer indicating number of layers.
prompt_length: An integer indicating the length of vision prompt tuning.
"""
def __init__(self, dim, layer_num, adapter_name, module_key, prompt_length=None, mask_values=0., attach_front=True):
super(PromptModule, self).__init__()
super(nn.Module, self).__init__(module_key)
self.dim = dim
self.layer_num = layer_num
self.adapter_name = adapter_name
self.prompt_length = prompt_length
self.mask_values = mask_values
self.attach_front = attach_front
self.prompt_token = nn.Parameter(torch.zeros(1, prompt_length, dim))
nn.init.xavier_uniform_(self.prompt_token)
self.mark_all_sub_modules_as_plugin()
def forward(self, x):
if not self.is_activated(self.adapter_name):
return x
prompt_token = self.prompt_token.expand(x.shape[0], -1, -1).to(x.device, x.dtype)
if self.layer_num == 0:
if self.attach_front:
x = torch.cat((prompt_token, x), dim=1)
else:
x = torch.cat((x, prompt_token), dim=1)
else:
if self.attach_front:
x = torch.cat((prompt_token, x[:, self.prompt_length:, :]), dim=1)
else:
x = torch.cat((x[:, :-self.prompt_length, :], prompt_token), dim=1)
return x
def patch_attention_mask(self, m):
if not self.is_activated(self.adapter_name):
return m
prefix_attention_mask = torch.full((*m.shape[:-1], self.prompt_length), self.mask_values).to(m.device)
if self.attach_front:
return torch.cat((prefix_attention_mask, m), dim=-1)
else:
return torch.cat((m, prefix_attention_mask), dim=-1)
def extract(self, x):
if self.attach_front:
return x[:, self.prompt_length:, :]
else:
return x[:, :-self.prompt_length, :]