File size: 9,282 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from typing import Optional
import torch
from mmengine.logging import MMLogger
from mmengine.model import BaseModel
from mmengine.runner.checkpoint import _load_checkpoint
from torch import nn
from mmaction.registry import MODELS, TOKENIZER
from mmaction.utils import ForwardResults, SampleList
from .utils import (interpolate_pos_embed_beit,
interpolate_pos_relative_bias_beit)
class VindLUBase(BaseModel):
"""VindLU base Model.
Args:
tokenizer: (dict): The config for tokenizer.
vision_encoder (dict): Backbone for extracting image features.
text_encoder (dict): Backbone for extracting text features.
temperature (float): Temperature parameter that controls the
concentration level of the distribution. Defaults to 0.07.
gradient_checkpointing (bool): Whether to do gradient_checkpointing.
Using checkpoint will save some memory while slowing down the
training speed. Defaults to False.
data_preprocessor (Optional[dict]): The config for preprocessing input
data.
init_cfg (Optional[dict]): the config to control the initialization.
Defaults to None.
"""
def __init__(
self,
tokenizer: dict,
vision_encoder: dict,
text_encoder: dict,
proj_dim: int = 256,
temperature: float = 0.07,
gradient_checkpointing: bool = False,
pretrined_vl: bool = True,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[dict] = None,
):
if data_preprocessor is None:
data_preprocessor = dict(type='ActionDataPreprocessor')
super().__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
self.tokenizer = TOKENIZER.build(tokenizer)
self.vision_cfg = vision_encoder
self.text_encoder_cfg = text_encoder
self.gradient_checkpointing = gradient_checkpointing
self.text_encoder_cfg.gradient_checkpointing = gradient_checkpointing
self.vision_width = vision_encoder.pop('encoder_width')
self.text_width = text_encoder.encoder_width
self.pretrined_vl = pretrined_vl
if self.vision_cfg.pop('add_ln'):
self.vision_layernorm = nn.LayerNorm(self.vision_width, eps=1e-12)
else:
self.vision_layernorm = nn.Identity()
self.vision_encoder = MODELS.build(self.vision_cfg)
if gradient_checkpointing:
self.vision_encoder.gradient_checkpointing_enable()
self.text_encoder = MODELS.build(self.text_encoder_cfg)
self.vision_proj = nn.Linear(self.vision_width, proj_dim)
self.text_proj = nn.Linear(self.text_width, proj_dim)
self.temp = nn.parameter.Parameter(torch.ones([]) * temperature)
self.itm_head = nn.Linear(self.text_width, 2)
def extract_feat(self, inputs: torch.Tensor, **kwargs) -> ForwardResults:
"""Extract features from raw inputs."""
@abstractmethod
def loss(self, inputs: torch.Tensor, data_samples: SampleList,
**kwargs) -> dict:
"""Calculate losses from a batch of inputs and data samples."""
def forward(self, inputs, data_samples, mode: str = 'loss'):
"""The unified entry for a forward process in both training and test.
The method should accept three modes:
- ``tensor``: Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- ``predict``: Forward and return the predictions, which are fully
processed to a list of :obj:`ActionDataSample`.
- ``loss``: Forward and return a dict of losses according to the given
inputs and data samples.
Note that this method doesn't handle neither back propagation nor
optimizer updating, which are done in the :meth:`train_step`.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[``ActionDataSample], optional): The
annotation data of every samples. Defaults to None.
mode (str): Return what kind of value. Defaults to ``tensor``.
Returns:
The return type depends on ``mode``.
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="predict"``, return a list of ``ActionDataSample``.
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'tensor':
return self.extract_feat(inputs, data_samples)
elif mode == 'loss':
return self.loss(inputs, data_samples)
elif mode == 'predict':
return self.predict(inputs, data_samples)
else:
raise RuntimeError(f'Invalid mode "{mode}".')
def encode_vision(self, image):
"""encode image / videos as features.
Args:
image (torch.Tensor): The input images.
Returns: tuple.
- vision_embeds (torch.Tensor): The features of all patches.
Shape: [B,T,L,C].
- pooled_vision_embeds (torch.Tensor): The pooled features.
Shape: [B,T,C].
"""
output_dict = self.vision_encoder(image)
vision_embeds = self.vision_layernorm(output_dict.last_hidden_state)
pooled_vision_embeds = output_dict.pooler_output
return vision_embeds, pooled_vision_embeds
def encode_text(self, text):
"""encode text.
Args:
text (dict): The output of huggingface's `PreTrainedTokenizer`.
contains keys:
- input_ids (torch.Tensor): Token ids to be fed to a model.
Shape: [B,L].
- attention_mask (torch.Tensor): The mask indicate padded tokens.
Shape: [B,L]. 0 is padded token.
- other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__". # noqa: E501
Returns: tuple.
- text_embeds (torch.Tensor): The features of all tokens. Shape: [B,L,C].
- pooled_text_embeds (torch.Tensor): The pooled features. Shape: [B,C].
"""
text_output = self.text_encoder(
text.input_ids,
attention_mask=text.attention_mask,
return_dict=True,
mode='text',
)
text_embeds = text_output.last_hidden_state
pooled_text_embeds = text_embeds[:, 0]
return text_embeds, pooled_text_embeds
@torch.no_grad()
def clip_contrastive_temperature(self, min_val=0.001, max_val=0.5):
"""Seems only used during pre-training."""
self.temp.clamp_(min_val, max_val)
@property
def device(self):
return next(self.parameters()).device
def preprocess_state_dict(self, state_dict):
"""Preprocess pretrained checkpoint for text_encoder."""
for key in list(state_dict.keys()):
if 'bert' in key:
encoder_key = key.replace('bert.', '')
state_dict[encoder_key] = state_dict[key]
del state_dict[key]
return state_dict
def load_from_pretrainded_beit(self):
from transformers.models.beit.modeling_beit import BeitModel
beit2d = BeitModel.from_pretrained(
self.vision_cfg.pretrained_model_name_or_path)
ori_state_dict = beit2d.state_dict()
del beit2d
# interpolate relative pos bias
state_dict = interpolate_pos_relative_bias_beit(
state_dict_old=ori_state_dict,
state_dict_new=self.vision_encoder.state_dict(),
patch_shape_new=self.vision_encoder.embeddings.patch_embeddings.
patch_shape,
)
for k in list(state_dict.keys()):
if 'prompt_bias_table' in k:
del state_dict[k]
msg = self.vision_encoder.load_state_dict(state_dict, strict=False)
logger = MMLogger.get_current_instance()
logger.info(msg)
def init_weights(self):
if self.vision_cfg.get('pretrained2d', False):
self.load_from_pretrainded_beit()
if self.pretrined_vl:
assert self.init_cfg.get('type') == 'Pretrained', (
'Please specify '
'init_cfg to use pretrained video-language checkpoint')
self.pretrained = self.init_cfg.get('checkpoint')
checkpoint = _load_checkpoint(self.pretrained, map_location='cpu')
state_dict = checkpoint['model']
state_dict = interpolate_pos_embed_beit(state_dict, self)
state_dict = self.preprocess_state_dict(state_dict)
msg = self.load_state_dict(state_dict, strict=False)
logger = MMLogger.get_current_instance()
logger.info(msg)
else:
super().init_weights()
|