Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import h5py | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from torch.utils.data import DataLoader, Dataset | |
| from torchvision import transforms | |
| def hypersim_distance_to_depth(npyDistance): | |
| intWidth = 1024 | |
| intHeight = 768 | |
| fltFocal = 886.81 | |
| npyImageplaneX = np.linspace((-0.5 * intWidth) + 0.5, (0.5 * intWidth) - 0.5, intWidth).reshape( | |
| 1, intWidth).repeat(intHeight, 0).astype(np.float32)[:, :, None] | |
| npyImageplaneY = np.linspace((-0.5 * intHeight) + 0.5, (0.5 * intHeight) - 0.5, | |
| intHeight).reshape(intHeight, 1).repeat(intWidth, 1).astype(np.float32)[:, :, None] | |
| npyImageplaneZ = np.full([intHeight, intWidth, 1], fltFocal, np.float32) | |
| npyImageplane = np.concatenate( | |
| [npyImageplaneX, npyImageplaneY, npyImageplaneZ], 2) | |
| npyDepth = npyDistance / np.linalg.norm(npyImageplane, 2, 2) * fltFocal | |
| return npyDepth | |
| def creat_uv_mesh(H, W): | |
| y, x = np.meshgrid(np.arange(0, H, dtype=np.float64), | |
| np.arange(0, W, dtype=np.float64), indexing='ij') | |
| meshgrid = np.stack((x, y)) | |
| ones = np.ones((1, H*W), dtype=np.float64) | |
| xy = meshgrid.reshape(2, -1) | |
| return np.concatenate([xy, ones], axis=0) | |
| # Some Hypersim normals are not properly oriented towards the camera. | |
| # The align_normals and creat_uv_mesh functions are from GeoWizard | |
| # https://github.com/fuxiao0719/GeoWizard/blob/5ff496579c6be35d9d86fe4d0760a6b5e6ba25c5/geowizard/training/dataloader/file_io.py#L79 | |
| def align_normals(normal, depth, K, H, W): | |
| ''' | |
| Orientation of surface normals in hypersim is not always consistent | |
| see https://github.com/apple/ml-hypersim/issues/26 | |
| ''' | |
| # inv K | |
| K = np.array([[K[0], 0, K[2]], | |
| [0, K[1], K[3]], | |
| [0, 0, 1]]) | |
| inv_K = np.linalg.inv(K) | |
| # reprojection depth to camera points | |
| xy = creat_uv_mesh(H, W) | |
| points = np.matmul(inv_K[:3, :3], xy).reshape(3, H, W) | |
| points = depth * points | |
| points = points.transpose((1, 2, 0)) | |
| # align normal | |
| orient_mask = np.sum(normal * points, axis=2) < 0 | |
| normal[orient_mask] *= -1 | |
| return normal | |
| class HypersimImageDepthNormalTransform: | |
| def __init__(self, size, random_flip, norm_type, truncnorm_min=0.02, align_cam_normal=False) -> None: | |
| self.size = size | |
| self.random_flip = random_flip | |
| self.norm_type = norm_type | |
| self.truncnorm_min = truncnorm_min | |
| self.truncnorm_max = 1 - truncnorm_min | |
| self.d_max = 65 | |
| self.align_cam_normal = align_cam_normal | |
| def to_tensor_and_resize_normal(self, normal): | |
| # to tensor | |
| normal = torch.from_numpy(normal).permute(2, 0, 1).unsqueeze(0) | |
| # resize | |
| normal = F.interpolate(normal, size=self.size, | |
| mode='nearest').squeeze() | |
| # shape = 3 * 768 * 1024 | |
| return normal | |
| def __call__(self, image, depth, normal): | |
| # convert the inward normals to outward normals | |
| normal[:, :, 0] *= -1 | |
| if self.align_cam_normal: | |
| # align normal towards camera | |
| H, W = normal.shape[:2] | |
| normal = align_normals( | |
| normal, depth, [886.81, 886.81, W/2, H/2], H, W) | |
| # resize | |
| image = transforms.functional.resize( | |
| image, self.size, interpolation=Image.BILINEAR) | |
| depth = torch.from_numpy(depth).unsqueeze(0).unsqueeze(0) | |
| depth = F.interpolate(depth, size=self.size, mode='nearest').squeeze() | |
| normal = self.to_tensor_and_resize_normal(normal) | |
| # random flip | |
| if self.random_flip and random.random() > 0.5: | |
| image = transforms.functional.hflip(image) | |
| depth = torch.flip(depth, [-1]) | |
| normal = torch.flip(normal, [-1]) | |
| # Flip x-component of normal map | |
| normal[0, :, :] = - normal[0, :, :] | |
| # to_tensor and normalize | |
| # image | |
| image = transforms.ToTensor()(image) | |
| # image = transforms.Normalize([0.5], [0.5])(image) | |
| # depth | |
| if self.norm_type == 'instnorm': | |
| dmin = depth.min() | |
| dmax = depth.max() | |
| # depth_norm = ((depth - dmin)/(dmax - dmin + 1e-5) - 0.5) * 2.0 | |
| elif self.norm_type == 'truncnorm': | |
| # refer to Marigold | |
| dmin = torch.quantile(depth, self.truncnorm_min) | |
| dmax = torch.quantile(depth, self.truncnorm_max) | |
| # depth_norm = ((depth - dmin)/(dmax - dmin + 1e-5) - 0.5) * 2.0 | |
| elif self.norm_type == 'perscene_norm': | |
| pass | |
| # depth_norm = ((depth / self.d_max) - 0.5) * 2.0 | |
| elif self.norm_type == "disparity": | |
| disparity = 1 / depth | |
| disparity_min = disparity.min() | |
| disparity_max = disparity.max() | |
| # disparity_norm = ((disparity - disparity_min) / | |
| # (disparity_max - disparity_min + 1e-5)) | |
| depth_norm = disparity | |
| elif self.norm_type == "trunc_disparity": | |
| disparity = 1 / depth | |
| disparity_min = torch.quantile(disparity, self.truncnorm_min) | |
| disparity_max = torch.quantile(disparity, self.truncnorm_max) | |
| disparity_norm = ((disparity - disparity_min) / | |
| (disparity_max - disparity_min + 1e-5)) | |
| depth_norm = disparity_norm | |
| else: | |
| raise TypeError( | |
| f"Not supported normalization type: {self.norm_type}. ") | |
| depth_norm = depth_norm.clip(0, 1) | |
| depth = depth_norm.unsqueeze(0).repeat(3, 1, 1) | |
| # normal | |
| normal = normal.clip(-1, 1) | |
| return image, depth, normal | |
| class HypersimDataset(Dataset): | |
| def __init__(self, data_dir, random_flip, norm_type, resolution=(480, 720), | |
| truncnorm_min=0.02, align_cam_normal=False, split="train", start=0, train_ratio=1.0): | |
| self.data_list = [] | |
| split_dir = os.path.join(data_dir, split) | |
| # 搜索所有 tonemap.jpg | |
| for root, dirs, files in os.walk(split_dir): | |
| for file in files: | |
| if file.endswith("tonemap.jpg"): | |
| img = os.path.join(root, file) | |
| dep = img.replace("final_preview", "geometry_hdf5").replace( | |
| "tonemap.jpg", "depth_meters.hdf5") | |
| nor = img.replace("final_preview", "geometry_hdf5").replace( | |
| "tonemap.jpg", "normal_cam.hdf5") | |
| self.data_list.append((img, dep, nor)) | |
| self.data_list.sort() | |
| self.data_list = self.data_list[start:] | |
| # print( | |
| # f"Total {len(self.data_list)} samples found for {split} set, first ten samples: {self.data_list[:10]}") | |
| # # compute new resolution | |
| # w, h = Image.open(self.data_list[0][0]).size | |
| # if h > w: | |
| # new_w = resolution | |
| # new_h = int(resolution * h / w) | |
| # else: | |
| # new_h = resolution | |
| # new_w = int(resolution * w / h) | |
| # print(f"Resizing to {resolution}") | |
| new_h, new_w = resolution | |
| self.new_h = new_h | |
| self.new_w = new_w | |
| self.transform = HypersimImageDepthNormalTransform( | |
| (new_h, new_w), random_flip, norm_type, truncnorm_min, align_cam_normal | |
| ) | |
| if train_ratio < 1.0: | |
| origin_len = len(self.data_list) | |
| self.data_list = self.data_list[:int(origin_len * train_ratio)] | |
| print( | |
| f"Hypersim use {int(origin_len * train_ratio)} samples instead of {origin_len}...") | |
| else: | |
| print( | |
| f"Hypersim use origin {len(self.data_list)} samples...") | |
| def __len__(self): | |
| return len(self.data_list) | |
| def __getitem__(self, idx): | |
| idx = idx % len(self.data_list) | |
| try: | |
| img_path, dep_path, nor_path = self.data_list[idx] | |
| image = Image.open(img_path).convert("RGB") | |
| # load depth (distance → depth) | |
| with h5py.File(dep_path, 'r') as f: | |
| dist = np.array(f["dataset"]) | |
| depth = hypersim_distance_to_depth(dist) | |
| raw_depth = torch.from_numpy(depth).unsqueeze(0).unsqueeze(0) | |
| raw_depth = F.interpolate(raw_depth, size=( | |
| self.new_h, self.new_w), mode='nearest').squeeze() | |
| raw_depth = torch.clamp(raw_depth, 1e-3, 65).repeat(3, 1, 1) | |
| # load normals | |
| with h5py.File(nor_path, 'r') as f: | |
| normal = np.array(f["dataset"]) | |
| image, depth, normal = self.transform(image, depth, normal) | |
| if torch.isnan(image).any() or torch.isinf(image).any(): | |
| print( | |
| f"Error loading data at index {idx}: image is nan or inf") | |
| return self.__getitem__(idx+1) | |
| if torch.isnan(depth).any() or torch.isinf(depth).any(): | |
| print( | |
| f"Error loading data at index {idx}: depth is nan or inf") | |
| return self.__getitem__(idx+1) | |
| return { | |
| "sample_idx": torch.tensor(idx), | |
| "images": image.unsqueeze(0), | |
| "disparity": depth.unsqueeze(0), | |
| 'depth': raw_depth.unsqueeze(0), | |
| "normal_values": normal, | |
| "image_path": img_path, | |
| "depth_path": dep_path, | |
| "normal_path": nor_path, | |
| } | |
| except Exception as e: | |
| print(f"Error loading data at index {idx}: {e}") | |
| # In case of error, return a random sample | |
| return self.__getitem__(idx+1) | |
| def collate_fn_hypersim(batch): | |
| images = torch.stack([b["images"] for b in batch]).float() | |
| disparity = torch.stack([b["disparity"] for b in batch]).float() | |
| normal_values = torch.stack([b["normal_values"] for b in batch]).float() | |
| depth = torch.stack([b["depth"] for b in batch]).float() | |
| return { | |
| "images": images, | |
| "disparity": disparity, | |
| 'depth': depth, | |
| # "normal_values | |
| "normal_values": normal_values, | |
| "image_paths": [b["image_path"] for b in batch], | |
| "depth_paths": [b["depth_path"] for b in batch], | |
| "normal_paths": [b["normal_path"] for b in batch], | |
| } | |
| if __name__ == "__main__": | |
| import matplotlib.cm as cm | |
| import torch | |
| import torchvision.transforms.functional as TF | |
| import torchvision.utils as vutils | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| args = OmegaConf.load("configs/hypersim.yaml") | |
| dataset = HypersimDataset( | |
| data_dir=args.train_data_dir_hypersim, | |
| resolution=args.resolution_hypersim, | |
| random_flip=args.random_flip, | |
| norm_type=args.norm_type, | |
| truncnorm_min=args.truncnorm_min, | |
| align_cam_normal=args.align_cam_normal, | |
| split="train", | |
| ) | |
| print(f"Dataset length: {len(dataset)}") | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=args.train_batch_size, | |
| shuffle=True, | |
| num_workers=args.dataloader_num_workers, | |
| pin_memory=True, | |
| collate_fn=collate_fn_hypersim, | |
| ) | |
| dir = 'debug' | |
| os.makedirs(dir, exist_ok=True) | |
| print(f"Length of dataset {len(dataset)}") | |
| for step, batch in enumerate(dataloader): | |
| if step >= 50: | |
| break | |
| print(f"Step {step}:") | |
| print( | |
| f" images: {batch['images'].shape} range {batch['images'].min().item()} - {batch['images'].max().item()}") | |
| print( | |
| f"Depth: {batch['depth'].shape} range {batch['depth'].min().item()} - {batch['depth'].max().item()}") | |
| # 取第一张 | |
| img = batch["images"][0, 0] # [3, H, W] | |
| depth = batch["depth"][0, 0] # [3, H, W](如果是3通道) | |
| # -------------------- | |
| # 保存 RGB | |
| # -------------------- | |
| img_to_save = img.clamp(0, 1) | |
| img_pil = TF.to_pil_image(img_to_save.cpu()) | |
| img_pil.save(os.path.join(dir, f"step_{step}_rgb.png")) | |
| # -------------------- | |
| # 保存 Depth(归一化后再存) | |
| # -------------------- | |
| depth_single = depth[0] # 如果是3通道,取第一通道 | |
| depth_min = depth_single.min() | |
| depth_max = depth_single.max() | |
| depth_norm = (depth_single - depth_min) / \ | |
| (depth_max - depth_min + 1e-8) | |
| # 转 numpy | |
| depth_np = depth_norm.cpu().numpy() | |
| # 用 Spectral colormap 映射成 RGB | |
| depth_color = cm.Spectral(depth_np)[:, :, :3] # 去掉alpha通道 | |
| # 转成 0–255 uint8 | |
| depth_color = (depth_color * 255).astype(np.uint8) | |
| depth_pil = Image.fromarray(depth_color) | |
| depth_pil.save(os.path.join(dir, f"step_{step}_depth_spectral.png")) | |
| # print( | |
| # f" disparity: {batch['disparity'].shape}, range {batch['disparity'].min().item()} - {batch['disparity'].max().item()}") | |
| # print( | |
| # f" normal_values: {batch['normal_values'].shape}, range {batch['normal_values'].min().item()} - {batch['normal_values'].max().item()}") | |