ct_detection / mmdetection /tests /test_structures /test_track_data_sample.py
maverickrzw's picture
des
2402804
from unittest import TestCase
import pytest
from mmdet.structures import DetDataSample, TrackDataSample
class TestDetDataSample(TestCase):
def test_init(self):
track_data_sample = TrackDataSample(
metainfo=dict(key_frames_inds=[0], ref_frames_inds=[1]))
assert 'key_frames_inds' in track_data_sample.metainfo and \
'ref_frames_inds' in track_data_sample.metainfo
assert track_data_sample.key_frames_inds == [0]
assert track_data_sample.ref_frames_inds == [1]
with pytest.raises(AssertionError):
track_data_sample.get_key_frames()
with pytest.raises(AssertionError):
track_data_sample.get_ref_frames()
def test_setter(self):
det_data_sample_1 = DetDataSample(
metainfo=dict(scale_factor=(1.5, 1.5)))
det_data_sample_2 = DetDataSample(metainfo=dict(scale_factor=(2., 2.)))
track_data_sample = TrackDataSample(
metainfo=dict(key_frames_inds=[0], ref_frames_inds=[1]))
track_data_sample.video_data_samples = [
det_data_sample_1, det_data_sample_2
]
assert track_data_sample.get_key_frames()[0].scale_factor == (1.5, 1.5)
assert track_data_sample.get_ref_frames()[0].scale_factor == (2., 2.)
def test_deleter(self):
det_data_sample_1 = DetDataSample(
metainfo=dict(scale_factor=(1.5, 1.5)))
det_data_sample_2 = DetDataSample(metainfo=dict(scale_factor=(2., 2.)))
track_data_sample = TrackDataSample(
metainfo=dict(key_frames_inds=[0], ref_frames_inds=[1]))
track_data_sample.video_data_samples = [
det_data_sample_1, det_data_sample_2
]
assert 'video_data_samples' in track_data_sample
del track_data_sample.video_data_samples
assert 'video_data_samples' not in track_data_sample