File size: 1,820 Bytes
a567fa4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
# Copyright (c) Facebook, Inc. and its affiliates.
import contextlib
import os
import tempfile
import unittest
import torch
from torchvision.utils import save_image
from densepose.data.image_list_dataset import ImageListDataset
from densepose.data.transform import ImageResizeTransform
@contextlib.contextmanager
def temp_image(height, width):
random_image = torch.rand(height, width)
with tempfile.NamedTemporaryFile(suffix=".jpg") as f:
f.close()
save_image(random_image, f.name)
yield f.name
os.unlink(f.name)
class TestImageListDataset(unittest.TestCase):
def test_image_list_dataset(self):
height, width = 720, 1280
with temp_image(height, width) as image_fpath:
image_list = [image_fpath]
category_list = [None]
dataset = ImageListDataset(image_list, category_list)
self.assertEqual(len(dataset), 1)
data1, categories1 = dataset[0]["images"], dataset[0]["categories"]
self.assertEqual(data1.shape, torch.Size((1, 3, height, width)))
self.assertEqual(data1.dtype, torch.float32)
self.assertIsNone(categories1[0])
def test_image_list_dataset_with_transform(self):
height, width = 720, 1280
with temp_image(height, width) as image_fpath:
image_list = [image_fpath]
category_list = [None]
transform = ImageResizeTransform()
dataset = ImageListDataset(image_list, category_list, transform)
self.assertEqual(len(dataset), 1)
data1, categories1 = dataset[0]["images"], dataset[0]["categories"]
self.assertEqual(data1.shape, torch.Size((1, 3, 749, 1333)))
self.assertEqual(data1.dtype, torch.float32)
self.assertIsNone(categories1[0])
|