Deepfake-Detector / tests /models /heads /test_omni_head.py
AZIIIIIIIIZ's picture
Upload 1039 files
d670799 verified
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmaction.models import OmniHead
class obj():
def __init__(self, name, value):
super(obj, self).__init__()
setattr(self, name, value)
def testOmniHead():
head = OmniHead(image_classes=100, video_classes=200, in_channels=400)
image_feat = torch.randn(2, 400, 8, 8)
image_score = head(image_feat)
assert image_score.shape == torch.Size([2, 100])
video_feat = torch.randn(2, 400, 8, 8, 8)
video_score = head(video_feat)
assert video_score.shape == torch.Size([2, 200])
head = OmniHead(
image_classes=100,
video_classes=200,
in_channels=400,
video_nl_head=True)
video_feat = torch.randn(2, 400, 8, 8, 8)
video_score = head(video_feat)
assert video_score.shape == torch.Size([2, 200])
data_samples = [obj('gt_label', torch.tensor(1)) for _ in range(2)]
losses = head.loss_by_feat(video_score, data_samples)
assert 'loss_cls' in losses
image_feat = torch.randn(1, 400, 8, 8)
head.eval()
image_score = head(image_feat)
assert image_score.shape == torch.Size([1, 100])
data_samples = [obj('gt_label', torch.tensor(1))]
losses = head.loss_by_feat(image_score, data_samples)
assert 'loss_cls' in losses