Spaces:
Runtime error
Runtime error
| import cv2 | |
| import os | |
| import numpy as np | |
| import torch | |
| import imageio | |
| from torchvision.utils import make_grid, save_image | |
| from .ray_marcher import RayMarcher, generate_colored_boxes | |
| def get_pose_on_orbit(radius, height, angles, world_up=torch.Tensor([0, 1, 0])): | |
| num_points = angles.shape[0] | |
| x = radius * torch.cos(angles) | |
| h = torch.ones((num_points,)) * height | |
| z = radius * torch.sin(angles) | |
| position = torch.stack([x, h, z], dim=-1) | |
| forward = position / torch.norm(position, p=2, dim=-1, keepdim=True) | |
| right = -torch.cross(world_up[None, ...], forward) | |
| right /= torch.norm(right, dim=-1, keepdim=True) | |
| up = torch.cross(forward, right) | |
| up /= torch.norm(up, p=2, dim=-1, keepdim=True) | |
| rotation = torch.stack([right, up, forward], dim=1) | |
| translation = torch.Tensor([0, 0, radius])[None, :, None].repeat(num_points, 1, 1) | |
| return torch.concat([rotation, translation], dim=2) | |
| def render_mvp_boxes(rm, batch, preds): | |
| with torch.no_grad(): | |
| boxes_rgba = generate_colored_boxes( | |
| preds["prim_rgba"], | |
| preds["prim_rot"], | |
| ) | |
| preds_boxes = rm( | |
| prim_rgba=boxes_rgba, | |
| prim_pos=preds["prim_pos"], | |
| prim_scale=preds["prim_scale"], | |
| prim_rot=preds["prim_rot"], | |
| RT=batch["Rt"], | |
| K=batch["K"], | |
| ) | |
| return preds_boxes["rgba_image"][:, :3].permute(0, 2, 3, 1) | |
| def save_image_summary(path, batch, preds): | |
| rgb = preds["rgb"].detach().permute(0, 3, 1, 2) | |
| # rgb_gt = batch["image"] | |
| rgb_boxes = preds["rgb_boxes"].detach().permute(0, 3, 1, 2) | |
| bs = rgb_boxes.shape[0] | |
| if "folder" in batch and "key" in batch: | |
| obj_list = [] | |
| for bs_idx in range(bs): | |
| tmp_img = rgb_boxes[bs_idx].permute(1, 2, 0).to(torch.uint8).cpu().numpy() | |
| tmp_img = np.ascontiguousarray(tmp_img) | |
| folder = batch['folder'][bs_idx] | |
| key = batch['key'][bs_idx] | |
| obj_list.append("{}/{}\n".format(folder, key)) | |
| cv2.putText(tmp_img, "{}".format(folder), (200, 200), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 2) | |
| cv2.putText(tmp_img, "{}".format(key), (200, 400), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 2) | |
| tmp_img_torch = torch.as_tensor(tmp_img).permute(2, 0, 1).float() | |
| rgb_boxes[bs_idx] = tmp_img_torch | |
| with open(os.path.splitext(path)[0]+".txt", "w") as f: | |
| f.writelines(obj_list) | |
| img = make_grid(torch.cat([rgb, rgb_boxes], dim=2) / 255.0).clip(0.0, 1.0) | |
| save_image(img, path) | |
| def visualize_primsdf_box(image_save_path, model, rm: RayMarcher, device): | |
| # prim_rgba: primitive payload [B, K, 4, S, S, S], | |
| # K - # of primitives, S - primitive size | |
| # prim_pos: locations [B, K, 3] | |
| # prim_rot: rotations [B, K, 3, 3] | |
| # prim_scale: scales [B, K, 3] | |
| # K: intrinsics [B, 3, 3] | |
| # RT: extrinsics [B, 3, 4] | |
| preds = {} | |
| batch = {} | |
| prim_alpha = model.sdf2alpha(model.feat_geo).reshape(1, model.num_prims, 1, model.prim_shape, model.prim_shape, model.prim_shape) * 255 | |
| prim_rgb = model.feat_tex.reshape(1, model.num_prims, 3, model.prim_shape, model.prim_shape, model.prim_shape) * 255 | |
| preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2) | |
| preds['prim_pos'] = model.pos.reshape(1, model.num_prims, 3) * rm.volradius | |
| preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(1, model.num_prims, 1, 1) | |
| preds['prim_scale'] = (1 / model.scale.reshape(1, model.num_prims, 1).repeat(1, 1, 3)) | |
| batch['Rt'] = torch.Tensor([ | |
| [ | |
| 1.0, | |
| 0.0, | |
| 0.0, | |
| 0.0 * rm.volradius | |
| ], | |
| [ | |
| 0.0, | |
| -1.0, | |
| 0.0, | |
| 0.0 * rm.volradius | |
| ], | |
| [ | |
| 0.0, | |
| 0.0, | |
| -1.0, | |
| 5 * rm.volradius | |
| ] | |
| ]).to(device)[None, ...] | |
| batch['K'] = torch.Tensor([ | |
| [ | |
| 2084.9526697685183, | |
| 0.0, | |
| 512.0 | |
| ], | |
| [ | |
| 0.0, | |
| 2084.9526697685183, | |
| 512.0 | |
| ], | |
| [ | |
| 0.0, | |
| 0.0, | |
| 1.0 | |
| ]]).to(device)[None, ...] | |
| ratio_h = rm.image_height / 1024. | |
| ratio_w = rm.image_width / 1024. | |
| batch['K'][:, 0:1, :] *= ratio_h | |
| batch['K'][:, 1:2, :] *= ratio_w | |
| # raymarcher is in mm | |
| rm_preds = rm( | |
| prim_rgba=preds["prim_rgba"], | |
| prim_pos=preds["prim_pos"], | |
| prim_scale=preds["prim_scale"], | |
| prim_rot=preds["prim_rot"], | |
| RT=batch["Rt"], | |
| K=batch["K"], | |
| ) | |
| rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1) | |
| preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous()) | |
| with torch.no_grad(): | |
| preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds) | |
| save_image_summary(image_save_path, batch, preds) | |
| def render_primsdf(image_save_path, model, rm, device): | |
| preds = {} | |
| batch = {} | |
| preds['prim_pos'] = model.pos.reshape(1, model.num_prims, 3) * rm.volradius | |
| preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(1, model.num_prims, 1, 1) | |
| preds['prim_scale'] = (1 / model.scale.reshape(1, model.num_prims, 1).repeat(1, 1, 3)) | |
| batch['Rt'] = torch.Tensor([ | |
| [ | |
| 1.0, | |
| 0.0, | |
| 0.0, | |
| 0.0 * rm.volradius | |
| ], | |
| [ | |
| 0.0, | |
| -1.0, | |
| 0.0, | |
| 0.0 * rm.volradius | |
| ], | |
| [ | |
| 0.0, | |
| 0.0, | |
| -1.0, | |
| 5 * rm.volradius | |
| ] | |
| ]).to(device)[None, ...] | |
| batch['K'] = torch.Tensor([ | |
| [ | |
| 2084.9526697685183, | |
| 0.0, | |
| 512.0 | |
| ], | |
| [ | |
| 0.0, | |
| 2084.9526697685183, | |
| 512.0 | |
| ], | |
| [ | |
| 0.0, | |
| 0.0, | |
| 1.0 | |
| ]]).to(device)[None, ...] | |
| ratio_h = rm.image_height / 1024. | |
| ratio_w = rm.image_width / 1024. | |
| batch['K'][:, 0:1, :] *= ratio_h | |
| batch['K'][:, 1:2, :] *= ratio_w | |
| # test rendering | |
| all_sampled_sdf = [] | |
| all_sampled_tex = [] | |
| for i in range(model.prim_shape ** 3): | |
| with torch.no_grad(): | |
| model_prediction = model(model.sdf_sampled_point[:, i, :].to(device)) | |
| sampled_sdf = model_prediction['sdf'] | |
| sampled_rgb = model_prediction['tex'] | |
| all_sampled_sdf.append(sampled_sdf) | |
| all_sampled_tex.append(sampled_rgb) | |
| sampled_sdf = torch.stack(all_sampled_sdf, dim=1) | |
| sampled_tex = torch.stack(all_sampled_tex, dim=1).permute(0, 2, 1).reshape(1, model.num_prims, 3, model.prim_shape, model.prim_shape, model.prim_shape) * 255 | |
| prim_rgb = sampled_tex | |
| prim_alpha = model.sdf2alpha(sampled_sdf).reshape(1, model.num_prims, 1, model.prim_shape, model.prim_shape, model.prim_shape) * 255 | |
| preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2) | |
| rm_preds = rm( | |
| prim_rgba=preds["prim_rgba"], | |
| prim_pos=preds["prim_pos"], | |
| prim_scale=preds["prim_scale"], | |
| prim_rot=preds["prim_rot"], | |
| RT=batch["Rt"], | |
| K=batch["K"], | |
| ) | |
| rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1) | |
| preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous()) | |
| with torch.no_grad(): | |
| preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds) | |
| save_image_summary(image_save_path, batch, preds) | |
| def visualize_primvolume(image_save_path, batch, prim_volume, rm: RayMarcher, device): | |
| # prim_volume - [B, nprims, 4+6*8^3] | |
| def sdf2alpha(sdf): | |
| return torch.exp(-(sdf / 0.005) ** 2) | |
| preds = {} | |
| prim_shape = int(np.round(((prim_volume.shape[2] - 4) / 6) ** (1/3))) | |
| num_prims = prim_volume.shape[1] | |
| bs = prim_volume.shape[0] | |
| geo_start_index = 4 | |
| geo_end_index = geo_start_index + prim_shape ** 3 # non-inclusive | |
| tex_start_index = geo_end_index | |
| tex_end_index = tex_start_index + prim_shape ** 3 * 3 # non-inclusive | |
| mat_start_index = tex_end_index | |
| mat_end_index = mat_start_index + prim_shape ** 3 * 2 | |
| feat_geo = prim_volume[:, :, geo_start_index: geo_end_index] | |
| feat_tex = prim_volume[:, :, tex_start_index: tex_end_index] | |
| prim_alpha = sdf2alpha(feat_geo).reshape(bs, num_prims, 1, prim_shape, prim_shape, prim_shape) * 255 | |
| prim_rgb = feat_tex.reshape(bs, num_prims, 3, prim_shape, prim_shape, prim_shape) * 255 | |
| preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2) | |
| pos = prim_volume[:, :, 1:4] | |
| scale = prim_volume[:, :, 0:1] | |
| preds['prim_pos'] = pos.reshape(bs, num_prims, 3) * rm.volradius | |
| preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(bs, num_prims, 1, 1) | |
| preds['prim_scale'] = (1 / scale.reshape(bs, num_prims, 1).repeat(1, 1, 3)) | |
| batch['Rt'] = torch.Tensor([ | |
| [ | |
| 1.0, | |
| 0.0, | |
| 0.0, | |
| 0.0 * rm.volradius | |
| ], | |
| [ | |
| 0.0, | |
| -1.0, | |
| 0.0, | |
| 0.0 * rm.volradius | |
| ], | |
| [ | |
| 0.0, | |
| 0.0, | |
| -1.0, | |
| 5 * rm.volradius | |
| ] | |
| ]).to(device)[None, ...].repeat(bs, 1, 1) | |
| batch['K'] = torch.Tensor([ | |
| [ | |
| 2084.9526697685183, | |
| 0.0, | |
| 512.0 | |
| ], | |
| [ | |
| 0.0, | |
| 2084.9526697685183, | |
| 512.0 | |
| ], | |
| [ | |
| 0.0, | |
| 0.0, | |
| 1.0 | |
| ]]).to(device)[None, ...].repeat(bs, 1, 1) | |
| ratio_h = rm.image_height / 1024. | |
| ratio_w = rm.image_width / 1024. | |
| batch['K'][:, 0:1, :] *= ratio_h | |
| batch['K'][:, 1:2, :] *= ratio_w | |
| # raymarcher is in mm | |
| rm_preds = rm( | |
| prim_rgba=preds["prim_rgba"], | |
| prim_pos=preds["prim_pos"], | |
| prim_scale=preds["prim_scale"], | |
| prim_rot=preds["prim_rot"], | |
| RT=batch["Rt"], | |
| K=batch["K"], | |
| ) | |
| rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1) | |
| preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous()) | |
| with torch.no_grad(): | |
| preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds) | |
| save_image_summary(image_save_path, batch, preds) | |
| def visualize_multiview_primvolume(image_save_path, batch, prim_volume, view_counts, rm: RayMarcher, device): | |
| # prim_volume - [B, nprims, 4+6*8^3] | |
| view_angles = torch.linspace(0.5, 2.5, view_counts + 1) * torch.pi | |
| view_angles = view_angles[:-1] | |
| def sdf2alpha(sdf): | |
| return torch.exp(-(sdf / 0.005) ** 2) | |
| preds = {} | |
| prim_shape = int(np.round(((prim_volume.shape[2] - 4) / 6) ** (1/3))) | |
| num_prims = prim_volume.shape[1] | |
| bs = prim_volume.shape[0] | |
| geo_start_index = 4 | |
| geo_end_index = geo_start_index + prim_shape ** 3 # non-inclusive | |
| tex_start_index = geo_end_index | |
| tex_end_index = tex_start_index + prim_shape ** 3 * 3 # non-inclusive | |
| mat_start_index = tex_end_index | |
| mat_end_index = mat_start_index + prim_shape ** 3 * 2 | |
| feat_geo = prim_volume[:, :, geo_start_index: geo_end_index] | |
| feat_tex = prim_volume[:, :, tex_start_index: tex_end_index] | |
| prim_alpha = sdf2alpha(feat_geo).reshape(bs, num_prims, 1, prim_shape, prim_shape, prim_shape) * 255 | |
| prim_rgb = feat_tex.reshape(bs, num_prims, 3, prim_shape, prim_shape, prim_shape) * 255 | |
| preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2) | |
| pos = prim_volume[:, :, 1:4] | |
| scale = prim_volume[:, :, 0:1] | |
| preds['prim_pos'] = pos.reshape(bs, num_prims, 3) * rm.volradius | |
| preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(bs, num_prims, 1, 1) | |
| preds['prim_scale'] = (1 / scale.reshape(bs, num_prims, 1).repeat(1, 1, 3)) | |
| batch['K'] = torch.Tensor([ | |
| [ | |
| 2084.9526697685183, | |
| 0.0, | |
| 512.0 | |
| ], | |
| [ | |
| 0.0, | |
| 2084.9526697685183, | |
| 512.0 | |
| ], | |
| [ | |
| 0.0, | |
| 0.0, | |
| 1.0 | |
| ]]).to(device)[None, ...].repeat(bs, 1, 1) | |
| ratio_h = rm.image_height / 1024. | |
| ratio_w = rm.image_width / 1024. | |
| batch['K'][:, 0:1, :] *= ratio_h | |
| batch['K'][:, 1:2, :] *= ratio_w | |
| final_preds = {} | |
| final_preds['rgb'] = [] | |
| final_preds['rgb_boxes'] = [] | |
| for view_ang in view_angles: | |
| bs_view_ang = view_ang.repeat(bs,) | |
| batch['Rt'] = get_pose_on_orbit(radius=5*rm.volradius, height=0, angles=bs_view_ang).to(prim_volume) | |
| # raymarcher is in mm | |
| rm_preds = rm( | |
| prim_rgba=preds["prim_rgba"], | |
| prim_pos=preds["prim_pos"], | |
| prim_scale=preds["prim_scale"], | |
| prim_rot=preds["prim_rot"], | |
| RT=batch["Rt"], | |
| K=batch["K"], | |
| ) | |
| rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1) | |
| preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous()) | |
| with torch.no_grad(): | |
| preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds) | |
| final_preds['rgb'].append(preds['rgb']) | |
| final_preds['rgb_boxes'].append(preds['rgb_boxes']) | |
| final_preds['rgb'] = torch.concat(final_preds['rgb'], dim=0) | |
| final_preds['rgb_boxes'] = torch.concat(final_preds['rgb_boxes'], dim=0) | |
| save_image_summary(image_save_path, batch, final_preds) | |
| def visualize_video_primvolume(video_save_folder, batch, prim_volume, view_counts, rm: RayMarcher, device): | |
| # prim_volume - [B, nprims, 4+6*8^3] | |
| view_angles = torch.linspace(1.5, 3.5, view_counts + 1) * torch.pi | |
| def sdf2alpha(sdf): | |
| return torch.exp(-(sdf / 0.005) ** 2) | |
| preds = {} | |
| prim_shape = int(np.round(((prim_volume.shape[2] - 4) / 6) ** (1/3))) | |
| num_prims = prim_volume.shape[1] | |
| bs = prim_volume.shape[0] | |
| geo_start_index = 4 | |
| geo_end_index = geo_start_index + prim_shape ** 3 # non-inclusive | |
| tex_start_index = geo_end_index | |
| tex_end_index = tex_start_index + prim_shape ** 3 * 3 # non-inclusive | |
| mat_start_index = tex_end_index | |
| mat_end_index = mat_start_index + prim_shape ** 3 * 2 | |
| feat_geo = prim_volume[:, :, geo_start_index: geo_end_index] | |
| feat_tex = prim_volume[:, :, tex_start_index: tex_end_index] | |
| feat_mat = prim_volume[:, :, mat_start_index: mat_end_index] | |
| prim_alpha = sdf2alpha(feat_geo).reshape(bs, num_prims, 1, prim_shape, prim_shape, prim_shape) * 255 | |
| prim_rgb = feat_tex.reshape(bs, num_prims, 3, prim_shape, prim_shape, prim_shape) * 255 | |
| prim_mat = feat_mat.reshape(bs, num_prims, 2, prim_shape, prim_shape, prim_shape) * 255 | |
| dummy_prim = torch.zeros_like(prim_mat[:, :, 0:1, ...]) | |
| prim_mat = torch.concat([dummy_prim, prim_mat], dim=2) | |
| preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2) | |
| preds['prim_mata'] = torch.concat([prim_mat, prim_alpha], dim=2) | |
| pos = prim_volume[:, :, 1:4] | |
| scale = prim_volume[:, :, 0:1] | |
| preds['prim_pos'] = pos.reshape(bs, num_prims, 3) * rm.volradius | |
| preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(bs, num_prims, 1, 1) | |
| preds['prim_scale'] = (1 / scale.reshape(bs, num_prims, 1).repeat(1, 1, 3)) | |
| batch['K'] = torch.Tensor([ | |
| [ | |
| 2084.9526697685183, | |
| 0.0, | |
| 512.0 | |
| ], | |
| [ | |
| 0.0, | |
| 2084.9526697685183, | |
| 512.0 | |
| ], | |
| [ | |
| 0.0, | |
| 0.0, | |
| 1.0 | |
| ]]).to(device)[None, ...].repeat(bs, 1, 1) | |
| ratio_h = rm.image_height / 1024. | |
| ratio_w = rm.image_width / 1024. | |
| batch['K'][:, 0:1, :] *= ratio_h | |
| batch['K'][:, 1:2, :] *= ratio_w | |
| final_preds = {} | |
| final_preds['rgb'] = [] | |
| final_preds['rgb_boxes'] = [] | |
| final_preds['mat_rgb'] = [] | |
| for view_ang in view_angles: | |
| bs_view_ang = view_ang.repeat(bs,) | |
| batch['Rt'] = get_pose_on_orbit(radius=5*rm.volradius, height=0, angles=bs_view_ang).to(prim_volume) | |
| # raymarcher is in mm | |
| rm_preds = rm( | |
| prim_rgba=preds["prim_rgba"], | |
| prim_pos=preds["prim_pos"], | |
| prim_scale=preds["prim_scale"], | |
| prim_rot=preds["prim_rot"], | |
| RT=batch["Rt"], | |
| K=batch["K"], | |
| ) | |
| rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1) | |
| preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous()) | |
| with torch.no_grad(): | |
| preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds) | |
| rm_preds = rm( | |
| prim_rgba=preds["prim_mata"], | |
| prim_pos=preds["prim_pos"], | |
| prim_scale=preds["prim_scale"], | |
| prim_rot=preds["prim_rot"], | |
| RT=batch["Rt"], | |
| K=batch["K"], | |
| ) | |
| mat_rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1) | |
| preds.update(mat_rgb=mat_rgba[..., :3].contiguous()) | |
| final_preds['rgb'].append(preds['rgb']) | |
| final_preds['rgb_boxes'].append(preds['rgb_boxes']) | |
| final_preds['mat_rgb'].append(preds['mat_rgb']) | |
| assert len(final_preds['rgb']) == len(final_preds['rgb_boxes']) | |
| final_preds['rgb'] = torch.concat(final_preds['rgb'], dim=0) | |
| final_preds['rgb_boxes'] = torch.concat(final_preds['rgb_boxes'], dim=0) | |
| final_preds['mat_rgb'] = torch.concat(final_preds['mat_rgb'], dim=0) | |
| total_num_frames = final_preds['rgb'].shape[0] | |
| rgb_video = os.path.join(video_save_folder, 'rgb.mp4') | |
| rgb_video_out = imageio.get_writer(rgb_video, fps=20) | |
| prim_video = os.path.join(video_save_folder, 'prim.mp4') | |
| prim_video_out = imageio.get_writer(prim_video, fps=20) | |
| mat_video = os.path.join(video_save_folder, 'mat.mp4') | |
| mat_video_out = imageio.get_writer(mat_video, fps=20) | |
| rgb_np = np.clip(final_preds['rgb'].detach().cpu().numpy(), 0, 255).astype(np.uint8) | |
| prim_np = np.clip(final_preds['rgb_boxes'].detach().cpu().numpy(), 0, 255).astype(np.uint8) | |
| mat_np = np.clip(final_preds['mat_rgb'].detach().cpu().numpy(), 0, 255).astype(np.uint8) | |
| for fidx in range(total_num_frames): | |
| rgb_video_out.append_data(rgb_np[fidx]) | |
| prim_video_out.append_data(prim_np[fidx]) | |
| mat_video_out.append_data(mat_np[fidx]) | |
| rgb_video_out.close() | |
| prim_video_out.close() | |
| mat_video_out.close() |