AZIIIIIIIIZ's picture
Upload 1039 files
d670799 verified
raw
history blame
6.94 kB
from typing import Any, Dict, List, Optional, Tuple, Union
import clip
import mmengine
import numpy as np
import torch
import torch.nn.functional as F
from mmengine.dist import all_gather, get_rank
from mmengine.model import BaseModel
from mmengine.structures import LabelData
from mmaction.registry import MODELS
from .adapter import TransformerAdapter
class GatherLayer(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, input: torch.Tensor) -> Tuple[List]:
ctx.save_for_backward(input)
output = all_gather(input)
return tuple(output)
@staticmethod
def backward(ctx: Any, *grads: torch.Tensor) -> torch.Tensor:
input, = ctx.saved_tensors
grad_out = torch.zeros_like(input)
grad_out[:] = grads[get_rank()]
return grad_out
def text_prompt(labels_or_label_file, templates_or_template_file=None):
if isinstance(labels_or_label_file, str):
labels = mmengine.list_from_file(labels_or_label_file)
elif isinstance(labels_or_label_file, list):
labels = labels_or_label_file
else:
raise ValueError(f'`labels_or_label_file` must be `list` or `str`, '
f'but got {type(labels_or_label_file)}')
if templates_or_template_file is None:
templates = [
'a photo of action {}', 'a picture of action {}',
'Human action of {}', '{}, an action', '{} this is an action',
'{}, a video of action', 'Playing action of {}', '{}',
'Playing a kind of action, {}', 'Doing a kind of action, {}',
'Look, the human is {}', 'Can you recognize the action of {}?',
'Video classification of {}', 'A video of {}', 'The man is {}',
'The woman is {}'
]
elif isinstance(templates_or_template_file, str):
templates = mmengine.list_from_file(templates_or_template_file)
elif not mmengine.is_seq_of(templates_or_template_file, str):
raise ValueError(f'`template` must be list of `str`, `str` or `None`, '
f'but got {type(templates_or_template_file)}')
num_prompt = len(templates)
prompt = torch.cat(
[clip.tokenize(t.format(c)) for t in templates for c in labels])
return prompt, num_prompt
@MODELS.register_module()
class ActionClip(BaseModel):
def __init__(self,
clip_arch: str,
num_adapter_segs: int,
num_adapter_layers: int = 6,
to_float32: bool = False,
labels_or_label_file: Optional[Union[List[str], str]] = None,
templates_or_template_file: Optional[Union[List[str],
str]] = None,
data_preprocessor: Optional[Dict] = None,
loss: Dict = dict(type='CrossEntropyLoss', loss_weight=0.5)):
super(ActionClip, self).__init__(data_preprocessor=data_preprocessor)
self.clip = clip.load(clip_arch, device='cpu')[0]
if to_float32:
self.clip.float()
self.adapter = TransformerAdapter(self.clip, num_adapter_segs,
num_adapter_layers)
self.loss = MODELS.build(loss)
if labels_or_label_file is not None:
self.prompt, self.num_prompt = text_prompt(
labels_or_label_file, templates_or_template_file)
def encode_video(self, video):
b, n, c, h, w = video.shape
video = video.view(-1, c, h, w)
frames_features = self.encode_image(video)
frames_features = frames_features.view(b, n, -1)
video_features = self.adapter(frames_features)
return video_features
def encode_image(self, image):
return self.clip.encode_image(image)
def encode_text(self, text):
return self.clip.encode_text(text)
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List] = None,
mode: str = 'tensor'):
if mode == 'tensor':
return self.encode_video(inputs)
elif mode == 'predict':
assert hasattr(self, 'prompt'),\
'`labels_or_label_file` is required to perform prediction. '
video_features = self.encode_video(inputs)
video_features = video_features / video_features.norm(
dim=-1, keepdim=True)
bsz = len(data_samples)
num_views = video_features.shape[0] // bsz
text_features = self.encode_text(self.prompt.to(inputs.device))
text_features = text_features / text_features.norm(
dim=-1, keepdim=True)
# (bsz*num_views, num_prompt, num_classes) ->
# (bsz, num_views*num_prompt, num_classes)
similarity = (100.0 * video_features @ text_features.T). \
view(bsz, num_views * self.num_prompt, -1)
cls_scores = F.softmax(similarity, dim=2).mean(dim=1)
for data_sample, score in zip(data_samples, cls_scores):
data_sample.pred_scores = LabelData(item=score)
return data_samples
elif mode == 'loss':
video_features = self.encode_video(inputs)
video_features = video_features / video_features.norm(
dim=-1, keepdim=True)
text_id = np.random.randint(
self.num_prompt, size=len(data_samples))
real_labels = [x.gt_labels.item.item() for x in data_samples]
selected_prompt = self.prompt.view(
self.num_prompt, -1,
self.prompt.shape[-1])[text_id, real_labels].to(inputs.device)
text_features = self.encode_text(selected_prompt)
text_features = text_features / text_features.norm(
dim=-1, keepdim=True)
video_features = torch.cat(
GatherLayer.apply(video_features), dim=0)
text_features = torch.cat(GatherLayer.apply(text_features), dim=0)
logit_scale = self.clip.logit_scale.exp()
logits_per_video = logit_scale * video_features @ text_features.t()
logits_per_text = logits_per_video.t()
labels = torch.arange(logits_per_video.shape[0]).to(
logit_scale.device)
sim_loss_v2t = self.loss(logits_per_video, labels)
sim_loss_t2v = self.loss(logits_per_text, labels)
losses = dict()
losses['sim_loss_v2t'] = sim_loss_v2t
losses['sim_loss_t2v'] = sim_loss_t2v
return losses
else:
raise RuntimeError(
f'Invalid mode "{mode}". '
'Only supports `predict`, `loss` and `tensor` mode. ')