# -*- coding: utf-8 -*- # @Organization : Tongyi Lab, Alibaba # @Author : Lingteng Qiu # @Email : 220019047@link.cuhk.edu.cn # @Time : 2025-10-14 19:43:20 # @Function : GSPlat-based Renderer. try: from gsplat.rendering import rasterization gsplat_enable = True except: gsplat_enable = False from collections import defaultdict import torch import torch.nn as nn from pytorch3d.transforms import matrix_to_quaternion from core.models.rendering.gs_renderer import GS3DRenderer from core.models.rendering.utils.typing import * from core.outputs.output import GaussianAppOutput from core.structures.camera import Camera from core.structures.gaussian_model import GaussianModel def scale_intrs(intrs, ratio_x, ratio_y): if len(intrs.shape) >= 3: intrs[:, 0] = intrs[:, 0] * ratio_x intrs[:, 1] = intrs[:, 1] * ratio_y else: intrs[0] = intrs[0] * ratio_x intrs[1] = intrs[1] * ratio_y return intrs def aabb(xyz): return torch.min(xyz, dim=0).values, torch.max(xyz, dim=0).values class GSPlatRenderer(GS3DRenderer): def __init__( self, human_model_path, subdivide_num, smpl_type, feat_dim, query_dim, use_rgb, sh_degree, xyz_offset_max_step, mlp_network_config, expr_param_dim, shape_param_dim, clip_scaling=0.2, cano_pose_type=0, decoder_mlp=False, skip_decoder=False, fix_opacity=False, fix_rotation=False, decode_with_extra_info=None, gradient_checkpointing=False, apply_pose_blendshape=False, dense_sample_pts=40000, # only use for dense_smaple_smplx gs_deform_scale=0.005, render_features=False, ): """ Initializes the GSPlatRenderer, an extension of GS3DRenderer for Gaussian Splatting rendering. Args: human_model_path (str): Path to human model files. subdivide_num (int): Subdivision number for base mesh. smpl_type (str): Type of SMPL/SMPL-X/other model to use. feat_dim (int): Dimension of feature embeddings. query_dim (int): Dimension of query points/features. use_rgb (bool): Whether to use RGB channels. sh_degree (int): Spherical harmonics degree for appearance. xyz_offset_max_step (float): Max offset per step for position. mlp_network_config (dict or None): MLP configuration for feature mapping. expr_param_dim (int): Expression parameter dimension. shape_param_dim (int): Shape parameter dimension. clip_scaling (float, optional): Output scaling for decoder. Default 0.2. cano_pose_type (int, optional): Canonical pose type. Default 0. decoder_mlp (bool, optional): Use MLP in decoder cross-attention. Default False. skip_decoder (bool, optional): Whether to skip decoder and cross-attn layers. Default False. fix_opacity (bool, optional): Fix opacity during training. Default False. fix_rotation (bool, optional): Fix rotation during training. Default False. decode_with_extra_info (dict or None, optional): Provide extra info to decoder. Default None. gradient_checkpointing (bool, optional): Enable gradient checkpointing. Default False. apply_pose_blendshape (bool, optional): Apply pose blendshape. Default False. dense_sample_pts (int, optional): Dense sample points for mesh/voxel. Default 40000. gs_deform_scale (float, optional): Deformation scale for Gaussian Splatting. Default 0.005. render_features (bool, optional): Output additional features in renderer. Default False. """ if gsplat_enable is False: raise ImportError("GSPlat is not installed, please install it first.") else: super(GSPlatRenderer, self).__init__( human_model_path, subdivide_num, smpl_type, feat_dim, query_dim, use_rgb, sh_degree, xyz_offset_max_step, mlp_network_config, expr_param_dim, shape_param_dim, clip_scaling, cano_pose_type, decoder_mlp, skip_decoder, fix_opacity, fix_rotation, decode_with_extra_info, gradient_checkpointing, apply_pose_blendshape, dense_sample_pts, # only use for dense_smaple_smplx gs_deform_scale, render_features, ) def get_gaussians_properties(self, viewpoint_camera, gaussian_model): """ Extracts and returns the 3D Gaussian properties for rendering from a GaussianModel instance in the context of the provided viewpoint camera. Args: viewpoint_camera (Camera): The viewpoint camera for rendering. (Unused in this stub, but kept for interface compatibility.) gaussian_model (GaussianModel): The GaussianModel object containing the properties to extract. Returns: Tuple: xyz (Tensor): The 3D coordinates of the Gaussians. shs (Tensor or None): The spherical harmonics coefficients or None. colors_precomp (Tensor): The precomputed RGB colors (if use_rgb). opacity (Tensor): The opacities of the Gaussians. scales (Tensor): The scaling factors per Gaussian. rotations (Tensor): Quaternion rotations per Gaussian. cov3D_precomp (None): Reserved for covariance data (not used here). """ xyz = gaussian_model.xyz opacity = gaussian_model.opacity scales = gaussian_model.scaling rotations = gaussian_model.rotation cov3D_precomp = None shs = None if gaussian_model.use_rgb: colors_precomp = gaussian_model.shs else: raise NotImplementedError return xyz, shs, colors_precomp, opacity, scales, rotations, cov3D_precomp def forward_single_view( self, gaussian_model: GaussianModel, viewpoint_camera: Camera, background_color: Optional[Float[Tensor, "3"]], ret_mask: bool = True, features=None, ): """ Renders a single view using the provided GaussianModel and camera parameters. Args: gaussian_model (GaussianModel): The 3D Gaussian model to be rendered. viewpoint_camera (Camera): The viewpoint camera used for rendering (contains intrinsics, extrinsics, size). background_color (Optional[Float[Tensor, "3"]]): Optional background color for the rendered image. ret_mask (bool, optional): Whether to return the alpha mask. Default: True. features (Optional[Tensor], optional): Optional feature tensor to concatenate with the Gaussian appearance. Returns: dict: { "comp_rgb": Rendered RGB image (H, W, 3), "comp_mask": Rendered alpha mask (H, W), "comp_features": (optional) Rendered additional features (if features is not None) } """ xyz, shs, colors_precomp, opacity, scales, rotations, cov3D_precomp = ( self.get_gaussians_properties(viewpoint_camera, gaussian_model) ) intrinsics = viewpoint_camera.intrinsic extrinsics = viewpoint_camera.world_view_transform.transpose( 0, 1 ).contiguous() # c2w -> w2c img_height = int(viewpoint_camera.height) img_width = int(viewpoint_camera.width) colors_precomp = colors_precomp.squeeze(1) opacity = opacity.squeeze(1) if features is not None: colors_precomp = torch.cat([colors_precomp, features], dim=1) channel = colors_precomp.shape[1] background_color = background_color[0].repeat(channel) with torch.autocast(device_type=self.device.type, dtype=torch.float32): render_rgbd, render_alphas, meta = rasterization( means=xyz.float(), quats=rotations.float(), scales=scales.float(), opacities=opacity.float(), colors=colors_precomp.float(), viewmats=extrinsics.unsqueeze(0).float(), Ks=intrinsics.float().unsqueeze(0)[:, :3, :3], width=img_width, height=img_height, near_plane=viewpoint_camera.znear, far_plane=viewpoint_camera.zfar, # radius_clip=3.0, eps2d=0.3, # 3 pixel render_mode="RGB", # render_mode="RGB+D", backgrounds=background_color.unsqueeze(0).float(), camera_model="pinhole", ) render_rgbd = render_rgbd.squeeze(0) render_alphas = render_alphas.squeeze(0) rendered_image = render_rgbd[:, :, :3] # rendered_depth = render_rgbd[:, :, -1:] if features is not None: ret = { "comp_rgb": rendered_image, # [H, W, 3] "comp_features": render_rgbd[:, :, 3:], "comp_rgb_bg": background_color, "comp_mask": render_alphas, # "comp_depth": rendered_depth, } else: ret = { "comp_rgb": rendered_image, # [H, W, 3] "comp_rgb_bg": background_color, "comp_mask": render_alphas, # "comp_depth": rendered_depth, } return ret class GSPlatFeatRenderer(GSPlatRenderer): """ GSPlatFeatRenderer extends GSPlatRenderer to support rendering outputs with additional learned feature channels, in addition to RGB images, masks, and optional background handling. This class modifies the behavior of `_render_views` and related functions to enable rendering per-pixel features, which can be used for downstream tasks like representation learning or multimodal perception. Typical usage mirrors GSPlatRenderer but expects `features` to be provided as input for rendering and produces feature maps in the output dictionary under the key `"comp_features"`, along with the rendered RGB image and mask. The rest of the pipeline, including Gaussian attribute animation with SMPL-X or other mesh deformation, is inherited. """ def forward_animate_gs( self, gs_attr_list: List[GaussianAppOutput], query_points: Dict[str, Tensor], smplx_data: Dict[str, Tensor], c2w: Float[Tensor, "B Nv 4 4"], intrinsic: Float[Tensor, "B Nv 4 4"], height: int, width: int, background_color: Optional[Float[Tensor, "B Nv 3"]] = None, debug: bool = False, df_data: Optional[Dict] = None, patch_size=14, **kwargs, ) -> Dict[str, Tensor]: """ Animate and render Gaussian Splatting (GS) models with feature support for a batch of frames/views. Args: gs_attr_list (List[GaussianAppOutput]): List of Gaussian attribute outputs, one per batch item. Each element contains predicted Gaussian parameters such as offset positions, opacity, rotation, scaling, and appearance for canonical points. query_points (Dict[str, Tensor]): Dictionary containing query information, must include: - 'neutral_coords': Tensor of canonical coordinate positions, shape [B, N, 3]. - 'mesh_meta': (Optional) Dictionary with mesh region meta-info as required by the skinning/posing models. smplx_data (Dict[str, Tensor]): Dictionary containing per-batch SMPL-X (or similar model) data for the current animation/frame. Used for pose and shape transformation. c2w (Float[Tensor, "B Nv 4 4"]): Camera-to-world matrices for the views to render (B: batch, Nv: number of views). intrinsic (Float[Tensor, "B Nv 4 4"]): Intrinsic camera matrices, shape matches c2w. height (int): Height of output render images (in pixels). width (int): Width of output render images (in pixels). background_color (Optional[Float[Tensor, "B Nv 3"]], default=None): Optional RGB background color per batch/view. debug (bool, optional): If True, enables debug behavior (e.g., simplifies opacities, disables poses, saves debug visualizations). df_data (Optional[Dict], default=None): Optional dictionary of additional deformation/feature data. patch_size (int, optional): Size of patches to use for rendering (default 14). **kwargs: Additional keyword arguments. Can optionally contain 'features' key for render feature maps. Returns: Dict[str, Tensor]: Dictionary of rendered outputs, including: - Rendered features (under 'comp_features'), images, masks, etc., batched accordingly. - '3dgs': List of all canonical-space GaussianModel instances for the batch. """ batch_size = len(gs_attr_list) out_list, cano_out_list = [], [] query_points_pos = query_points["neutral_coords"] mesh_meta = query_points["mesh_meta"] gs_list = [] for b in range(batch_size): # Animate GS models anim_models, cano_models = self.animate_gs_model( gs_attr_list[b], query_points_pos[b], self._get_single_batch_data(smplx_data, b), debug=debug, mesh_meta=mesh_meta, ) gs_list.extend(cano_models) features = ( kwargs["features"][b] if kwargs.get("features") is not None else None ) # Render animated views out_list.append( self._render_views( anim_models[: c2w.shape[1]], # Only keep requested views c2w[b], intrinsic[b], height, width, background_color[b] if background_color is not None else None, debug, patch_size=patch_size, features=features, ) ) results = self._combine_outputs(out_list, cano_out_list) results["3dgs"] = gs_list return results def animate_gs_model( self, gs_attr: GaussianAppOutput, query_points, smplx_data, debug=False, mesh_meta=None, ): """ Animates the Gaussian Splatting (GS) model by transforming canonical (neutral) points and attributes into the posed space using SMPL-X model deformations. Args: gs_attr (GaussianAppOutput): Gaussian attribute output for canonical points, including offset positions, opacity, rotation, scaling, and appearance. query_points (Tensor): Canonical query point coordinates, shape (N, 3). smplx_data (dict): SMPL-X input data for the current animation frame, including body pose, shape, etc. debug (bool, optional): If True, use debug mode (e.g., force all opacities to 1.0, use identity rotations). Default: False. mesh_meta (dict, optional): Additional mesh region meta-information (e.g., for constraints). Default: None. Returns: Tuple[List[GaussianModel], List[GaussianModel]]: - gs_list: List of posed-space GaussianModel instances (one per camera/view except canonical view). - cano_gs_list: List of canonical-space GaussianModel instances (last view is canonical). """ device = gs_attr.offset_xyz.device if debug: N = gs_attr.offset_xyz.shape[0] gs_attr.xyz = torch.zeros_like(gs_attr.offset_xyz) gs_attr.opacity = torch.ones((N, 1), device=device) gs_attr.rotation = matrix_to_quaternion( torch.eye(3, device=device).expand(N, 3, 3) ) # build cano_dependent_pose merge_smplx_data = self._prepare_smplx_data(smplx_data) posed_points = self._transform_points( merge_smplx_data, query_points, gs_attr.offset_xyz, device, mesh_meta ) rotation_pose_verts = self._compute_rotations( posed_points["transform_mat_posed"], gs_attr.rotation, device, posed_points["mesh_meta"], ) return self._create_gaussian_models( posed_points["posed_coords"], gs_attr, rotation_pose_verts, merge_smplx_data["body_pose"].shape[0], ) def _render_views( self, gs_list: List[GaussianModel], c2w: Tensor, intrinsic: Tensor, height: int, width: int, bg_color: Optional[Tensor], debug: bool, patch_size: int, **kwargs, ) -> Dict[str, Tensor]: """ Renders multiple views for a list of GaussianModel instances. Args: gs_list (List[GaussianModel]): List of GaussianModel instances to render for each view. c2w (Tensor): Camera-to-world matrices for each view, shape [Nv, 4, 4]. intrinsic (Tensor): Intrinsic camera matrices for each view, shape [Nv, 4, 4]. height (int): Output image height (in pixels). width (int): Output image width (in pixels). bg_color (Optional[Tensor]): Optional background color tensor, shape [Nv, 3] or None. debug (bool): If True, enable debug visualizations/saving intermediate images. patch_size (int): Patch size used in rendering for downsampling/aggregation. **kwargs: Additional keyword arguments, can include 'features' per view if rendering extra feature channels. Returns: Dict[str, Tensor]: Dictionary with rendering outputs for all views. Main key 'render' is a list of dictionaries per view, each with keys such as 'comp_rgb', 'comp_mask', and, if features are provided, 'comp_features'. """ # obtain device self.device = gs_list[0].xyz.device results = defaultdict(list) for v_idx, gs in enumerate(gs_list): if self.render_features: render_features = kwargs["features"] else: render_features = None camera = Camera.from_c2w(c2w[v_idx], intrinsic[v_idx], height, width) results["render"].append( self.forward_single_view( gs, camera, bg_color[v_idx], features=render_features, patch_size=patch_size, ) ) if debug and v_idx == 0: self._debug_save_image(results["render"][-1]["comp_rgb"]) return results def forward_single_view( self, gaussian_model: GaussianModel, viewpoint_camera: Camera, background_color: Optional[Float[Tensor, "3"]], ret_mask: bool = True, features=None, patch_size=14, ): """ Renders a single view for a GaussianModel with optional extra features and positional embedding patching. Args: gaussian_model (GaussianModel): The 3D Gaussian model to be rendered. viewpoint_camera (Camera): The viewpoint camera containing intrinsics, extrinsics, and image size. background_color (Optional[Float[Tensor, "3"]]): RGB background color for the rendered image. ret_mask (bool, optional): Whether to return the alpha mask in the output. Default: True. features (Optional[Tensor], optional): Additional features to concatenate with Gaussian appearance. Default: None. patch_size (int, optional): Size of patches for rendering downsampling. Default: 14. Returns: dict: { "comp_rgb": Rendered RGB image (patch_height, patch_width, 3), "comp_mask": Rendered alpha mask (patch_height, patch_width), "comp_features": (optional) Rendered feature image (patch_height, patch_width, channels) if features are provided, ... (other keys as required for downstream use) } """ xyz, shs, colors_precomp, opacity, scales, rotations, cov3D_precomp = ( self.get_gaussians_properties(viewpoint_camera, gaussian_model) ) extrinsics = viewpoint_camera.world_view_transform.transpose( 0, 1 ).contiguous() # c2w -> w2c # scale_ratio for patch size img_height = int(viewpoint_camera.height) img_width = int(viewpoint_camera.width) patch_height = img_height // patch_size patch_width = img_width // patch_size scale_y = img_height / patch_height scale_x = img_width / patch_width intrinsics = scale_intrs( viewpoint_camera.intrinsic.clone(), 1.0 / scale_x, 1.0 / scale_y ) colors_precomp = colors_precomp.squeeze(1) opacity = opacity.squeeze(1) if features is not None: colors_precomp = torch.cat([colors_precomp, features], dim=1) channel = colors_precomp.shape[1] background_color = background_color[0].repeat(channel) with torch.autocast(device_type=self.device.type, dtype=torch.float32): render_rgbd, render_alphas, meta = rasterization( means=xyz.float(), quats=rotations.float(), scales=scales.float(), opacities=opacity.float(), colors=colors_precomp.float(), viewmats=extrinsics.unsqueeze(0).float(), Ks=intrinsics.float().unsqueeze(0)[:, :3, :3], width=patch_width, height=patch_height, near_plane=viewpoint_camera.znear, far_plane=viewpoint_camera.zfar, # radius_clip=3.0, eps2d=0.3, # 3 pixel render_mode="RGB", # render_mode="RGB+D", backgrounds=background_color.unsqueeze(0).float(), camera_model="pinhole", ) render_rgbd = render_rgbd.squeeze(0) render_alphas = render_alphas.squeeze(0) rendered_image = render_rgbd[:, :, :3] if features is not None: ret = { "comp_features": render_rgbd[:, :, 3:], "comp_mask": render_alphas, "comp_rgb": rendered_image, # [H, W, 3] # "comp_depth": rendered_depth, } else: ret = { "comp_rgb": rendered_image, # [H, W, 3] "comp_rgb_bg": background_color, "comp_mask": render_alphas, # "comp_depth": rendered_depth, } return ret def forward( self, gs_hidden_features: Float[Tensor, "B Np Cp"], query_points: dict, smplx_data, # e.g., body_pose:[B, Nv, 21, 3], betas:[B, 100] c2w: Float[Tensor, "B Nv 4 4"], intrinsic: Float[Tensor, "B Nv 4 4"], height, width, additional_features: Optional[Float[Tensor, "B C H W"]] = None, background_color: Optional[Float[Tensor, "B Nv 3"]] = None, debug: bool = False, **kwargs, ): """ Forward rendering pass for the GSPlatFeatPosEmbedRenderer. Args: gs_hidden_features (Float[Tensor, "B Np Cp"]): Gaussian hidden features for the batch (B: batch size, Np: number of points, Cp: feature channels). query_points (dict): Dictionary containing query data, typically including canonical coordinates and mesh metadata. smplx_data: SMPL-X (or similar parametric model) data; for example, may contain pose, shape, and other parameters. c2w (Float[Tensor, "B Nv 4 4"]): Camera-to-world transformation matrices for each batch and view. intrinsic (Float[Tensor, "B Nv 4 4"]): Intrinsic camera matrices, matching batch and view dimensions. height (int): Height of the output images. width (int): Width of the output images. additional_features (Optional[Float[Tensor, "B C H W"]], optional): Additional feature maps to include in rendering, if provided. background_color (Optional[Float[Tensor, "B Nv 3"]], optional): Optional background RGB colors per batch/view. debug (bool, optional): If True, enables debug information or simplified behavior. **kwargs: Additional keyword arguments. Typically includes: - df_data: Deformation/feature data for advanced driving fields. - patch_size: Patch size for chunked rendering. Returns: Dict[str, Tensor]: Output dictionary containing rendered images (e.g., 'comp_rgb'), optional features (e.g., 'comp_features'), masks, Gaussian attributes, and mesh metadata. """ # need shape_params of smplx_data to get querty points and get "transform_mat_neutral_pose" # only forward gs params gs_attr_list, query_points, smplx_data = self.forward_gs( gs_hidden_features, query_points, smplx_data=smplx_data, additional_features=additional_features, debug=debug, ) out = self.forward_animate_gs( gs_attr_list, query_points, smplx_data, c2w, intrinsic, height, width, background_color, debug, df_data=kwargs["df_data"], patch_size=kwargs.get("patch_size", 14), features=gs_hidden_features, ) out["gs_attr"] = gs_attr_list out["mesh_meta"] = query_points["mesh_meta"] return out def _combine_outputs( self, out_list: List[Dict], cano_out_list: List[Dict] ) -> Dict[str, Tensor]: """ Combines the outputs from multiple rendered views (out_list) and canonical outputs (cano_out_list) into a single batched output dictionary. Args: out_list (List[Dict]): A list containing the output dictionaries for each batch/item, where each dictionary holds the rendered outputs for each viewpoint. cano_out_list (List[Dict]): A list of output dictionaries for the canonical (neutral pose) view per batch item. Returns: Dict[str, Tensor]: A combined dictionary of batched output tensors (organized by key), typically containing: - Rendered RGB images, masks, and features with shape [batch_size, num_views, ...] - Canonical view outputs under distinct keys, if required. - Any other output keys from the rendering pipeline. """ batch_size = len(out_list) combined = defaultdict(list) for out in out_list: # Collect render outputs for render_item in out["render"]: for k, v in render_item.items(): combined[k].append(v) # Reshape and permute tensors result = { k: torch.stack(v).view(batch_size, -1, *v[0].shape).permute(0, 1, 4, 2, 3) for k, v in combined.items() if torch.stack(v).dim() >= 4 } return result class GSPlatBackFeatRenderer(GSPlatFeatRenderer): """Adding an embedding to model background color""" def __init__( self, human_model_path, subdivide_num, smpl_type, feat_dim, query_dim, use_rgb, sh_degree, xyz_offset_max_step, mlp_network_config, expr_param_dim, shape_param_dim, clip_scaling=0.2, cano_pose_type=0, decoder_mlp=False, skip_decoder=False, fix_opacity=False, fix_rotation=False, decode_with_extra_info=None, gradient_checkpointing=False, apply_pose_blendshape=False, dense_sample_pts=40000, # only use for dense_smaple_smplx gs_deform_scale=0.005, render_features=False, ): """ Initializes the GSPlatBackFeatRenderer, an extension of GSPlatFeatRenderer adding a learnable embedding to model background color. Args: human_model_path (str): Path to human model files. subdivide_num (int): Subdivision number for base mesh. smpl_type (str): Type of SMPL/SMPL-X/other model to use. feat_dim (int): Dimension of feature embeddings. query_dim (int): Dimension of query points/features. use_rgb (bool): Whether to use RGB channels. sh_degree (int): Spherical harmonics degree for appearance. xyz_offset_max_step (float): Max offset per step for position. mlp_network_config (dict or None): MLP configuration for feature mapping. expr_param_dim (int): Expression parameter dimension. shape_param_dim (int): Shape parameter dimension. clip_scaling (float, optional): Output scaling for decoder. Default 0.2. cano_pose_type (int, optional): Canonical pose type. Default 0. decoder_mlp (bool, optional): Use MLP in decoder cross-attention. Default False. skip_decoder (bool, optional): Whether to skip decoder and cross-attn layers. Default False. fix_opacity (bool, optional): Fix opacity during training. Default False. fix_rotation (bool, optional): Fix rotation during training. Default False. decode_with_extra_info (dict or None, optional): Provide extra info to decoder. Default None. gradient_checkpointing (bool, optional): Enable gradient checkpointing. Default False. apply_pose_blendshape (bool, optional): Apply pose blendshape. Default False. dense_sample_pts (int, optional): Dense sample points for mesh/voxel. Default 40000. gs_deform_scale (float, optional): Deformation scale for Gaussian Splatting. Default 0.005. render_features (bool, optional): Output additional features in renderer. Default False. Adds: self.background_embedding (nn.Parameter): Learnable background embedding for enhanced modeling. """ super(GSPlatBackFeatRenderer, self).__init__( human_model_path, subdivide_num, smpl_type, feat_dim, query_dim, use_rgb, sh_degree, xyz_offset_max_step, mlp_network_config, expr_param_dim, shape_param_dim, clip_scaling, cano_pose_type, decoder_mlp, skip_decoder, fix_opacity, fix_rotation, decode_with_extra_info, gradient_checkpointing, apply_pose_blendshape, dense_sample_pts, # only use for dense_smaple_smplx gs_deform_scale, render_features, ) # learnable positional embedding self.background_embedding = nn.Parameter(torch.zeros(3, 128)) # xavier init nn.init.xavier_uniform_(self.background_embedding) def forward_single_view( self, gaussian_model: GaussianModel, viewpoint_camera: Camera, background_color: Optional[Float[Tensor, "3"]], ret_mask: bool = True, features=None, patch_size=14, ): """ Renders a single view using the provided GaussianModel and camera parameters, with support for patch-based rendering and learnable background embeddings for enhanced modeling. Args: gaussian_model (GaussianModel): The 3D Gaussian model to be rendered. viewpoint_camera (Camera): The viewpoint camera used for rendering (contains intrinsics, extrinsics, size). background_color (Optional[Float[Tensor, "3"]]): Optional background color for the rendered image. ret_mask (bool, optional): Whether to return the alpha mask. Default is True. features (Optional[Tensor], optional): Optional feature tensor to concatenate with the Gaussian appearance. patch_size (int, optional): Size of the rendering patches (default: 14). Returns: dict: { "comp_rgb": Rendered RGB image (H, W, 3), "comp_features": (optional) Rendered additional features (if features is not None), "comp_rgb_bg": Rendered background color, "comp_mask": Rendered alpha mask (H, W), # "comp_depth": Optionally rendered depth map (if enabled) } """ xyz, shs, colors_precomp, opacity, scales, rotations, cov3D_precomp = ( self.get_gaussians_properties(viewpoint_camera, gaussian_model) ) extrinsics = viewpoint_camera.world_view_transform.transpose( 0, 1 ).contiguous() # c2w -> w2c # scale_ratio for patch size img_height = int(viewpoint_camera.height) img_width = int(viewpoint_camera.width) patch_height = img_height // patch_size patch_width = img_width // patch_size scale_y = img_height / patch_height scale_x = img_width / patch_width intrinsics = scale_intrs( viewpoint_camera.intrinsic.clone(), 1.0 / scale_x, 1.0 / scale_y ) colors_precomp = colors_precomp.squeeze(1) opacity = opacity.squeeze(1) if features is not None: colors_precomp = torch.cat([colors_precomp, features], dim=1) channel = colors_precomp.shape[1] bg_idx = (background_color[0] / 0.5).int().item() background_embedding = self.background_embedding[bg_idx] background_color = torch.cat( [background_color, background_embedding], dim=0 ) with torch.autocast(device_type=self.device.type, dtype=torch.float32): render_rgbd, render_alphas, meta = rasterization( means=xyz.float(), quats=rotations.float(), scales=scales.float(), opacities=opacity.float(), colors=colors_precomp.float(), viewmats=extrinsics.unsqueeze(0).float(), Ks=intrinsics.float().unsqueeze(0)[:, :3, :3], width=patch_width, height=patch_height, near_plane=viewpoint_camera.znear, far_plane=viewpoint_camera.zfar, # radius_clip=3.0, eps2d=0.3, # 3 pixel render_mode="RGB", # render_mode="RGB+D", backgrounds=background_color.unsqueeze(0).float(), camera_model="pinhole", ) render_rgbd = render_rgbd.squeeze(0) render_alphas = render_alphas.squeeze(0) rendered_image = render_rgbd[:, :, :3] rendered_features = render_rgbd[:, :, 3:] if features is not None: ret = { "comp_features": rendered_features, "comp_mask": render_alphas, "comp_rgb": rendered_image, # [H, W, 3] # "comp_depth": rendered_depth, } else: ret = { "comp_rgb": rendered_image, # [H, W, 3] "comp_rgb_bg": background_color, "comp_mask": render_alphas, # "comp_depth": rendered_depth, } return ret