| import gradio as gr |
|
|
|
|
| import numpy as np |
| import cv2 |
| from tqdm import tqdm |
|
|
| import torch |
| from pytorch3d.io.obj_io import load_obj |
| import tempfile |
| import main_mcc |
| import mcc_model |
| import util.misc as misc |
| from engine_mcc import prepare_data |
| from plyfile import PlyData, PlyElement |
| import trimesh |
|
|
| def run_inference(model, samples, device, temperature, args): |
| model.eval() |
|
|
| seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data( |
| samples, device, is_train=False, args=args, is_viz=True |
| ) |
| pred_occupy = [] |
| pred_colors = [] |
|
|
| max_n_unseen_fwd = 2000 |
|
|
| model.cached_enc_feat = None |
| num_passes = int(np.ceil(unseen_xyz.shape[1] / max_n_unseen_fwd)) |
| for p_idx in range(num_passes): |
| p_start = p_idx * max_n_unseen_fwd |
| p_end = (p_idx + 1) * max_n_unseen_fwd |
| cur_unseen_xyz = unseen_xyz[:, p_start:p_end] |
| cur_unseen_rgb = unseen_rgb[:, p_start:p_end].zero_() |
| cur_labels = labels[:, p_start:p_end].zero_() |
|
|
| with torch.no_grad(): |
| _, pred = model( |
| seen_images=seen_images, |
| seen_xyz=seen_xyz, |
| unseen_xyz=cur_unseen_xyz, |
| unseen_rgb=cur_unseen_rgb, |
| unseen_occupy=cur_labels, |
| cache_enc=True, |
| valid_seen_xyz=valid_seen_xyz, |
| ) |
| if device == "cuda": |
| pred_occupy.append(pred[..., 0].cuda()) |
| else: |
| pred_occupy.append(pred[..., 0].cpu()) |
| if args.regress_color: |
| pred_colors.append(pred[..., 1:].reshape((-1, 3))) |
| else: |
| pred_colors.append( |
| ( |
| torch.nn.Softmax(dim=2)( |
| pred[..., 1:].reshape((-1, 3, 256)) / temperature |
| ) * torch.linspace(0, 1, 256, device=pred.device) |
| ).sum(axis=2) |
| ) |
| |
| pred_occupy = torch.cat(pred_occupy, dim=1) |
| pred_occupy = torch.nn.Sigmoid()(pred_occupy) |
| return torch.cat(pred_colors, dim=0).cpu().numpy(), pred_occupy.cpu().numpy(), unseen_xyz.cpu().numpy() |
|
|
| def pad_image(im, value): |
| if im.shape[0] > im.shape[1]: |
| diff = im.shape[0] - im.shape[1] |
| return torch.cat([im, (torch.zeros((im.shape[0], diff, im.shape[2])) + value)], dim=1) |
| else: |
| diff = im.shape[1] - im.shape[0] |
| return torch.cat([im, (torch.zeros((diff, im.shape[1], im.shape[2])) + value)], dim=0) |
|
|
| def backproject_depth_to_pointcloud(depth, rotation=np.eye(3), translation=np.zeros(3)): |
| |
| principal_point = [depth.shape[1] / 2, depth.shape[0] / 2] |
| intrinsics = get_intrinsics(depth.shape[0], depth.shape[1], principal_point) |
| |
| intrinsics = get_intrinsics(depth.shape[0], depth.shape[1], principal_point) |
| |
| height, width = depth.shape |
|
|
| |
| u, v = np.meshgrid(np.arange(width), np.arange(height)) |
| uv_homogeneous = np.stack((u, v, np.ones_like(u)), axis=-1).reshape(-1, 3) |
|
|
| |
| inv_intrinsics = np.linalg.inv(intrinsics) |
|
|
| |
| points_cam_homogeneous = np.dot(uv_homogeneous, inv_intrinsics.T) * depth.flatten()[:, np.newaxis] |
|
|
| |
| points_cam_homogeneous = np.concatenate((points_cam_homogeneous, np.ones((len(points_cam_homogeneous), 1))), axis=1) |
|
|
| |
| extrinsics = np.hstack((rotation, translation[:, np.newaxis])) |
| pointcloud = np.dot(points_cam_homogeneous, extrinsics.T) |
| pointcloud[:, 1:] *= -1 |
| |
| |
| pointcloud = pointcloud[:, :3].reshape(height, width, 3) |
| |
|
|
| return pointcloud |
|
|
| |
| def get_intrinsics(H,W, principal_point): |
| """ |
| Intrinsics for a pinhole camera model. |
| Assume fov of 55 degrees and central principal point |
| of bounding box. |
| """ |
| f = 0.5 * W / np.tan(0.5 * 55 * np.pi / 180.0) |
| cx, cy = principal_point |
| return np.array([[f, 0, cx], |
| [0, f, cy], |
| [0, 0, 1]]) |
| |
| def normalize(seen_xyz): |
| seen_xyz = seen_xyz / (seen_xyz[torch.isfinite(seen_xyz.sum(dim=-1))].var(dim=0) ** 0.5).mean() |
| seen_xyz = seen_xyz - seen_xyz[torch.isfinite(seen_xyz.sum(dim=-1))].mean(axis=0) |
| return seen_xyz |
|
|
| def voxel_grid_downsample(points, colors, voxel_size): |
| |
| voxel_indices = np.floor(points / voxel_size).astype(int) |
|
|
| |
| unique_voxel_indices, inverse_indices = np.unique(voxel_indices, axis=0, return_inverse=True) |
|
|
| |
| centroids = np.empty_like(unique_voxel_indices, dtype=float) |
| avg_colors = np.empty((len(unique_voxel_indices), colors.shape[1]), dtype=colors.dtype) |
| for i in range(len(unique_voxel_indices)): |
| centroids[i] = points[inverse_indices == i].mean(axis=0) |
| avg_colors[i] = colors[inverse_indices == i].mean(axis=0) |
|
|
| |
| avg_colors = avg_colors[:, ::-1] |
| |
| return centroids, avg_colors |
|
|
| def infer( |
| image, |
| depth_image, |
| seg, |
| granularity, |
| temperature, |
| ): |
| |
| args.viz_granularity = granularity |
| |
| rgb = image |
| depth_image = cv2.imread(depth_image.name, -1) |
| depth_image = depth_image.astype(np.float32) / 256 |
| seen_xyz = backproject_depth_to_pointcloud(depth_image) |
| seen_rgb = (torch.tensor(rgb).float() / 255)[..., [2, 1, 0]] |
| H, W = seen_rgb.shape[:2] |
| seen_rgb = torch.nn.functional.interpolate( |
| seen_rgb.permute(2, 0, 1)[None], |
| size=[H, W], |
| mode="bilinear", |
| align_corners=False, |
| )[0].permute(1, 2, 0) |
|
|
| seg = cv2.imread(seg.name, cv2.IMREAD_UNCHANGED) |
| mask = torch.tensor(cv2.resize(seg, (W, H))).bool() |
| seen_xyz[~mask] = float('inf') |
| seen_xyz = torch.tensor(seen_xyz).float() |
| seen_xyz = normalize(seen_xyz) |
|
|
| bottom, right = mask.nonzero().max(dim=0)[0] |
| top, left = mask.nonzero().min(dim=0)[0] |
|
|
| bottom = bottom + 40 |
| right = right + 40 |
| top = max(top - 40, 0) |
| left = max(left - 40, 0) |
|
|
| seen_xyz = seen_xyz[top:bottom+1, left:right+1] |
| seen_rgb = seen_rgb[top:bottom+1, left:right+1] |
|
|
| seen_xyz = pad_image(seen_xyz, float('inf')) |
| seen_rgb = pad_image(seen_rgb, 0) |
|
|
| seen_rgb = torch.nn.functional.interpolate( |
| seen_rgb.permute(2, 0, 1)[None], |
| size=[800, 800], |
| mode="bilinear", |
| align_corners=False, |
| ) |
|
|
| seen_xyz = torch.nn.functional.interpolate( |
| seen_xyz.permute(2, 0, 1)[None], |
| size=[112, 112], |
| mode="bilinear", |
| align_corners=False, |
| ).permute(0, 2, 3, 1) |
|
|
| samples = [ |
| [seen_xyz, seen_rgb], |
| [torch.zeros((20000, 3)), torch.zeros((20000, 3))], |
| ] |
|
|
| pred_colors, pred_occupy, unseen_xyz = run_inference(model, samples, device, temperature, args) |
| _masks = pred_occupy > 0.1 |
| unseen_xyz = unseen_xyz[_masks] |
| pred_colors = pred_colors[None, ...][_masks] * 255 |
| |
| |
| vertex = np.core.records.fromarrays(np.hstack((unseen_xyz, pred_colors)).transpose(), |
| names='x, y, z, red, green, blue', |
| formats='f8, f8, f8, u1, u1, u1') |
| |
|
|
| |
| element = PlyElement.describe(vertex, 'vertex') |
| |
| |
| with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as f: |
| PlyData([element], text=True).write(f) |
| temp_file_name = f.name |
|
|
| |
| voxel_size = 0.2 |
| downsampled_xyz, downsampled_colors = voxel_grid_downsample(unseen_xyz, pred_colors, voxel_size) |
| |
| meshes = [] |
| for point, color in zip(downsampled_xyz, downsampled_colors): |
| |
| cube = trimesh.creation.box(extents=[voxel_size]*3) |
| cube.apply_translation(point) |
|
|
| |
| cube.visual.vertex_colors = np.hstack([color, 255]) |
| meshes.append(cube) |
|
|
| |
| with tempfile.NamedTemporaryFile(suffix=".obj", delete=False) as f: |
| temp_obj_file = f.name |
| print(temp_obj_file) |
| |
| combined = trimesh.util.concatenate(meshes) |
| |
| combined.export(temp_obj_file) |
| return temp_file_name, temp_obj_file |
|
|
| if __name__ == '__main__': |
| device = "cpu" |
| |
| |
| parser = main_mcc.get_args_parser() |
| parser.set_defaults(eval=True) |
|
|
| args = parser.parse_args() |
| |
| model = mcc_model.get_mcc_model( |
| occupancy_weight=1.0, |
| rgb_weight=0.01, |
| args=args, |
| ) |
| |
| if device == "cuda": |
| model = model.cuda() |
|
|
| misc.load_model(args=args, model_without_ddp=model, optimizer=None, loss_scaler=None) |
|
|
| demo = gr.Interface(fn=infer, |
| inputs=[gr.Image(label="Input Image"), |
| gr.File(label="Depth Image"), |
| gr.File(label="Segmentation File"), |
| gr.Slider(minimum=0.05, maximum=0.5, step=0.05, value=0.2, label="Grain Size"), |
| gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.1, label="Color Temperature") |
| ], |
| outputs=[gr.outputs.File(label="Point Cloud"), |
| gr.Model3D( clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model")], |
| examples=[["demo/quest2.jpg", "demo/quest2_depth.png", "demo/quest2_seg.png", 0.2, 0.1]], |
| cache_examples=True) |
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|