Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' | |
| from pathlib import Path | |
| import sys | |
| if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: | |
| sys.path.insert(0, _package_root) | |
| from typing import * | |
| import itertools | |
| import json | |
| import warnings | |
| import click | |
| def main( | |
| input_path: str, | |
| output_path: str, | |
| pretrained_model_name_or_path: str, | |
| device_name: str, | |
| resize_to: int, | |
| resolution_level: int, | |
| threshold: float, | |
| batch_size: int, | |
| save_splitted: bool, | |
| save_maps_: bool, | |
| save_glb_: bool, | |
| save_ply_: bool, | |
| show: bool, | |
| ): | |
| # Lazy import | |
| import cv2 | |
| import numpy as np | |
| from numpy import ndarray | |
| import torch | |
| from PIL import Image | |
| from tqdm import tqdm, trange | |
| import trimesh | |
| import trimesh.visual | |
| from scipy.sparse import csr_array, hstack, vstack | |
| from scipy.ndimage import convolve | |
| from scipy.sparse.linalg import lsmr | |
| import utils3d | |
| from moge.model.v1 import MoGeModel | |
| from moge.utils.io import save_glb, save_ply | |
| from moge.utils.vis import colorize_depth | |
| from moge.utils.panorama import spherical_uv_to_directions, get_panorama_cameras, split_panorama_image, merge_panorama_depth | |
| device = torch.device(device_name) | |
| include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG'] | |
| if Path(input_path).is_dir(): | |
| image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices))) | |
| else: | |
| image_paths = [Path(input_path)] | |
| if len(image_paths) == 0: | |
| raise FileNotFoundError(f'No image files found in {input_path}') | |
| # Write outputs | |
| if not any([save_maps_, save_glb_, save_ply_]): | |
| warnings.warn('No output format specified. Defaults to saving all. Please use "--maps", "--glb", or "--ply" to specify the output.') | |
| save_maps_ = save_glb_ = save_ply_ = True | |
| model = MoGeModel.from_pretrained(pretrained_model_name_or_path).to(device).eval() | |
| for image_path in (pbar := tqdm(image_paths, desc='Total images', disable=len(image_paths) <= 1)): | |
| image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB) | |
| height, width = image.shape[:2] | |
| if resize_to is not None: | |
| height, width = min(resize_to, int(resize_to * height / width)), min(resize_to, int(resize_to * width / height)) | |
| image = cv2.resize(image, (width, height), cv2.INTER_AREA) | |
| splitted_extrinsics, splitted_intriniscs = get_panorama_cameras() | |
| splitted_resolution = 512 | |
| splitted_images = split_panorama_image(image, splitted_extrinsics, splitted_intriniscs, splitted_resolution) | |
| # Infer each view | |
| print('Inferring...') if pbar.disable else pbar.set_postfix_str(f'Inferring') | |
| splitted_distance_maps, splitted_masks = [], [] | |
| for i in trange(0, len(splitted_images), batch_size, desc='Inferring splitted views', disable=len(splitted_images) <= batch_size, leave=False): | |
| image_tensor = torch.tensor(np.stack(splitted_images[i:i + batch_size]) / 255, dtype=torch.float32, device=device).permute(0, 3, 1, 2) | |
| fov_x, fov_y = np.rad2deg(utils3d.numpy.intrinsics_to_fov(np.array(splitted_intriniscs[i:i + batch_size]))) | |
| fov_x = torch.tensor(fov_x, dtype=torch.float32, device=device) | |
| output = model.infer(image_tensor, fov_x=fov_x, apply_mask=False) | |
| distance_map, mask = output['points'].norm(dim=-1).cpu().numpy(), output['mask'].cpu().numpy() | |
| splitted_distance_maps.extend(list(distance_map)) | |
| splitted_masks.extend(list(mask)) | |
| # Save splitted | |
| if save_splitted: | |
| splitted_save_path = Path(output_path, image_path.stem, 'splitted') | |
| splitted_save_path.mkdir(exist_ok=True, parents=True) | |
| for i in range(len(splitted_images)): | |
| cv2.imwrite(str(splitted_save_path / f'{i:02d}.jpg'), cv2.cvtColor(splitted_images[i], cv2.COLOR_RGB2BGR)) | |
| cv2.imwrite(str(splitted_save_path / f'{i:02d}_distance_vis.png'), cv2.cvtColor(colorize_depth(splitted_distance_maps[i], splitted_masks[i]), cv2.COLOR_RGB2BGR)) | |
| # Merge | |
| print('Merging...') if pbar.disable else pbar.set_postfix_str(f'Merging') | |
| merging_width, merging_height = min(1920, width), min(960, height) | |
| panorama_depth, panorama_mask = merge_panorama_depth(merging_width, merging_height, splitted_distance_maps, splitted_masks, splitted_extrinsics, splitted_intriniscs) | |
| panorama_depth = panorama_depth.astype(np.float32) | |
| panorama_depth = cv2.resize(panorama_depth, (width, height), cv2.INTER_LINEAR) | |
| panorama_mask = cv2.resize(panorama_mask.astype(np.uint8), (width, height), cv2.INTER_NEAREST) > 0 | |
| points = panorama_depth[:, :, None] * spherical_uv_to_directions(utils3d.numpy.image_uv(width=width, height=height)) | |
| # Write outputs | |
| print('Writing outputs...') if pbar.disable else pbar.set_postfix_str(f'Inferring') | |
| save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem) | |
| save_path.mkdir(exist_ok=True, parents=True) | |
| if save_maps_: | |
| cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) | |
| cv2.imwrite(str(save_path / 'depth_vis.png'), cv2.cvtColor(colorize_depth(panorama_depth, mask=panorama_mask), cv2.COLOR_RGB2BGR)) | |
| cv2.imwrite(str(save_path / 'depth.exr'), panorama_depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) | |
| cv2.imwrite(str(save_path / 'points.exr'), points, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) | |
| cv2.imwrite(str(save_path /'mask.png'), (panorama_mask * 255).astype(np.uint8)) | |
| # Export mesh & visulization | |
| if save_glb_ or save_ply_ or show: | |
| normals, normals_mask = utils3d.numpy.points_to_normals(points, panorama_mask) | |
| faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( | |
| points, | |
| image.astype(np.float32) / 255, | |
| utils3d.numpy.image_uv(width=width, height=height), | |
| mask=panorama_mask & ~(utils3d.numpy.depth_edge(panorama_depth, rtol=threshold) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask)), | |
| tri=True | |
| ) | |
| if save_glb_: | |
| save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image) | |
| if save_ply_: | |
| save_ply(save_path / 'mesh.ply', vertices, faces, vertex_colors) | |
| if show: | |
| trimesh.Trimesh( | |
| vertices=vertices, | |
| vertex_colors=vertex_colors, | |
| faces=faces, | |
| process=False | |
| ).show() | |
| if __name__ == '__main__': | |
| main() |