| |
| |
| |
| |
| |
| |
|
|
| import sys |
|
|
| sys.path.append("./") |
|
|
| import math |
| import os |
| import pdb |
|
|
| import numpy as np |
| import torch |
| from plyfile import PlyData, PlyElement |
|
|
| from core.models.rendering.utils.sh_utils import RGB2SH, SH2RGB |
| from core.models.rendering.utils.typing import * |
| from engine.MVSRender.camera_utils import MiniCam, OrbitCamera, orbit_camera |
|
|
|
|
| def avaliable_device(): |
| if torch.cuda.is_available(): |
| current_device_id = torch.cuda.current_device() |
| device = f"cuda:{current_device_id}" |
| else: |
| device = "cpu" |
|
|
| return device |
|
|
|
|
| from diff_gaussian_rasterization import ( |
| GaussianRasterizationSettings, |
| GaussianRasterizer, |
| ) |
| from PIL import Image |
|
|
|
|
| def inverse_sigmoid(x): |
|
|
| if isinstance(x, float): |
| x = torch.tensor(x).float() |
|
|
| return torch.log(x / (1 - x)) |
|
|
|
|
| class GaussianModel: |
| def setup_functions(self): |
|
|
| self.scaling_activation = torch.exp |
| self.scaling_inverse_activation = torch.log |
|
|
| self.opacity_activation = torch.sigmoid |
| self.inverse_opacity_activation = inverse_sigmoid |
| self.rotation_activation = torch.nn.functional.normalize |
|
|
| |
| self.rgb_activation = torch.sigmoid |
|
|
| def __init__(self, use_rgb=True): |
|
|
| self.setup_functions() |
|
|
| self.xyz: Tensor = torch.empty(0) |
| self.opacity: Tensor = torch.empty(0) |
| self.rotation: Tensor = torch.empty(0) |
| self.scaling: Tensor = torch.empty(0) |
| self.shs: Tensor = torch.empty(0) |
| self.use_rgb = use_rgb |
|
|
| def construct_list_of_attributes(self): |
| l = ["x", "y", "z", "nx", "ny", "nz"] |
| features_dc = self.shs[:, :1] |
| features_rest = self.shs[:, 1:] |
|
|
| for i in range(features_dc.shape[1] * features_dc.shape[2]): |
| l.append("f_dc_{}".format(i)) |
| for i in range(features_rest.shape[1] * features_rest.shape[2]): |
| l.append("f_rest_{}".format(i)) |
| l.append("opacity") |
| for i in range(self.scaling.shape[1]): |
| l.append("scale_{}".format(i)) |
| for i in range(self.rotation.shape[1]): |
| l.append("rot_{}".format(i)) |
| return l |
|
|
| def save_ply(self, path): |
|
|
| xyz = self.xyz.detach().cpu().numpy() |
| normals = np.zeros_like(xyz) |
|
|
| if self.use_rgb: |
| shs = RGB2SH(self.shs) |
| else: |
| shs = self.shs |
|
|
| features_dc = shs[:, :1] |
| features_rest = shs[:, 1:] |
|
|
| f_dc = ( |
| features_dc.float().detach().flatten(start_dim=1).contiguous().cpu().numpy() |
| ) |
|
|
| f_rest = ( |
| features_rest.float() |
| .detach() |
| .flatten(start_dim=1) |
| .contiguous() |
| .cpu() |
| .numpy() |
| ) |
|
|
| opacities = ( |
| |
| inverse_sigmoid(self.opacity) |
| .detach() |
| .cpu() |
| .numpy() |
| ) |
|
|
| scale = np.log(self.scaling.detach().cpu().numpy()) |
| rotation = self.rotation.detach().cpu().numpy() |
|
|
| dtype_full = [ |
| (attribute, "f4") for attribute in self.construct_list_of_attributes() |
| ] |
|
|
| elements = np.empty(xyz.shape[0], dtype=dtype_full) |
| attributes = np.concatenate( |
| (xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1 |
| ) |
| elements[:] = list(map(tuple, attributes)) |
| el = PlyElement.describe(elements, "vertex") |
| PlyData([el]).write(path) |
|
|
| def load_ply(self, path): |
|
|
| plydata = PlyData.read(path) |
|
|
| xyz = np.stack( |
| ( |
| np.asarray(plydata.elements[0]["x"]), |
| np.asarray(plydata.elements[0]["y"]), |
| np.asarray(plydata.elements[0]["z"]), |
| ), |
| axis=1, |
| ) |
| opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] |
|
|
| features_dc = np.zeros((xyz.shape[0], 3, 1)) |
| features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) |
| features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) |
| features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) |
|
|
| extra_f_names = [ |
| p.name |
| for p in plydata.elements[0].properties |
| if p.name.startswith("f_rest_") |
| ] |
|
|
| extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1])) |
| sh_degree = int(math.sqrt((len(extra_f_names) + 3) / 3)) - 1 |
|
|
| features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) |
| for idx, attr_name in enumerate(extra_f_names): |
| features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) |
| |
| |
| features_extra = features_extra.reshape( |
| (features_extra.shape[0], 3, (sh_degree + 1) ** 2 - 1) |
| ) |
|
|
| scale_names = [ |
| p.name |
| for p in plydata.elements[0].properties |
| if p.name.startswith("scale_") |
| ] |
| scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1])) |
| scales = np.zeros((xyz.shape[0], len(scale_names))) |
| for idx, attr_name in enumerate(scale_names): |
| scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) |
|
|
| rot_names = [ |
| p.name for p in plydata.elements[0].properties if p.name.startswith("rot") |
| ] |
| rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1])) |
| rots = np.zeros((xyz.shape[0], len(rot_names))) |
| for idx, attr_name in enumerate(rot_names): |
| rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) |
|
|
| xyz = torch.from_numpy(xyz).to(self.xyz) |
| opacities = torch.from_numpy(opacities).to(self.opacity) |
| rotation = torch.from_numpy(rots).to(self.rotation) |
| scales = torch.from_numpy(scales).to(self.scaling) |
| features_dc = torch.from_numpy(features_dc).to(self.shs) |
| features_rest = torch.from_numpy(features_extra).to(self.shs) |
|
|
| shs = torch.cat([features_dc, features_rest], dim=2) |
|
|
| if self.use_rgb: |
| shs = SH2RGB(shs) |
| else: |
| shs = shs |
|
|
| self.xyz: Tensor = xyz |
| self.opacity: Tensor = self.opacity_activation(opacities) |
| self.rotation: Tensor = self.rotation_activation(rotation) |
| self.scaling: Tensor = self.scaling_activation(scales) |
| self.shs: Tensor = shs.permute(0, 2, 1) |
|
|
| self.active_sh_degree = sh_degree |
|
|
|
|
| def camera_traj(cam, ref_size=1024, views=30, radius=2.0): |
|
|
| azimuth_bins = 360 // views |
|
|
| cameras = [] |
|
|
| for rotate_i in range(30): |
| pose = orbit_camera( |
| 0, |
| 0 + rotate_i * azimuth_bins, |
| radius, |
| ) |
|
|
| cur_cam = MiniCam( |
| pose, |
| ref_size, |
| ref_size, |
| cam.fovy, |
| cam.fovx, |
| cam.near, |
| cam.far, |
| ) |
|
|
| cameras.append(cur_cam) |
|
|
| return cameras |
|
|
|
|
| class MVSRender: |
| def __init__( |
| self, gs_file, fovy=50, ref_size=1024, views=30, radius=2.0, sh_degree=0 |
| ): |
| self.cam = OrbitCamera(ref_size, ref_size, r=radius, fovy=fovy) |
|
|
| self.fovy = fovy |
| self.ref_size = ref_size |
| self.views = views |
| self.radius = radius |
|
|
| input_gs = GaussianModel(use_rgb=True) |
| input_gs.load_ply(gs_file) |
|
|
| self.gs = input_gs |
| self.autosize() |
|
|
| self.camera_views = camera_traj(self.cam, ref_size, views, radius) |
| self.device = avaliable_device() |
| self.sh_degree = sh_degree |
|
|
| def autosize(self): |
| xyz = self.gs.xyz |
| min_xyz = xyz.min(dim=0).values |
| max_xyz = xyz.max(dim=0).values |
| middle_offset = (min_xyz + max_xyz) / 2 |
| xyz -= middle_offset |
| self.gs.xyz = xyz |
| self.offset = middle_offset |
|
|
| def gsplat( |
| self, |
| viewpoint_camera, |
| background_color, |
| ): |
| |
|
|
| gs = self.gs |
|
|
| screenspace_points = ( |
| torch.zeros_like( |
| gs.xyz, dtype=gs.xyz.dtype, requires_grad=True, device=self.device |
| ) |
| + 0 |
| ) |
| try: |
| screenspace_points.retain_grad() |
| except: |
| pass |
|
|
| bg_color = torch.ones(3).float().to(self.device) |
| |
| tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) |
| tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) |
|
|
| raster_settings = GaussianRasterizationSettings( |
| image_height=int(viewpoint_camera.image_height), |
| image_width=int(viewpoint_camera.image_width), |
| tanfovx=tanfovx, |
| tanfovy=tanfovy, |
| bg=bg_color, |
| scale_modifier=1.0, |
| viewmatrix=viewpoint_camera.world_view_transform, |
| projmatrix=viewpoint_camera.full_proj_transform.float(), |
| sh_degree=self.sh_degree, |
| campos=viewpoint_camera.camera_center, |
| prefiltered=False, |
| debug=False, |
| ) |
|
|
| rasterizer = GaussianRasterizer(raster_settings=raster_settings) |
|
|
| means3D = gs.xyz |
| means2D = screenspace_points |
| opacity = gs.opacity |
|
|
| |
| |
| scales = None |
| rotations = None |
| cov3D_precomp = None |
| scales = gs.scaling |
| rotations = gs.rotation |
|
|
| |
| |
| shs = None |
| colors_precomp = None |
|
|
| colors_precomp = gs.shs.squeeze(1).float() |
| shs = None |
|
|
| |
| |
|
|
| with torch.autocast(device_type=bg_color.device.type, dtype=torch.float32): |
| rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( |
| means3D=means3D.float().cuda(), |
| means2D=means2D.float().cuda(), |
| shs=None, |
| colors_precomp=colors_precomp.cuda(), |
| opacities=opacity.float().cuda(), |
| scales=scales.float().cuda(), |
| rotations=rotations.float().cuda(), |
| cov3D_precomp=cov3D_precomp, |
| ) |
|
|
| ret = { |
| "comp_rgb": rendered_image.permute(1, 2, 0), |
| "comp_rgb_bg": bg_color, |
| "comp_mask": rendered_alpha.permute(1, 2, 0), |
| "comp_depth": rendered_depth.permute(1, 2, 0), |
| } |
|
|
| return ret |
|
|
| @torch.no_grad() |
| def rendering(self, save_path): |
| os.makedirs(save_path, exist_ok=True) |
| for cam_id, camera in enumerate(self.camera_views): |
| out = self.gsplat(camera, 1.0) |
| comp_rgb = out["comp_rgb"] |
| mask = out["comp_mask"] |
|
|
| |
|
|
| rgb = (comp_rgb * mask + (1 - mask) * 1) * 255 |
| rgb = rgb.detach().float().cpu().numpy().astype(np.uint8) |
|
|
| rgb = Image.fromarray(rgb) |
| save_rgb_path = os.path.join(save_path, f"{cam_id:05d}.png") |
| rgb.save(save_rgb_path) |
|
|
|
|
| def basename(path): |
| pre_name = os.path.basename(path).split(".")[0] |
|
|
| return pre_name |
|
|
|
|
| def main(): |
|
|
| view_path = ( |
| "./exps/output_gs/eval-heursample/LHM-A4O-SR-B-soft-hard-B-large/view_016/" |
| ) |
| save_path = "./debug/gs_render_debug" |
| os.makedirs(save_path, exist_ok=True) |
|
|
| ply_files = os.listdir(view_path) |
| for py_file in ply_files: |
| py_file = os.path.join(view_path, py_file) |
| print(f"py file, {py_file}") |
|
|
| renderer = MVSRender(py_file, radius=2.5) |
|
|
| save_dir = os.path.join(save_path, basename(py_file)) |
| renderer.rendering(save_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|