File size: 9,282 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from typing import Optional

import torch
from mmengine.logging import MMLogger
from mmengine.model import BaseModel
from mmengine.runner.checkpoint import _load_checkpoint
from torch import nn

from mmaction.registry import MODELS, TOKENIZER
from mmaction.utils import ForwardResults, SampleList
from .utils import (interpolate_pos_embed_beit,
                    interpolate_pos_relative_bias_beit)


class VindLUBase(BaseModel):
    """VindLU base Model.



    Args:

        tokenizer: (dict): The config for tokenizer.

        vision_encoder (dict): Backbone for extracting image features.

        text_encoder (dict): Backbone for extracting text features.

        temperature (float): Temperature parameter that controls the

            concentration level of the distribution. Defaults to 0.07.

        gradient_checkpointing (bool): Whether to do gradient_checkpointing.

            Using checkpoint will save some memory while slowing down the

            training speed. Defaults to False.

        data_preprocessor (Optional[dict]): The config for preprocessing input

            data.

        init_cfg (Optional[dict]): the config to control the initialization.

            Defaults to None.

    """

    def __init__(

        self,

        tokenizer: dict,

        vision_encoder: dict,

        text_encoder: dict,

        proj_dim: int = 256,

        temperature: float = 0.07,

        gradient_checkpointing: bool = False,

        pretrined_vl: bool = True,

        data_preprocessor: Optional[dict] = None,

        init_cfg: Optional[dict] = None,

    ):
        if data_preprocessor is None:
            data_preprocessor = dict(type='ActionDataPreprocessor')
        super().__init__(
            init_cfg=init_cfg, data_preprocessor=data_preprocessor)

        self.tokenizer = TOKENIZER.build(tokenizer)
        self.vision_cfg = vision_encoder
        self.text_encoder_cfg = text_encoder
        self.gradient_checkpointing = gradient_checkpointing
        self.text_encoder_cfg.gradient_checkpointing = gradient_checkpointing

        self.vision_width = vision_encoder.pop('encoder_width')
        self.text_width = text_encoder.encoder_width
        self.pretrined_vl = pretrined_vl

        if self.vision_cfg.pop('add_ln'):
            self.vision_layernorm = nn.LayerNorm(self.vision_width, eps=1e-12)
        else:
            self.vision_layernorm = nn.Identity()

        self.vision_encoder = MODELS.build(self.vision_cfg)

        if gradient_checkpointing:
            self.vision_encoder.gradient_checkpointing_enable()

        self.text_encoder = MODELS.build(self.text_encoder_cfg)

        self.vision_proj = nn.Linear(self.vision_width, proj_dim)
        self.text_proj = nn.Linear(self.text_width, proj_dim)

        self.temp = nn.parameter.Parameter(torch.ones([]) * temperature)
        self.itm_head = nn.Linear(self.text_width, 2)

    def extract_feat(self, inputs: torch.Tensor, **kwargs) -> ForwardResults:
        """Extract features from raw inputs."""

    @abstractmethod
    def loss(self, inputs: torch.Tensor, data_samples: SampleList,

             **kwargs) -> dict:
        """Calculate losses from a batch of inputs and data samples."""

    def forward(self, inputs, data_samples, mode: str = 'loss'):
        """The unified entry for a forward process in both training and test.



        The method should accept three modes:



        - ``tensor``: Forward the whole network and return tensor or tuple of

        tensor without any post-processing, same as a common nn.Module.

        - ``predict``: Forward and return the predictions, which are fully

        processed to a list of :obj:`ActionDataSample`.

        - ``loss``: Forward and return a dict of losses according to the given

        inputs and data samples.



        Note that this method doesn't handle neither back propagation nor

        optimizer updating, which are done in the :meth:`train_step`.



        Args:

            inputs (torch.Tensor): The input tensor with shape

                (N, C, ...) in general.

            data_samples (List[``ActionDataSample], optional): The

                annotation data of every samples. Defaults to None.

            mode (str): Return what kind of value. Defaults to ``tensor``.



        Returns:

            The return type depends on ``mode``.



            - If ``mode="tensor"``, return a tensor or a tuple of tensor.

            - If ``mode="predict"``, return a list of ``ActionDataSample``.

            - If ``mode="loss"``, return a dict of tensor.

        """

        if mode == 'tensor':
            return self.extract_feat(inputs, data_samples)
        elif mode == 'loss':
            return self.loss(inputs, data_samples)
        elif mode == 'predict':
            return self.predict(inputs, data_samples)
        else:
            raise RuntimeError(f'Invalid mode "{mode}".')

    def encode_vision(self, image):
        """encode image / videos as features.



        Args:

            image (torch.Tensor): The input images.



        Returns: tuple.

            - vision_embeds (torch.Tensor): The features of all patches.

                Shape: [B,T,L,C].

            - pooled_vision_embeds (torch.Tensor): The pooled features.

                Shape: [B,T,C].

        """
        output_dict = self.vision_encoder(image)
        vision_embeds = self.vision_layernorm(output_dict.last_hidden_state)
        pooled_vision_embeds = output_dict.pooler_output

        return vision_embeds, pooled_vision_embeds

    def encode_text(self, text):
        """encode text.

        Args:

            text (dict): The output of huggingface's `PreTrainedTokenizer`.

                contains keys:

                - input_ids (torch.Tensor): Token ids to be fed to a model.

                    Shape: [B,L].

                - attention_mask (torch.Tensor): The mask indicate padded tokens.

                    Shape: [B,L]. 0 is padded token.

                - other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__".  # noqa: E501

        Returns: tuple.

            - text_embeds (torch.Tensor): The features of all tokens. Shape: [B,L,C].

            - pooled_text_embeds (torch.Tensor): The pooled features. Shape: [B,C].



        """
        text_output = self.text_encoder(
            text.input_ids,
            attention_mask=text.attention_mask,
            return_dict=True,
            mode='text',
        )
        text_embeds = text_output.last_hidden_state
        pooled_text_embeds = text_embeds[:, 0]
        return text_embeds, pooled_text_embeds

    @torch.no_grad()
    def clip_contrastive_temperature(self, min_val=0.001, max_val=0.5):
        """Seems only used during pre-training."""
        self.temp.clamp_(min_val, max_val)

    @property
    def device(self):
        return next(self.parameters()).device

    def preprocess_state_dict(self, state_dict):
        """Preprocess pretrained checkpoint for text_encoder."""
        for key in list(state_dict.keys()):
            if 'bert' in key:
                encoder_key = key.replace('bert.', '')
                state_dict[encoder_key] = state_dict[key]
                del state_dict[key]
        return state_dict

    def load_from_pretrainded_beit(self):
        from transformers.models.beit.modeling_beit import BeitModel
        beit2d = BeitModel.from_pretrained(
            self.vision_cfg.pretrained_model_name_or_path)
        ori_state_dict = beit2d.state_dict()
        del beit2d
        # interpolate relative pos bias
        state_dict = interpolate_pos_relative_bias_beit(
            state_dict_old=ori_state_dict,
            state_dict_new=self.vision_encoder.state_dict(),
            patch_shape_new=self.vision_encoder.embeddings.patch_embeddings.
            patch_shape,
        )

        for k in list(state_dict.keys()):
            if 'prompt_bias_table' in k:
                del state_dict[k]

        msg = self.vision_encoder.load_state_dict(state_dict, strict=False)
        logger = MMLogger.get_current_instance()
        logger.info(msg)

    def init_weights(self):
        if self.vision_cfg.get('pretrained2d', False):
            self.load_from_pretrainded_beit()

        if self.pretrined_vl:
            assert self.init_cfg.get('type') == 'Pretrained', (
                'Please specify '
                'init_cfg to use pretrained video-language checkpoint')
            self.pretrained = self.init_cfg.get('checkpoint')
            checkpoint = _load_checkpoint(self.pretrained, map_location='cpu')
            state_dict = checkpoint['model']
            state_dict = interpolate_pos_embed_beit(state_dict, self)
            state_dict = self.preprocess_state_dict(state_dict)
            msg = self.load_state_dict(state_dict, strict=False)
            logger = MMLogger.get_current_instance()
            logger.info(msg)
        else:
            super().init_weights()