File size: 6,942 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
from typing import Any, Dict, List, Optional, Tuple, Union
import clip
import mmengine
import numpy as np
import torch
import torch.nn.functional as F
from mmengine.dist import all_gather, get_rank
from mmengine.model import BaseModel
from mmengine.structures import LabelData
from mmaction.registry import MODELS
from .adapter import TransformerAdapter
class GatherLayer(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, input: torch.Tensor) -> Tuple[List]:
ctx.save_for_backward(input)
output = all_gather(input)
return tuple(output)
@staticmethod
def backward(ctx: Any, *grads: torch.Tensor) -> torch.Tensor:
input, = ctx.saved_tensors
grad_out = torch.zeros_like(input)
grad_out[:] = grads[get_rank()]
return grad_out
def text_prompt(labels_or_label_file, templates_or_template_file=None):
if isinstance(labels_or_label_file, str):
labels = mmengine.list_from_file(labels_or_label_file)
elif isinstance(labels_or_label_file, list):
labels = labels_or_label_file
else:
raise ValueError(f'`labels_or_label_file` must be `list` or `str`, '
f'but got {type(labels_or_label_file)}')
if templates_or_template_file is None:
templates = [
'a photo of action {}', 'a picture of action {}',
'Human action of {}', '{}, an action', '{} this is an action',
'{}, a video of action', 'Playing action of {}', '{}',
'Playing a kind of action, {}', 'Doing a kind of action, {}',
'Look, the human is {}', 'Can you recognize the action of {}?',
'Video classification of {}', 'A video of {}', 'The man is {}',
'The woman is {}'
]
elif isinstance(templates_or_template_file, str):
templates = mmengine.list_from_file(templates_or_template_file)
elif not mmengine.is_seq_of(templates_or_template_file, str):
raise ValueError(f'`template` must be list of `str`, `str` or `None`, '
f'but got {type(templates_or_template_file)}')
num_prompt = len(templates)
prompt = torch.cat(
[clip.tokenize(t.format(c)) for t in templates for c in labels])
return prompt, num_prompt
@MODELS.register_module()
class ActionClip(BaseModel):
def __init__(self,
clip_arch: str,
num_adapter_segs: int,
num_adapter_layers: int = 6,
to_float32: bool = False,
labels_or_label_file: Optional[Union[List[str], str]] = None,
templates_or_template_file: Optional[Union[List[str],
str]] = None,
data_preprocessor: Optional[Dict] = None,
loss: Dict = dict(type='CrossEntropyLoss', loss_weight=0.5)):
super(ActionClip, self).__init__(data_preprocessor=data_preprocessor)
self.clip = clip.load(clip_arch, device='cpu')[0]
if to_float32:
self.clip.float()
self.adapter = TransformerAdapter(self.clip, num_adapter_segs,
num_adapter_layers)
self.loss = MODELS.build(loss)
if labels_or_label_file is not None:
self.prompt, self.num_prompt = text_prompt(
labels_or_label_file, templates_or_template_file)
def encode_video(self, video):
b, n, c, h, w = video.shape
video = video.view(-1, c, h, w)
frames_features = self.encode_image(video)
frames_features = frames_features.view(b, n, -1)
video_features = self.adapter(frames_features)
return video_features
def encode_image(self, image):
return self.clip.encode_image(image)
def encode_text(self, text):
return self.clip.encode_text(text)
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List] = None,
mode: str = 'tensor'):
if mode == 'tensor':
return self.encode_video(inputs)
elif mode == 'predict':
assert hasattr(self, 'prompt'),\
'`labels_or_label_file` is required to perform prediction. '
video_features = self.encode_video(inputs)
video_features = video_features / video_features.norm(
dim=-1, keepdim=True)
bsz = len(data_samples)
num_views = video_features.shape[0] // bsz
text_features = self.encode_text(self.prompt.to(inputs.device))
text_features = text_features / text_features.norm(
dim=-1, keepdim=True)
# (bsz*num_views, num_prompt, num_classes) ->
# (bsz, num_views*num_prompt, num_classes)
similarity = (100.0 * video_features @ text_features.T). \
view(bsz, num_views * self.num_prompt, -1)
cls_scores = F.softmax(similarity, dim=2).mean(dim=1)
for data_sample, score in zip(data_samples, cls_scores):
data_sample.pred_scores = LabelData(item=score)
return data_samples
elif mode == 'loss':
video_features = self.encode_video(inputs)
video_features = video_features / video_features.norm(
dim=-1, keepdim=True)
text_id = np.random.randint(
self.num_prompt, size=len(data_samples))
real_labels = [x.gt_labels.item.item() for x in data_samples]
selected_prompt = self.prompt.view(
self.num_prompt, -1,
self.prompt.shape[-1])[text_id, real_labels].to(inputs.device)
text_features = self.encode_text(selected_prompt)
text_features = text_features / text_features.norm(
dim=-1, keepdim=True)
video_features = torch.cat(
GatherLayer.apply(video_features), dim=0)
text_features = torch.cat(GatherLayer.apply(text_features), dim=0)
logit_scale = self.clip.logit_scale.exp()
logits_per_video = logit_scale * video_features @ text_features.t()
logits_per_text = logits_per_video.t()
labels = torch.arange(logits_per_video.shape[0]).to(
logit_scale.device)
sim_loss_v2t = self.loss(logits_per_video, labels)
sim_loss_t2v = self.loss(logits_per_text, labels)
losses = dict()
losses['sim_loss_v2t'] = sim_loss_v2t
losses['sim_loss_t2v'] = sim_loss_t2v
return losses
else:
raise RuntimeError(
f'Invalid mode "{mode}". '
'Only supports `predict`, `loss` and `tensor` mode. ')
|