|
|
| import torch
|
|
|
| from mmaction.models import OmniHead
|
|
|
|
|
| class obj():
|
|
|
| def __init__(self, name, value):
|
| super(obj, self).__init__()
|
| setattr(self, name, value)
|
|
|
|
|
| def testOmniHead():
|
| head = OmniHead(image_classes=100, video_classes=200, in_channels=400)
|
|
|
| image_feat = torch.randn(2, 400, 8, 8)
|
| image_score = head(image_feat)
|
| assert image_score.shape == torch.Size([2, 100])
|
|
|
| video_feat = torch.randn(2, 400, 8, 8, 8)
|
| video_score = head(video_feat)
|
| assert video_score.shape == torch.Size([2, 200])
|
|
|
| head = OmniHead(
|
| image_classes=100,
|
| video_classes=200,
|
| in_channels=400,
|
| video_nl_head=True)
|
|
|
| video_feat = torch.randn(2, 400, 8, 8, 8)
|
| video_score = head(video_feat)
|
| assert video_score.shape == torch.Size([2, 200])
|
| data_samples = [obj('gt_label', torch.tensor(1)) for _ in range(2)]
|
| losses = head.loss_by_feat(video_score, data_samples)
|
| assert 'loss_cls' in losses
|
|
|
| image_feat = torch.randn(1, 400, 8, 8)
|
| head.eval()
|
| image_score = head(image_feat)
|
| assert image_score.shape == torch.Size([1, 100])
|
| data_samples = [obj('gt_label', torch.tensor(1))]
|
| losses = head.loss_by_feat(image_score, data_samples)
|
| assert 'loss_cls' in losses
|
|
|