File size: 1,349 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmaction.models import OmniHead


class obj():

    def __init__(self, name, value):
        super(obj, self).__init__()
        setattr(self, name, value)


def testOmniHead():
    head = OmniHead(image_classes=100, video_classes=200, in_channels=400)

    image_feat = torch.randn(2, 400, 8, 8)
    image_score = head(image_feat)
    assert image_score.shape == torch.Size([2, 100])

    video_feat = torch.randn(2, 400, 8, 8, 8)
    video_score = head(video_feat)
    assert video_score.shape == torch.Size([2, 200])

    head = OmniHead(
        image_classes=100,
        video_classes=200,
        in_channels=400,
        video_nl_head=True)

    video_feat = torch.randn(2, 400, 8, 8, 8)
    video_score = head(video_feat)
    assert video_score.shape == torch.Size([2, 200])
    data_samples = [obj('gt_label', torch.tensor(1)) for _ in range(2)]
    losses = head.loss_by_feat(video_score, data_samples)
    assert 'loss_cls' in losses

    image_feat = torch.randn(1, 400, 8, 8)
    head.eval()
    image_score = head(image_feat)
    assert image_score.shape == torch.Size([1, 100])
    data_samples = [obj('gt_label', torch.tensor(1))]
    losses = head.loss_by_feat(image_score, data_samples)
    assert 'loss_cls' in losses