| | |
| | |
| | |
| | |
| | |
| | |
| | import os |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from core.models.rendering.base_gs_render import BaseGSRender |
| | from core.models.rendering.gaussian_decoder.mlp_decoder import GSMLPDecoder |
| | from core.models.rendering.utils.typing import * |
| | from core.outputs.output import GaussianAppOutput |
| |
|
| |
|
| | class GS3DRenderer(BaseGSRender): |
| | def __init__( |
| | self, |
| | human_model_path, |
| | subdivide_num, |
| | smpl_type, |
| | feat_dim, |
| | query_dim, |
| | use_rgb, |
| | sh_degree, |
| | xyz_offset_max_step, |
| | mlp_network_config, |
| | expr_param_dim, |
| | shape_param_dim, |
| | clip_scaling=0.2, |
| | cano_pose_type=0, |
| | decoder_mlp=False, |
| | skip_decoder=False, |
| | fix_opacity=False, |
| | fix_rotation=False, |
| | decode_with_extra_info=None, |
| | gradient_checkpointing=False, |
| | apply_pose_blendshape=False, |
| | dense_sample_pts=40000, |
| | gs_deform_scale=0.005, |
| | render_features=False, |
| | ): |
| | """ |
| | Initializes the GS3DRenderer, a subclass of BaseGSRender for 3D Gaussian Splatting rendering. |
| | |
| | Args: |
| | human_model_path (str): Path to human model files. |
| | subdivide_num (int): Subdivision number for base mesh. |
| | smpl_type (str): Type of SMPL/SMPL-X/other model to use. |
| | feat_dim (int): Dimension of feature embeddings. |
| | query_dim (int): Dimension of query points/features. |
| | use_rgb (bool): Whether to use RGB channels. |
| | sh_degree (int): Spherical harmonics degree for appearance. |
| | xyz_offset_max_step (float): Max offset per step for position. |
| | mlp_network_config (dict or None): MLP configuration for feature mapping. |
| | expr_param_dim (int): Expression parameter dimension. |
| | shape_param_dim (int): Shape parameter dimension. |
| | clip_scaling (float, optional): Output scaling for decoder. Default 0.2. |
| | cano_pose_type (int, optional): Canonical pose type. Default 0. |
| | decoder_mlp (bool, optional): Use MLP in decoder cross-attention. Default False. |
| | skip_decoder (bool, optional): Whether to skip decoder and cross-attn layers. Default False. |
| | fix_opacity (bool, optional): Fix opacity during training. Default False. |
| | fix_rotation (bool, optional): Fix rotation during training. Default False. |
| | decode_with_extra_info (dict or None, optional): Provide extra info to decoder. Default None. |
| | gradient_checkpointing (bool, optional): Enable gradient checkpointing. Default False. |
| | apply_pose_blendshape (bool, optional): Apply pose blendshape. Default False. |
| | dense_sample_pts (int, optional): Dense sample points for mesh/voxel. Default 40000. |
| | gs_deform_scale (float, optional): Deformation scale for Gaussian Splatting. Default 0.005. |
| | render_features (bool, optional): Output additional features in renderer. Default False. |
| | """ |
| |
|
| | super(GS3DRenderer, self).__init__( |
| | human_model_path, |
| | subdivide_num, |
| | smpl_type, |
| | feat_dim, |
| | query_dim, |
| | use_rgb, |
| | sh_degree, |
| | xyz_offset_max_step, |
| | mlp_network_config, |
| | expr_param_dim, |
| | shape_param_dim, |
| | clip_scaling, |
| | cano_pose_type, |
| | decoder_mlp, |
| | skip_decoder, |
| | fix_opacity, |
| | fix_rotation, |
| | decode_with_extra_info, |
| | gradient_checkpointing, |
| | apply_pose_blendshape, |
| | dense_sample_pts, |
| | gs_deform_scale, |
| | render_features, |
| | ) |
| |
|
| | self.gs_net = GSMLPDecoder( |
| | in_channels=query_dim, |
| | use_rgb=use_rgb, |
| | sh_degree=self.sh_degree, |
| | clip_scaling=clip_scaling, |
| | init_scaling=-6.0, |
| | init_density=0.1, |
| | xyz_offset=True, |
| | restrict_offset=True, |
| | xyz_offset_max_step=xyz_offset_max_step, |
| | fix_opacity=fix_opacity, |
| | fix_rotation=fix_rotation, |
| | use_fine_feat=( |
| | True |
| | if decode_with_extra_info is not None |
| | and decode_with_extra_info["type"] is not None |
| | else False |
| | ), |
| | ) |
| |
|
| | self.gs_deform_scale = gs_deform_scale |
| |
|
| | def hyper_step(self, step): |
| | """using to adjust the constrain scale.""" |
| |
|
| | self.gs_net.hyper_step(step) |
| |
|
| | def forward_gs_attr( |
| | self, x, query_points, smplx_data, debug=False, x_fine=None, mesh_meta=None |
| | ): |
| | """ |
| | x: [N, C] Float[Tensor, "Np Cp"], |
| | query_points: [N, 3] Float[Tensor, "Np 3"] |
| | """ |
| | device = x.device |
| | if self.mlp_network_config is not None: |
| | |
| | x = self.mlp_net(x) |
| | if x_fine is not None: |
| | x_fine = self.mlp_net(x_fine) |
| |
|
| | |
| |
|
| | is_constrain_body = mesh_meta["is_constrain_body"].to( |
| | self.smplx_model.is_constrain_body |
| | ) |
| | is_hands = (mesh_meta["is_rhand"] + mesh_meta["is_lhand"]).to( |
| | self.smplx_model.is_rhand |
| | ) |
| | is_upper_body = mesh_meta["is_upper_body"].to(self.smplx_model.is_upper_body) |
| | |
| | |
| | |
| |
|
| | constrain_dict = dict( |
| | is_constrain_body=is_constrain_body, |
| | is_hands=is_hands, |
| | is_upper_body=is_upper_body, |
| | ) |
| |
|
| | gs_attr: GaussianAppOutput = self.gs_net( |
| | x, query_points, x_fine, constrain_dict |
| | ) |
| |
|
| | return gs_attr |
| |
|
| |
|
| | def test(): |
| | import cv2 |
| |
|
| | human_model_path = "./pretrained_models/human_model_files" |
| | smplx_data_root = "/data1/projects/ExAvatar_RELEASE/avatar/data/Custom/data/gyeongsik/smplx_optimized/smplx_params_smoothed" |
| | shape_param_file = "/data1/projects/ExAvatar_RELEASE/avatar/data/Custom/data/gyeongsik/smplx_optimized/shape_param.json" |
| |
|
| | from core.models.rendering.smpl_x import read_smplx_param |
| |
|
| | batch_size = 1 |
| | device = "cuda" |
| | smplx_data, cam_param_list, ori_image_list = read_smplx_param( |
| | smplx_data_root=smplx_data_root, shape_param_file=shape_param_file, batch_size=2 |
| | ) |
| | smplx_data_tmp = smplx_data |
| | for k, v in smplx_data.items(): |
| | smplx_data_tmp[k] = v.unsqueeze(0) |
| | if (k == "betas") or (k == "face_offset") or (k == "joint_offset"): |
| | smplx_data_tmp[k] = v[0].unsqueeze(0) |
| | smplx_data = smplx_data_tmp |
| |
|
| | gs_render = GS3DRenderer( |
| | human_model_path=human_model_path, |
| | subdivide_num=2, |
| | smpl_type="smplx", |
| | feat_dim=64, |
| | query_dim=64, |
| | use_rgb=False, |
| | sh_degree=3, |
| | mlp_network_config=None, |
| | xyz_offset_max_step=1.8 / 32, |
| | ) |
| |
|
| | gs_render.to(device) |
| | |
| |
|
| | c2w_list = [] |
| | intr_list = [] |
| | for cam_param in cam_param_list: |
| | c2w = torch.eye(4).to(device) |
| | c2w[:3, :3] = cam_param["R"] |
| | c2w[:3, 3] = cam_param["t"] |
| | c2w_list.append(c2w) |
| | intr = torch.eye(4).to(device) |
| | intr[0, 0] = cam_param["focal"][0] |
| | intr[1, 1] = cam_param["focal"][1] |
| | intr[0, 2] = cam_param["princpt"][0] |
| | intr[1, 2] = cam_param["princpt"][1] |
| | intr_list.append(intr) |
| |
|
| | c2w = torch.stack(c2w_list).unsqueeze(0) |
| | intrinsic = torch.stack(intr_list).unsqueeze(0) |
| |
|
| | out = gs_render.forward( |
| | gs_hidden_features=torch.zeros((batch_size, 2048, 64)).float().to(device), |
| | query_points=None, |
| | smplx_data=smplx_data, |
| | c2w=c2w, |
| | intrinsic=intrinsic, |
| | height=int(cam_param_list[0]["princpt"][1]) * 2, |
| | width=int(cam_param_list[0]["princpt"][0]) * 2, |
| | background_color=torch.tensor([1.0, 1.0, 1.0]) |
| | .float() |
| | .view(1, 1, 3) |
| | .repeat(batch_size, 2, 1) |
| | .to(device), |
| | debug=False, |
| | ) |
| |
|
| | for k, v in out.items(): |
| | if k == "comp_rgb_bg": |
| | print("comp_rgb_bg", v) |
| | continue |
| | for b_idx in range(len(v)): |
| | if k == "3dgs": |
| | for v_idx in range(len(v[b_idx])): |
| | v[b_idx][v_idx].save_ply(f"./debug_vis/{b_idx}_{v_idx}.ply") |
| | continue |
| | for v_idx in range(v.shape[1]): |
| | save_path = os.path.join("./debug_vis", f"{b_idx}_{v_idx}_{k}.jpg") |
| | cv2.imwrite( |
| | save_path, |
| | (v[b_idx, v_idx].detach().cpu().numpy() * 255).astype(np.uint8), |
| | ) |
| |
|