File size: 6,584 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, List, Tuple
import torch
from mmengine.dist import all_gather, get_rank
from mmengine.model import BaseModel
from mmengine.structures import InstanceData
from mmaction.registry import MODELS
from mmaction.utils import ForwardResults, OptSampleList
class GatherLayer(torch.autograd.Function):
"""Gather tensors from all process, supporting backward propagation."""
@staticmethod
def forward(ctx: Any, input: torch.Tensor) -> Tuple[List]:
ctx.save_for_backward(input)
output = all_gather(input)
return tuple(output)
@staticmethod
def backward(ctx: Any, *grads: torch.Tensor) -> torch.Tensor:
input, = ctx.saved_tensors
grad_out = torch.zeros_like(input)
grad_out[:] = grads[get_rank()]
return grad_out
@MODELS.register_module()
class CLIPSimilarity(BaseModel):
"""CLIP-based similarity model.
Args:
clip_arch (str): The architecture of the clip model.
Supported choices are `'ViT-B/32'`, `'ViT-B/16'`,
`'ViT-L/14'` and `'ViT-L/14@336px'`.
data_preprocessor (dict): The pre-process config.
adapter (dict): The 3D adapter config.
to_float32 (bool): Whether to convert the dtype of params of clip
model to float32.
frozen_layers: Layers to be frozen (all params fixed). -1 means
not freezing any parameters. Defaults to -1.
loss (dict): The config of loss. Defaults to
`dict(type='CrossEntropyLoss', loss_weight=0.5)`.
"""
def __init__(
self,
clip_arch: str,
data_preprocessor: Dict[str, Dict],
adapter: Dict,
to_float32: bool = False,
frozen_layers: int = -1,
loss: Dict = dict(type='CrossEntropyLoss', loss_weight=0.5)
) -> None:
super(CLIPSimilarity,
self).__init__(data_preprocessor=data_preprocessor)
try:
import clip
except ImportError:
raise ImportError('Please run `pip install '
'git+https://github.com/openai/CLIP.git` '
'to install clip first. ')
self.clip = clip.load(clip_arch, device='cpu')[0]
if to_float32:
self.clip.float()
self.loss = MODELS.build(loss)
self.adapter = MODELS.build(adapter)
self.frozen_layers = frozen_layers
self._freeze_stages()
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
"""Encode video."""
b, n, c, h, w = video.shape
video = video.view(-1, c, h, w)
frames_features = self.encode_image(video)
frames_features = frames_features.view(b, n, -1)
video_features = self.adapter(frames_features)
return video_features
def encode_image(self, image: torch.Tensor) -> torch.Tensor:
"""Encode image."""
return self.clip.encode_image(image)
def encode_text(self, text: torch.Tensor) -> torch.Tensor:
"""Encode text."""
return self.clip.encode_text(text)
def extract_feat(self,
inputs: Dict[str, torch.Tensor],
norm: bool = True) -> Tuple:
"""Extract features."""
text_inputs = inputs['text']
video_inputs = inputs['imgs']
text_features = self.encode_text(text_inputs)
video_features = self.encode_video(video_inputs)
if norm:
text_features = text_features / text_features.norm(
dim=-1, keepdim=True)
video_features = video_features / video_features.norm(
dim=-1, keepdim=True)
return video_features, text_features
def forward(self,
inputs: Dict[str, torch.Tensor],
data_samples: OptSampleList = None,
mode: str = 'tensor') -> ForwardResults:
"""Forward function."""
if mode == 'tensor':
return self.extract_feat(inputs, norm=False)
elif mode == 'loss':
video_features, text_features = self.extract_feat(inputs)
video_features = torch.cat(
GatherLayer.apply(video_features), dim=0)
text_features = torch.cat(GatherLayer.apply(text_features), dim=0)
logit_scale = self.clip.logit_scale.exp()
logits_per_video = logit_scale * video_features @ text_features.t()
logits_per_text = logits_per_video.t()
labels = torch.arange(logits_per_video.shape[0]).to(
logit_scale.device)
sim_loss_v2t = self.loss(logits_per_video, labels)
sim_loss_t2v = self.loss(logits_per_text, labels)
losses = dict()
losses['sim_loss_v2t'] = sim_loss_v2t
losses['sim_loss_t2v'] = sim_loss_t2v
return losses
elif mode == 'predict':
video_features, text_features = self.extract_feat(inputs)
for ds, vf, tf in zip(data_samples, video_features, text_features):
features = InstanceData(video_feature=vf, text_feature=tf)
ds.features = features
return data_samples
else:
raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode')
def train(self, mode: bool = True) -> None:
"""Set the optimization status when training."""
super().train(mode)
self._freeze_stages()
def _freeze_stages(self) -> None:
"""Prevent all the parameters from being optimized before
``self.frozen_layers``."""
if self.frozen_layers >= 0:
top_layers = [
'ln_final', 'text_projection', 'logit_scale', 'visual.ln_post',
'visual.proj'
]
mid_layers = [
'visual.transformer.resblocks', 'transformer.resblocks'
]
for name, param in self.clip.named_parameters():
if any(name.find(n) == 0 for n in top_layers):
continue
elif any(name.find(n) == 0 for n in mid_layers):
layer_n = int(name.split('.resblocks.')[1].split('.')[0])
if layer_n >= self.frozen_layers:
continue
param.requires_grad = False
|