File size: 9,331 Bytes
7feac49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# Copyright (c) Alibaba, Inc. and its affiliates.

import re
import types
from dataclasses import dataclass, field
from typing import List, Union

import torch
from torch import nn

from swift.utils import get_logger
from swift.utils.torch_utils import find_sub_module
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput

logger = get_logger()


@dataclass
class PromptConfig(SwiftConfig):
    """
    The configuration class for the prompt module.

    Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens
    and prepend to the original tokens in the first layer or multiple layers.
    'Visual Prompt Tuning' by Jia et al.(2022)
    See https://arxiv.org/abs/2203.12119

    Here we apply the VPT to other fields.

    Args:
        dim(`Union[int, List[int]]`): The dimension of the hidden states, use list if there are up-sample blocks
            or down-sample blocks
        target_modules(str): The layer module to be replaced, in regex format
        embedding_pos(Union[str, int]): The position of the embedding tensor
        attention_mask_pos(Union[str, int]): The position of the attention mask
        attention_mask_value(Union[float, int, bool]): The value to pad to the attention mask
        prompt_length(int): The length of the prompt tokens
        attach_front(bool): When set to True, prompt is attached in front of the embedding
        extract_embedding(bool): Whether the embedding is extracted at final stage to keep the same dims with inputs
    """

    dim: Union[int, List[int]] = field(default=None, metadata={'help': 'The dimension of the hidden states'})

    target_modules: str = field(default=None, metadata={'help': 'The layer module to be replaced, in regex format'})

    embedding_pos: Union[str, int] = field(default=None, metadata={'help': 'The position of the embedding tensor'})

    attention_mask_pos: Union[str, int] = field(default=None, metadata={'help': 'The position of the attention mask'})

    attention_mask_value: Union[float, int, bool] = field(
        default=0., metadata={'help': 'The value to pad to the attention mask'})

    prompt_length: int = field(default=16, metadata={'help': 'The length of the prompt tokens'})

    attach_front: bool = field(
        default=True, metadata={'help': 'When set to True, prompt is attached in front of the embedding'})

    extract_embedding: bool = field(
        default=False,
        metadata={'help': 'Whether the embedding is extracted at final stage to keep the same dims with inputs'})

    def __post_init__(self):
        from .mapping import SwiftTuners
        self.swift_type = SwiftTuners.PROMPT


class Prompt(SwiftAdapter):

    @staticmethod
    def prepare_model(model: nn.Module, config: PromptConfig, adapter_name: str):
        module_keys = [key for key, _ in model.named_modules()]
        match_module_keys = []
        for module_key in module_keys:
            if isinstance(config.target_modules, str):
                target_module_found = re.fullmatch(config.target_modules, module_key)
            else:
                target_module_found = any(module_key.endswith(target_key) for target_key in config.target_modules)
            if target_module_found:  # noqa
                module = model.get_submodule(module_key)

                def _forward(self, *args, **kwargs):
                    if isinstance(config.embedding_pos, int):
                        input_embedding = args[config.embedding_pos]
                    else:
                        input_embedding = kwargs[config.embedding_pos]

                    input_embedding = getattr(self, f'prompt_{adapter_name}').forward(input_embedding)
                    if isinstance(config.embedding_pos, int):
                        args = type(args)(
                            args[0:config.embedding_pos] + (input_embedding, ) + args[config.embedding_pos + 1:])
                    else:
                        kwargs[config.embedding_pos] = input_embedding

                    if config.attention_mask_pos:
                        attention_mask = None
                        if isinstance(config.attention_mask_pos, int):
                            attention_mask = args[config.attention_mask_pos]
                        elif isinstance(config.attention_mask_pos, str):
                            attention_mask = kwargs[config.attention_mask_pos]

                        if attention_mask is not None:
                            attention_mask = getattr(self,
                                                     f'prompt_{adapter_name}').patch_attention_mask(attention_mask)
                        if isinstance(config.attention_mask_pos, int):
                            args = type(args)(
                                args[0:config.attention_mask_pos] + (attention_mask, )
                                + args[config.attention_mask_pos + 1:])
                        else:
                            kwargs[config.attention_mask_pos] = attention_mask

                    forward_output = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
                    if config.extract_embedding:
                        forward_output = getattr(self, f'prompt_{adapter_name}').extract(forward_output)

                    return forward_output

                setattr(module, f'forward_origin_{adapter_name}', module.forward)
                module.forward = types.MethodType(_forward, module)
                if isinstance(config.dim, list):
                    input_dim = config.dim[len(match_module_keys)]
                else:
                    input_dim = config.dim
                prompt_module = PromptModule(input_dim, int(module_key.rsplit('.')[-1]), adapter_name, module_key,
                                             config.prompt_length, config.attention_mask_value, config.attach_front)
                setattr(module, f'prompt_{adapter_name}', prompt_module)
                logger.info(f'Prompt modules(module_key): {module_key}.prompt_{adapter_name}')
                match_module_keys.append(module_key)

        def state_dict_callback(state_dict, adapter_name, **kwargs):
            return {key: value for key, value in state_dict.items() if f'prompt_{adapter_name}' in key}

        def mark_trainable_callback(model):
            return

        return SwiftOutput(
            config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)

    @staticmethod
    def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
        modules = find_sub_module(module, f'prompt_{adapter_name}')
        for _module in modules:
            _module: ActivationMixin
            _module: nn.Module
            _module.set_activation(adapter_name, activate)
            SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload)


class PromptModule(nn.Module, ActivationMixin):
    """The implementation of vision prompt tuning method.

    Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens
    and prepend to the original tokens in the first layer or multiple layers.
    'Visual Prompt Tuning' by Jia et al.(2022)
    See https://arxiv.org/abs/2203.12119

    Args:
        dim: An integer indicating the embedding dimension.
        layer_num: An integer indicating number of layers.
        prompt_length: An integer indicating the length of vision prompt tuning.
    """

    def __init__(self, dim, layer_num, adapter_name, module_key, prompt_length=None, mask_values=0., attach_front=True):
        super(PromptModule, self).__init__()
        super(nn.Module, self).__init__(module_key)
        self.dim = dim
        self.layer_num = layer_num
        self.adapter_name = adapter_name
        self.prompt_length = prompt_length
        self.mask_values = mask_values
        self.attach_front = attach_front
        self.prompt_token = nn.Parameter(torch.zeros(1, prompt_length, dim))
        nn.init.xavier_uniform_(self.prompt_token)
        self.mark_all_sub_modules_as_plugin()

    def forward(self, x):
        if not self.is_activated(self.adapter_name):
            return x
        prompt_token = self.prompt_token.expand(x.shape[0], -1, -1).to(x.device, x.dtype)

        if self.layer_num == 0:
            if self.attach_front:
                x = torch.cat((prompt_token, x), dim=1)
            else:
                x = torch.cat((x, prompt_token), dim=1)
        else:
            if self.attach_front:
                x = torch.cat((prompt_token, x[:, self.prompt_length:, :]), dim=1)
            else:
                x = torch.cat((x[:, :-self.prompt_length, :], prompt_token), dim=1)
        return x

    def patch_attention_mask(self, m):
        if not self.is_activated(self.adapter_name):
            return m
        prefix_attention_mask = torch.full((*m.shape[:-1], self.prompt_length), self.mask_values).to(m.device)
        if self.attach_front:
            return torch.cat((prefix_attention_mask, m), dim=-1)
        else:
            return torch.cat((m, prefix_attention_mask), dim=-1)

    def extract(self, x):
        if self.attach_front:
            return x[:, self.prompt_length:, :]
        else:
            return x[:, :-self.prompt_length, :]