LHMPP / engine /MVSRender /mvs_render.py
Lingteng Qiu (邱陵腾)
rm assets & wheels
434b0b0
# -*- coding: utf-8 -*-
# @Organization : Tongyi Lab, Alibaba
# @Author : Lingteng Qiu
# @Email : 220019047@link.cuhk.edu.cn
# @Time : 2025-08-31 10:02:15
# @Function : MVS Gaussian splatting renderer
import sys
sys.path.append("./")
import math
import os
import pdb
import numpy as np
import torch
from plyfile import PlyData, PlyElement
from core.models.rendering.utils.sh_utils import RGB2SH, SH2RGB
from core.models.rendering.utils.typing import *
from engine.MVSRender.camera_utils import MiniCam, OrbitCamera, orbit_camera
def avaliable_device():
if torch.cuda.is_available():
current_device_id = torch.cuda.current_device()
device = f"cuda:{current_device_id}"
else:
device = "cpu"
return device
from diff_gaussian_rasterization import (
GaussianRasterizationSettings,
GaussianRasterizer,
)
from PIL import Image
def inverse_sigmoid(x):
if isinstance(x, float):
x = torch.tensor(x).float()
return torch.log(x / (1 - x))
class GaussianModel:
def setup_functions(self):
self.scaling_activation = torch.exp
self.scaling_inverse_activation = torch.log
self.opacity_activation = torch.sigmoid
self.inverse_opacity_activation = inverse_sigmoid
self.rotation_activation = torch.nn.functional.normalize
# rgb activation function
self.rgb_activation = torch.sigmoid
def __init__(self, use_rgb=True):
self.setup_functions()
self.xyz: Tensor = torch.empty(0)
self.opacity: Tensor = torch.empty(0)
self.rotation: Tensor = torch.empty(0)
self.scaling: Tensor = torch.empty(0)
self.shs: Tensor = torch.empty(0)
self.use_rgb = use_rgb # shs indicates rgb?
def construct_list_of_attributes(self):
l = ["x", "y", "z", "nx", "ny", "nz"]
features_dc = self.shs[:, :1]
features_rest = self.shs[:, 1:]
for i in range(features_dc.shape[1] * features_dc.shape[2]):
l.append("f_dc_{}".format(i))
for i in range(features_rest.shape[1] * features_rest.shape[2]):
l.append("f_rest_{}".format(i))
l.append("opacity")
for i in range(self.scaling.shape[1]):
l.append("scale_{}".format(i))
for i in range(self.rotation.shape[1]):
l.append("rot_{}".format(i))
return l
def save_ply(self, path):
xyz = self.xyz.detach().cpu().numpy()
normals = np.zeros_like(xyz)
if self.use_rgb:
shs = RGB2SH(self.shs)
else:
shs = self.shs
features_dc = shs[:, :1]
features_rest = shs[:, 1:]
f_dc = (
features_dc.float().detach().flatten(start_dim=1).contiguous().cpu().numpy()
)
f_rest = (
features_rest.float()
.detach()
.flatten(start_dim=1)
.contiguous()
.cpu()
.numpy()
)
opacities = (
# inverse_sigmoid(torch.clamp(self.opacity))
inverse_sigmoid(self.opacity)
.detach()
.cpu()
.numpy()
)
scale = np.log(self.scaling.detach().cpu().numpy())
rotation = self.rotation.detach().cpu().numpy()
dtype_full = [
(attribute, "f4") for attribute in self.construct_list_of_attributes()
]
elements = np.empty(xyz.shape[0], dtype=dtype_full)
attributes = np.concatenate(
(xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1
)
elements[:] = list(map(tuple, attributes))
el = PlyElement.describe(elements, "vertex")
PlyData([el]).write(path)
def load_ply(self, path):
plydata = PlyData.read(path)
xyz = np.stack(
(
np.asarray(plydata.elements[0]["x"]),
np.asarray(plydata.elements[0]["y"]),
np.asarray(plydata.elements[0]["z"]),
),
axis=1,
)
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
features_dc = np.zeros((xyz.shape[0], 3, 1))
features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
extra_f_names = [
p.name
for p in plydata.elements[0].properties
if p.name.startswith("f_rest_")
]
extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1]))
sh_degree = int(math.sqrt((len(extra_f_names) + 3) / 3)) - 1
features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
for idx, attr_name in enumerate(extra_f_names):
features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
# Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
# 0, 3, 8, 15
features_extra = features_extra.reshape(
(features_extra.shape[0], 3, (sh_degree + 1) ** 2 - 1)
)
scale_names = [
p.name
for p in plydata.elements[0].properties
if p.name.startswith("scale_")
]
scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
scales = np.zeros((xyz.shape[0], len(scale_names)))
for idx, attr_name in enumerate(scale_names):
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
rot_names = [
p.name for p in plydata.elements[0].properties if p.name.startswith("rot")
]
rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
rots = np.zeros((xyz.shape[0], len(rot_names)))
for idx, attr_name in enumerate(rot_names):
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
xyz = torch.from_numpy(xyz).to(self.xyz)
opacities = torch.from_numpy(opacities).to(self.opacity)
rotation = torch.from_numpy(rots).to(self.rotation)
scales = torch.from_numpy(scales).to(self.scaling)
features_dc = torch.from_numpy(features_dc).to(self.shs)
features_rest = torch.from_numpy(features_extra).to(self.shs)
shs = torch.cat([features_dc, features_rest], dim=2)
if self.use_rgb:
shs = SH2RGB(shs)
else:
shs = shs
self.xyz: Tensor = xyz
self.opacity: Tensor = self.opacity_activation(opacities)
self.rotation: Tensor = self.rotation_activation(rotation)
self.scaling: Tensor = self.scaling_activation(scales)
self.shs: Tensor = shs.permute(0, 2, 1)
self.active_sh_degree = sh_degree
def camera_traj(cam, ref_size=1024, views=30, radius=2.0):
azimuth_bins = 360 // views
cameras = []
for rotate_i in range(30):
pose = orbit_camera(
0,
0 + rotate_i * azimuth_bins,
radius,
)
cur_cam = MiniCam(
pose,
ref_size,
ref_size,
cam.fovy,
cam.fovx,
cam.near,
cam.far,
)
cameras.append(cur_cam)
return cameras
class MVSRender:
def __init__(
self, gs_file, fovy=50, ref_size=1024, views=30, radius=2.0, sh_degree=0
):
self.cam = OrbitCamera(ref_size, ref_size, r=radius, fovy=fovy)
self.fovy = fovy
self.ref_size = ref_size
self.views = views
self.radius = radius
input_gs = GaussianModel(use_rgb=True)
input_gs.load_ply(gs_file)
self.gs = input_gs
self.autosize()
self.camera_views = camera_traj(self.cam, ref_size, views, radius)
self.device = avaliable_device()
self.sh_degree = sh_degree
def autosize(self):
xyz = self.gs.xyz
min_xyz = xyz.min(dim=0).values
max_xyz = xyz.max(dim=0).values
middle_offset = (min_xyz + max_xyz) / 2
xyz -= middle_offset
self.gs.xyz = xyz
self.offset = middle_offset
def gsplat(
self,
viewpoint_camera,
background_color,
):
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
gs = self.gs
screenspace_points = (
torch.zeros_like(
gs.xyz, dtype=gs.xyz.dtype, requires_grad=True, device=self.device
)
+ 0
)
try:
screenspace_points.retain_grad()
except:
pass
bg_color = torch.ones(3).float().to(self.device)
# Set up rasterization configuration
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
raster_settings = GaussianRasterizationSettings(
image_height=int(viewpoint_camera.image_height),
image_width=int(viewpoint_camera.image_width),
tanfovx=tanfovx,
tanfovy=tanfovy,
bg=bg_color,
scale_modifier=1.0,
viewmatrix=viewpoint_camera.world_view_transform,
projmatrix=viewpoint_camera.full_proj_transform.float(),
sh_degree=self.sh_degree,
campos=viewpoint_camera.camera_center,
prefiltered=False,
debug=False,
)
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
means3D = gs.xyz
means2D = screenspace_points
opacity = gs.opacity
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
# scaling / rotation by the rasterizer.
scales = None
rotations = None
cov3D_precomp = None
scales = gs.scaling
rotations = gs.rotation
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
shs = None
colors_precomp = None
colors_precomp = gs.shs.squeeze(1).float()
shs = None
# Rasterize visible Gaussians to image, obtain their radii (on screen).
# NOTE that dadong tries to regress rgb not shs
with torch.autocast(device_type=bg_color.device.type, dtype=torch.float32):
rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(
means3D=means3D.float().cuda(),
means2D=means2D.float().cuda(),
shs=None,
colors_precomp=colors_precomp.cuda(),
opacities=opacity.float().cuda(),
scales=scales.float().cuda(),
rotations=rotations.float().cuda(),
cov3D_precomp=cov3D_precomp,
)
ret = {
"comp_rgb": rendered_image.permute(1, 2, 0), # [H, W, 3]
"comp_rgb_bg": bg_color,
"comp_mask": rendered_alpha.permute(1, 2, 0),
"comp_depth": rendered_depth.permute(1, 2, 0),
}
return ret
@torch.no_grad()
def rendering(self, save_path):
os.makedirs(save_path, exist_ok=True)
for cam_id, camera in enumerate(self.camera_views):
out = self.gsplat(camera, 1.0)
comp_rgb = out["comp_rgb"]
mask = out["comp_mask"]
# mask[mask < 0.5] = 0.0
rgb = (comp_rgb * mask + (1 - mask) * 1) * 255
rgb = rgb.detach().float().cpu().numpy().astype(np.uint8)
rgb = Image.fromarray(rgb)
save_rgb_path = os.path.join(save_path, f"{cam_id:05d}.png")
rgb.save(save_rgb_path)
def basename(path):
pre_name = os.path.basename(path).split(".")[0]
return pre_name
def main():
view_path = (
"./exps/output_gs/eval-heursample/LHM-A4O-SR-B-soft-hard-B-large/view_016/"
)
save_path = "./debug/gs_render_debug"
os.makedirs(save_path, exist_ok=True)
ply_files = os.listdir(view_path)
for py_file in ply_files:
py_file = os.path.join(view_path, py_file)
print(f"py file, {py_file}")
renderer = MVSRender(py_file, radius=2.5)
save_dir = os.path.join(save_path, basename(py_file))
renderer.rendering(save_dir)
if __name__ == "__main__":
main()