File size: 5,511 Bytes
436b829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from abc import ABC, abstractmethod
from omegaconf import DictConfig
import cv2
import numpy as np
import imageio
from ppd.utils.logger import Log
import time
import h5py
import torch
from torchvision.transforms import Compose
from PIL import Image


class Dataset(ABC):
    def __init__(self, **kwargs):
        super(Dataset, self).__init__()
        self.cfg = DictConfig(kwargs)
        self.dataset_name = self.cfg.get('dataset_name', 'unknown')
        self.use_low = self.cfg.get('use_low', True)
        self.build_metas()
        self.build_transforms()
        Log.info(
            f'{self.cfg.split} split of {self.dataset_name} dataset: {len(self.rgb_files)} frames in total.')

    @abstractmethod
    def build_metas(self):
        '''
        prepare rgb_files, depth_files, low_files
        '''
        pass
        # depth_files
        # rgb_files

    def build_transforms(self):
        transforms = self.cfg.get('transforms', [])
        if len(transforms) == 0:
            self.transform = lambda x: x
            return
        log_str = f'{self.dataset_name} transform layers: \n'
        for idx, transform in enumerate(transforms):
            log_str += (str(transform) +
                        '\n') if idx != len(transforms) - 1 else str(transform)
        Log.info(log_str)
        self.transform = Compose(transforms)

    def read_rgb(self, index):
        img_path = self.rgb_files[index]
        start_time = time.time()
        rgb = cv2.imread(img_path)
        end_time = time.time()
        if end_time - start_time > 1:
            Log.warn(f'Long time to read {img_path}: {end_time - start_time}')
        rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
        return np.asarray(rgb / 255.).astype(np.float32)

    def read_rgb_name(self, index):
        return '__'.join(self.rgb_files[index].split('/')[-2:])

    def read_depth(self, index, depth=None):
        if not hasattr(self, 'depth_files'):
            return None, None
        Log.debug(index, self.depth_files[index])
        start_time = time.time()
        if depth is not None:
            pass
        elif self.depth_files[index].endswith('.png'):
            depth_path = self.depth_files[index]
            depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR |
                               cv2.IMREAD_ANYDEPTH) / 1000.
        elif self.depth_files[index].endswith('.npz'):
            depth = np.load(self.depth_files[index])['data']
        elif self.depth_files[index].endswith('.hdf5'):
            depth = h5py.File(self.depth_files[index])['dataset']
            depth = np.asarray(depth)
        elif self.depth_files[index].endswith('.npy'):
            depth = np.load(self.depth_files[index])
        else:
            raise ValueError(f"Invalid depth file: {self.depth_files[index]}")
        if len(depth.shape) == 2:
            pass
        elif len(depth.shape) == 3 and depth.shape[2] == 1:
            depth = depth[:, :, 0]
        else:
            raise ValueError(f"Invalid depth file: {self.depth_files[index]}")
        end_time = time.time()
        if end_time - start_time > 1:
            Log.warn(
                f'Long time to read {self.depth_files[index]}: {end_time - start_time}')
        valid_mask = np.logical_and(
            depth > 0.01, ~np.isnan(depth)) & (~np.isinf(depth))
        if valid_mask.sum() == 0:
            Log.warn('No valid mask in the depth map of {}'.format(
                self.depth_files[index]))
        if valid_mask.sum() != 0 and np.isnan(depth).sum() != 0:
            depth[np.isnan(depth)] = depth[valid_mask].max()
        if valid_mask.sum() != 0 and np.isinf(depth).sum() != 0:
            depth[np.isinf(depth)] = depth[valid_mask].max()
        return depth, valid_mask.astype(np.uint8)

    def check_shape(self, rgb, dpt):
        assert (rgb.shape[:2] == dpt.shape[:2]), "rgb.shape: {}, dpt.shape: {}".format(
            rgb.shape, dpt.shape)
        assert (len(rgb.shape) == 3), "rgb.shape: {}".format(rgb.shape)
        assert (len(dpt.shape) == 2), "dpt.shape: {}".format(dpt.shape)

    def __getitem__(self, index):
        index = index % len(self.rgb_files)
        repeat_num = 0
        while True:
            rgb, (dpt, msk) = self.read_rgb(index), self.read_depth(index)
            if dpt is not None:
                self.check_shape(rgb, dpt)
            sample = {
                'image': rgb,
            }
            if dpt is not None:
                sample['depth'] = dpt
                sample['mask'] = msk

            sample = self.transform(sample)
            if 'mask' not in sample or sample['mask'].sum() >= 10:
                break
            else:
                repeat_num += 1
                index = int(np.random.randint(0, len(self.rgb_files)))
                image_name = self.rgb_files[index]
                if repeat_num >= 1:
                    Log.warn(
                        f'No valid mask in the depth map of {image_name}.')
                elif repeat_num > 5:
                    Log.warn(
                        f'No valid mask in the depth map of {image_name}.')
                elif repeat_num > 10:
                    raise ValueError(
                        f'No valid mask in the depth map of {image_name}.')

        sample['dataset_name'] = self.dataset_name
        sample['image_name'] = self.read_rgb_name(index)
        sample['image_path'] = self.rgb_files[index]
        return sample

    def __len__(self):
        return len(self.rgb_files)