File size: 1,349 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmaction.models import OmniHead
class obj():
def __init__(self, name, value):
super(obj, self).__init__()
setattr(self, name, value)
def testOmniHead():
head = OmniHead(image_classes=100, video_classes=200, in_channels=400)
image_feat = torch.randn(2, 400, 8, 8)
image_score = head(image_feat)
assert image_score.shape == torch.Size([2, 100])
video_feat = torch.randn(2, 400, 8, 8, 8)
video_score = head(video_feat)
assert video_score.shape == torch.Size([2, 200])
head = OmniHead(
image_classes=100,
video_classes=200,
in_channels=400,
video_nl_head=True)
video_feat = torch.randn(2, 400, 8, 8, 8)
video_score = head(video_feat)
assert video_score.shape == torch.Size([2, 200])
data_samples = [obj('gt_label', torch.tensor(1)) for _ in range(2)]
losses = head.loss_by_feat(video_score, data_samples)
assert 'loss_cls' in losses
image_feat = torch.randn(1, 400, 8, 8)
head.eval()
image_score = head(image_feat)
assert image_score.shape == torch.Size([1, 100])
data_samples = [obj('gt_label', torch.tensor(1))]
losses = head.loss_by_feat(image_score, data_samples)
assert 'loss_cls' in losses
|