LHMPP / core /models /rendering /gs_renderer.py
Lingteng Qiu (邱陵腾)
rm assets & wheels
434b0b0
# -*- coding: utf-8 -*-
# @Organization : Tongyi Lab, Alibaba
# @Author : Lingteng Qiu
# @Email : 220019047@link.cuhk.edu.cn
# @Time : 2025-06-04 20:43:18
# @Function : 3DGSRender Class
import os
import numpy as np
import torch
from core.models.rendering.base_gs_render import BaseGSRender
from core.models.rendering.gaussian_decoder.mlp_decoder import GSMLPDecoder
from core.models.rendering.utils.typing import *
from core.outputs.output import GaussianAppOutput
class GS3DRenderer(BaseGSRender):
def __init__(
self,
human_model_path,
subdivide_num,
smpl_type,
feat_dim,
query_dim,
use_rgb,
sh_degree,
xyz_offset_max_step,
mlp_network_config,
expr_param_dim,
shape_param_dim,
clip_scaling=0.2,
cano_pose_type=0,
decoder_mlp=False,
skip_decoder=False,
fix_opacity=False,
fix_rotation=False,
decode_with_extra_info=None,
gradient_checkpointing=False,
apply_pose_blendshape=False,
dense_sample_pts=40000, # only use for dense_smaple_smplx
gs_deform_scale=0.005,
render_features=False,
):
"""
Initializes the GS3DRenderer, a subclass of BaseGSRender for 3D Gaussian Splatting rendering.
Args:
human_model_path (str): Path to human model files.
subdivide_num (int): Subdivision number for base mesh.
smpl_type (str): Type of SMPL/SMPL-X/other model to use.
feat_dim (int): Dimension of feature embeddings.
query_dim (int): Dimension of query points/features.
use_rgb (bool): Whether to use RGB channels.
sh_degree (int): Spherical harmonics degree for appearance.
xyz_offset_max_step (float): Max offset per step for position.
mlp_network_config (dict or None): MLP configuration for feature mapping.
expr_param_dim (int): Expression parameter dimension.
shape_param_dim (int): Shape parameter dimension.
clip_scaling (float, optional): Output scaling for decoder. Default 0.2.
cano_pose_type (int, optional): Canonical pose type. Default 0.
decoder_mlp (bool, optional): Use MLP in decoder cross-attention. Default False.
skip_decoder (bool, optional): Whether to skip decoder and cross-attn layers. Default False.
fix_opacity (bool, optional): Fix opacity during training. Default False.
fix_rotation (bool, optional): Fix rotation during training. Default False.
decode_with_extra_info (dict or None, optional): Provide extra info to decoder. Default None.
gradient_checkpointing (bool, optional): Enable gradient checkpointing. Default False.
apply_pose_blendshape (bool, optional): Apply pose blendshape. Default False.
dense_sample_pts (int, optional): Dense sample points for mesh/voxel. Default 40000.
gs_deform_scale (float, optional): Deformation scale for Gaussian Splatting. Default 0.005.
render_features (bool, optional): Output additional features in renderer. Default False.
"""
super(GS3DRenderer, self).__init__(
human_model_path,
subdivide_num,
smpl_type,
feat_dim,
query_dim,
use_rgb,
sh_degree,
xyz_offset_max_step,
mlp_network_config,
expr_param_dim,
shape_param_dim,
clip_scaling,
cano_pose_type,
decoder_mlp,
skip_decoder,
fix_opacity,
fix_rotation,
decode_with_extra_info,
gradient_checkpointing,
apply_pose_blendshape,
dense_sample_pts, # only use for dense_smaple_smplx
gs_deform_scale,
render_features,
)
self.gs_net = GSMLPDecoder(
in_channels=query_dim,
use_rgb=use_rgb,
sh_degree=self.sh_degree,
clip_scaling=clip_scaling,
init_scaling=-6.0,
init_density=0.1,
xyz_offset=True,
restrict_offset=True,
xyz_offset_max_step=xyz_offset_max_step,
fix_opacity=fix_opacity,
fix_rotation=fix_rotation,
use_fine_feat=(
True
if decode_with_extra_info is not None
and decode_with_extra_info["type"] is not None
else False
),
)
self.gs_deform_scale = gs_deform_scale # deform mask scale.
def hyper_step(self, step):
"""using to adjust the constrain scale."""
self.gs_net.hyper_step(step)
def forward_gs_attr(
self, x, query_points, smplx_data, debug=False, x_fine=None, mesh_meta=None
):
"""
x: [N, C] Float[Tensor, "Np Cp"],
query_points: [N, 3] Float[Tensor, "Np 3"]
"""
device = x.device
if self.mlp_network_config is not None:
# x is processed by LayerNorm
x = self.mlp_net(x)
if x_fine is not None:
x_fine = self.mlp_net(x_fine)
# NOTE that gs_attr contains offset xyz
is_constrain_body = mesh_meta["is_constrain_body"].to(
self.smplx_model.is_constrain_body
)
is_hands = (mesh_meta["is_rhand"] + mesh_meta["is_lhand"]).to(
self.smplx_model.is_rhand
)
is_upper_body = mesh_meta["is_upper_body"].to(self.smplx_model.is_upper_body)
# is_constrain_body = self.smplx_model.is_constrain_body
# is_hands = self.smplx_model.is_rhand + self.smplx_model.is_lhand
# is_upper_body = self.smplx_model.is_upper_body
constrain_dict = dict(
is_constrain_body=is_constrain_body,
is_hands=is_hands,
is_upper_body=is_upper_body,
)
gs_attr: GaussianAppOutput = self.gs_net(
x, query_points, x_fine, constrain_dict
)
return gs_attr
def test():
import cv2
human_model_path = "./pretrained_models/human_model_files"
smplx_data_root = "/data1/projects/ExAvatar_RELEASE/avatar/data/Custom/data/gyeongsik/smplx_optimized/smplx_params_smoothed"
shape_param_file = "/data1/projects/ExAvatar_RELEASE/avatar/data/Custom/data/gyeongsik/smplx_optimized/shape_param.json"
from core.models.rendering.smpl_x import read_smplx_param
batch_size = 1
device = "cuda"
smplx_data, cam_param_list, ori_image_list = read_smplx_param(
smplx_data_root=smplx_data_root, shape_param_file=shape_param_file, batch_size=2
)
smplx_data_tmp = smplx_data
for k, v in smplx_data.items():
smplx_data_tmp[k] = v.unsqueeze(0)
if (k == "betas") or (k == "face_offset") or (k == "joint_offset"):
smplx_data_tmp[k] = v[0].unsqueeze(0)
smplx_data = smplx_data_tmp
gs_render = GS3DRenderer(
human_model_path=human_model_path,
subdivide_num=2,
smpl_type="smplx",
feat_dim=64,
query_dim=64,
use_rgb=False,
sh_degree=3,
mlp_network_config=None,
xyz_offset_max_step=1.8 / 32,
)
gs_render.to(device)
# print(cam_param_list[0])
c2w_list = []
intr_list = []
for cam_param in cam_param_list:
c2w = torch.eye(4).to(device)
c2w[:3, :3] = cam_param["R"]
c2w[:3, 3] = cam_param["t"]
c2w_list.append(c2w)
intr = torch.eye(4).to(device)
intr[0, 0] = cam_param["focal"][0]
intr[1, 1] = cam_param["focal"][1]
intr[0, 2] = cam_param["princpt"][0]
intr[1, 2] = cam_param["princpt"][1]
intr_list.append(intr)
c2w = torch.stack(c2w_list).unsqueeze(0)
intrinsic = torch.stack(intr_list).unsqueeze(0)
out = gs_render.forward(
gs_hidden_features=torch.zeros((batch_size, 2048, 64)).float().to(device),
query_points=None,
smplx_data=smplx_data,
c2w=c2w,
intrinsic=intrinsic,
height=int(cam_param_list[0]["princpt"][1]) * 2,
width=int(cam_param_list[0]["princpt"][0]) * 2,
background_color=torch.tensor([1.0, 1.0, 1.0])
.float()
.view(1, 1, 3)
.repeat(batch_size, 2, 1)
.to(device),
debug=False,
)
for k, v in out.items():
if k == "comp_rgb_bg":
print("comp_rgb_bg", v)
continue
for b_idx in range(len(v)):
if k == "3dgs":
for v_idx in range(len(v[b_idx])):
v[b_idx][v_idx].save_ply(f"./debug_vis/{b_idx}_{v_idx}.ply")
continue
for v_idx in range(v.shape[1]):
save_path = os.path.join("./debug_vis", f"{b_idx}_{v_idx}_{k}.jpg")
cv2.imwrite(
save_path,
(v[b_idx, v_idx].detach().cpu().numpy() * 255).astype(np.uint8),
)