|
|
|
|
|
import torch
|
|
|
|
|
|
from mmaction.models import OmniHead
|
|
|
|
|
|
|
|
|
class obj():
|
|
|
|
|
|
def __init__(self, name, value):
|
|
|
super(obj, self).__init__()
|
|
|
setattr(self, name, value)
|
|
|
|
|
|
|
|
|
def testOmniHead():
|
|
|
head = OmniHead(image_classes=100, video_classes=200, in_channels=400)
|
|
|
|
|
|
image_feat = torch.randn(2, 400, 8, 8)
|
|
|
image_score = head(image_feat)
|
|
|
assert image_score.shape == torch.Size([2, 100])
|
|
|
|
|
|
video_feat = torch.randn(2, 400, 8, 8, 8)
|
|
|
video_score = head(video_feat)
|
|
|
assert video_score.shape == torch.Size([2, 200])
|
|
|
|
|
|
head = OmniHead(
|
|
|
image_classes=100,
|
|
|
video_classes=200,
|
|
|
in_channels=400,
|
|
|
video_nl_head=True)
|
|
|
|
|
|
video_feat = torch.randn(2, 400, 8, 8, 8)
|
|
|
video_score = head(video_feat)
|
|
|
assert video_score.shape == torch.Size([2, 200])
|
|
|
data_samples = [obj('gt_label', torch.tensor(1)) for _ in range(2)]
|
|
|
losses = head.loss_by_feat(video_score, data_samples)
|
|
|
assert 'loss_cls' in losses
|
|
|
|
|
|
image_feat = torch.randn(1, 400, 8, 8)
|
|
|
head.eval()
|
|
|
image_score = head(image_feat)
|
|
|
assert image_score.shape == torch.Size([1, 100])
|
|
|
data_samples = [obj('gt_label', torch.tensor(1))]
|
|
|
losses = head.loss_by_feat(image_score, data_samples)
|
|
|
assert 'loss_cls' in losses
|
|
|
|