DAP-weights / datasets /insta23k.py
Insta360-Research's picture
Upload 372 files
f4d2177 verified
from __future__ import print_function
import os
import cv2
import numpy as np
import random
import torch
from torch.utils import data
from torchvision import transforms
def read_list(list_file):
rgb_depth_list = []
with open(list_file) as f:
lines = f.readlines()
for line in lines:
rgb_depth_list.append(line.strip().split(" "))
return rgb_depth_list
class Insta23k(data.Dataset):
"""The Structured3D Dataset"""
def __init__(self, root_dir, list_file, height=504, width=1008, color_augmentation=True,
LR_filp_augmentation=True, yaw_rotation_augmentation=True, repeat=1, is_training=False):
"""
Args:
root_dir (string): Directory of the Stanford2D3D Dataset.
list_file (string): Path to the txt file contain the list of image and depth files.
height, width: input size.
disable_color_augmentation, disable_LR_filp_augmentation,
disable_yaw_rotation_augmentation: augmentation options.
is_training (bool): True if the dataset is the training set.
"""
self.root_dir = root_dir
self.rgb_depth_list = read_list(list_file)
self.w = width
self.h = height
self.max_depth_meters = 100.0
self.min_depth_meters = 0.01
self.color_augmentation = color_augmentation
self.LR_filp_augmentation = LR_filp_augmentation
self.yaw_rotation_augmentation = yaw_rotation_augmentation
self.is_training = is_training
if self.color_augmentation:
try:
self.brightness = (0.8, 1.2)
self.contrast = (0.8, 1.2)
self.saturation = (0.8, 1.2)
self.hue = (-0.1, 0.1)
self.color_aug= transforms.ColorJitter(
self.brightness, self.contrast, self.saturation, self.hue)
except TypeError:
self.brightness = 0.2
self.contrast = 0.2
self.saturation = 0.2
self.hue = 0.1
self.color_aug = transforms.ColorJitter(
self.brightness, self.contrast, self.saturation, self.hue)
self.to_tensor = transforms.ToTensor()
self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def __len__(self):
return len(self.rgb_depth_list)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
inputs = {}
rgb_name = os.path.join(self.root_dir, self.rgb_depth_list[idx][0])
rgb = cv2.imread(rgb_name)
if rgb is None:
print(rgb_name)
rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
rgb = cv2.resize(rgb, dsize=(self.w, self.h), interpolation=cv2.INTER_CUBIC)
depth_name = os.path.join(self.root_dir, self.rgb_depth_list[idx][1])
gt_depth = np.load(depth_name).astype(np.float32)
gt_depth = cv2.resize(gt_depth, dsize=(self.w, self.h), interpolation=cv2.INTER_NEAREST)
gt_depth[gt_depth > self.max_depth_meters] = self.max_depth_meters
if self.is_training and self.yaw_rotation_augmentation:
# random yaw rotation
roll_idx = random.randint(0, self.w)
rgb = np.roll(rgb, roll_idx, 1)
gt_depth = np.roll(gt_depth, roll_idx, 1)
if self.is_training and self.LR_filp_augmentation and random.random() > 0.5:
rgb = cv2.flip(rgb, 1)
gt_depth = cv2.flip(gt_depth, 1)
if self.is_training and self.color_augmentation and random.random() > 0.5:
aug_rgb = np.asarray(self.color_aug(transforms.ToPILImage()(rgb)))
else:
aug_rgb = rgb.copy()
aug_rgb = self.to_tensor(aug_rgb.copy())
gt_depth = torch.from_numpy(np.expand_dims(gt_depth, axis=0)).to(torch.float32)
val_mask = ((gt_depth > 0) & (gt_depth <= self.max_depth_meters) & ~torch.isnan(gt_depth))
mask_100 = (gt_depth < 100.0)
# Normalize depth
# _min, _max = torch.quantile(gt_depth[val_mask], torch.tensor([0.02, 1 - 0.02]),)
# gt_depth_norm = (gt_depth - 0.01) / (_max - 0.01)
# gt_depth_norm = torch.clip(gt_depth_norm, 0.01, 1.0)
# print(gt_depth)
gt_depth_norm = gt_depth / self.max_depth_meters
gt_depth_norm = torch.clip(gt_depth_norm, 0.001, 1.0)
# Conduct output
inputs = {}
inputs["rgb"] = self.normalize(aug_rgb)
inputs["gt_depth"] = gt_depth_norm
inputs["val_mask"] = val_mask
inputs["mask_100"] = mask_100
inputs["touying"] = False
return inputs