Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import contextlib | |
| import os | |
| import random | |
| import tempfile | |
| import unittest | |
| import torch | |
| import torchvision.io as io | |
| from densepose.data.transform import ImageResizeTransform | |
| from densepose.data.video import RandomKFramesSelector, VideoKeyframeDataset | |
| try: | |
| import av | |
| except ImportError: | |
| av = None | |
| # copied from torchvision test/test_io.py | |
| def _create_video_frames(num_frames, height, width): | |
| y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width)) | |
| data = [] | |
| for i in range(num_frames): | |
| xc = float(i) / num_frames | |
| yc = 1 - float(i) / (2 * num_frames) | |
| d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255 | |
| data.append(d.unsqueeze(2).repeat(1, 1, 3).byte()) | |
| return torch.stack(data, 0) | |
| # adapted from torchvision test/test_io.py | |
| def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, options=None): | |
| if lossless: | |
| if video_codec is not None: | |
| raise ValueError("video_codec can't be specified together with lossless") | |
| if options is not None: | |
| raise ValueError("options can't be specified together with lossless") | |
| video_codec = "libx264rgb" | |
| options = {"crf": "0"} | |
| if video_codec is None: | |
| video_codec = "libx264" | |
| if options is None: | |
| options = {} | |
| data = _create_video_frames(num_frames, height, width) | |
| with tempfile.NamedTemporaryFile(suffix=".mp4") as f: | |
| f.close() | |
| io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options) | |
| yield f.name, data | |
| os.unlink(f.name) | |
| class TestVideoKeyframeDataset(unittest.TestCase): | |
| def test_read_keyframes_all(self): | |
| with temp_video(60, 300, 300, 5, video_codec="mpeg4") as (fname, data): | |
| video_list = [fname] | |
| category_list = [None] | |
| dataset = VideoKeyframeDataset(video_list, category_list) | |
| self.assertEqual(len(dataset), 1) | |
| data1, categories1 = dataset[0]["images"], dataset[0]["categories"] | |
| self.assertEqual(data1.shape, torch.Size((5, 3, 300, 300))) | |
| self.assertEqual(data1.dtype, torch.float32) | |
| self.assertIsNone(categories1[0]) | |
| return | |
| self.assertTrue(False) | |
| def test_read_keyframes_with_selector(self): | |
| with temp_video(60, 300, 300, 5, video_codec="mpeg4") as (fname, data): | |
| video_list = [fname] | |
| category_list = [None] | |
| random.seed(0) | |
| frame_selector = RandomKFramesSelector(3) | |
| dataset = VideoKeyframeDataset(video_list, category_list, frame_selector) | |
| self.assertEqual(len(dataset), 1) | |
| data1, categories1 = dataset[0]["images"], dataset[0]["categories"] | |
| self.assertEqual(data1.shape, torch.Size((3, 3, 300, 300))) | |
| self.assertEqual(data1.dtype, torch.float32) | |
| self.assertIsNone(categories1[0]) | |
| return | |
| self.assertTrue(False) | |
| def test_read_keyframes_with_selector_with_transform(self): | |
| with temp_video(60, 300, 300, 5, video_codec="mpeg4") as (fname, data): | |
| video_list = [fname] | |
| category_list = [None] | |
| random.seed(0) | |
| frame_selector = RandomKFramesSelector(1) | |
| transform = ImageResizeTransform() | |
| dataset = VideoKeyframeDataset(video_list, category_list, frame_selector, transform) | |
| data1, categories1 = dataset[0]["images"], dataset[0]["categories"] | |
| self.assertEqual(len(dataset), 1) | |
| self.assertEqual(data1.shape, torch.Size((1, 3, 800, 800))) | |
| self.assertEqual(data1.dtype, torch.float32) | |
| self.assertIsNone(categories1[0]) | |
| return | |
| self.assertTrue(False) | |