import copy import math import os import pdb from collections import defaultdict from dataclasses import dataclass, field import numpy as np import omegaconf import torch import torch.nn as nn import torch.nn.functional as F from diff_gaussian_rasterization import ( GaussianRasterizationSettings, GaussianRasterizer, ) from plyfile import PlyData, PlyElement from pytorch3d.transforms import matrix_to_quaternion from pytorch3d.transforms.rotation_conversions import quaternion_multiply from LHM.models.rendering.smpl_x import SMPLXModel, read_smplx_param from LHM.models.rendering.smpl_x_voxel_dense_sampling import SMPLXVoxelMeshModel from LHM.models.rendering.utils.sh_utils import RGB2SH, SH2RGB from LHM.models.rendering.utils.typing import * from LHM.models.rendering.utils.utils import MLP, trunc_exp from LHM.models.utils import LinerParameterTuner, StaticParameterTuner from LHM.outputs.output import GaussianAppOutput def auto_repeat_size(tensor, repeat_num, axis=0): repeat_size = [1] * tensor.dim() repeat_size[axis] = repeat_num return repeat_size def aabb(xyz): return torch.min(xyz, dim=0).values, torch.max(xyz, dim=0).values def inverse_sigmoid(x): if isinstance(x, float): x = torch.tensor(x).float() return torch.log(x / (1 - x)) def generate_rotation_matrix_y(degrees): theta = math.radians(degrees) cos_theta = math.cos(theta) sin_theta = math.sin(theta) R = [[cos_theta, 0, sin_theta], [0, 1, 0], [-sin_theta, 0, cos_theta]] return np.asarray(R, dtype=np.float32) def getWorld2View2(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0): Rt = np.zeros((4, 4)) Rt[:3, :3] = R.transpose() Rt[:3, 3] = t Rt[3, 3] = 1.0 C2W = np.linalg.inv(Rt) cam_center = C2W[:3, 3] cam_center = (cam_center + translate) * scale C2W[:3, 3] = cam_center Rt = np.linalg.inv(C2W) return np.float32(Rt) def getProjectionMatrix(znear, zfar, fovX, fovY): tanHalfFovY = math.tan((fovY / 2)) tanHalfFovX = math.tan((fovX / 2)) top = tanHalfFovY * znear bottom = -top right = tanHalfFovX * znear left = -right P = torch.zeros(4, 4) z_sign = 1.0 P[0, 0] = 2.0 * znear / (right - left) P[1, 1] = 2.0 * znear / (top - bottom) P[0, 2] = (right + left) / (right - left) P[1, 2] = (top + bottom) / (top - bottom) P[3, 2] = z_sign P[2, 2] = z_sign * zfar / (zfar - znear) P[2, 3] = -(zfar * znear) / (zfar - znear) return P def intrinsic_to_fov(intrinsic, w, h): fx, fy = intrinsic[0, 0], intrinsic[1, 1] fov_x = 2 * torch.arctan2(w, 2 * fx) fov_y = 2 * torch.arctan2(h, 2 * fy) return fov_x, fov_y class Camera: def __init__( self, w2c, intrinsic, FoVx, FoVy, height, width, trans=np.array([0.0, 0.0, 0.0]), scale=1.0, ) -> None: self.FoVx = FoVx self.FoVy = FoVy self.height = height self.width = width self.world_view_transform = w2c.transpose(0, 1) self.zfar = 100.0 self.znear = 0.01 self.trans = trans self.scale = scale self.projection_matrix = ( getProjectionMatrix( znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy ) .transpose(0, 1) .to(w2c.device) ) self.full_proj_transform = ( self.world_view_transform.unsqueeze(0).bmm( self.projection_matrix.unsqueeze(0) ) ).squeeze(0) self.camera_center = self.world_view_transform.inverse()[3, :3] self.intrinsic = intrinsic @staticmethod def from_c2w(c2w, intrinsic, height, width): w2c = torch.inverse(c2w) FoVx, FoVy = intrinsic_to_fov( intrinsic, w=torch.tensor(width, device=w2c.device), h=torch.tensor(height, device=w2c.device), ) return Camera( w2c=w2c, intrinsic=intrinsic, FoVx=FoVx, FoVy=FoVy, height=height, width=width, ) class GaussianModel: def setup_functions(self): self.scaling_activation = torch.exp self.scaling_inverse_activation = torch.log self.opacity_activation = torch.sigmoid self.inverse_opacity_activation = inverse_sigmoid self.rotation_activation = torch.nn.functional.normalize # rgb activation function self.rgb_activation = torch.sigmoid def __init__(self, xyz, opacity, rotation, scaling, shs, use_rgb=False) -> None: """ Initializes the GSRenderer object. Args: xyz (Tensor): The xyz coordinates. opacity (Tensor): The opacity values. rotation (Tensor): The rotation values. scaling (Tensor): The scaling values. before_activate: if True, the output appearance is needed to process by activation function. shs (Tensor): The spherical harmonics coefficients. use_rgb (bool, optional): Indicates whether shs represents RGB values. Defaults to False. """ self.setup_functions() self.xyz: Tensor = xyz self.opacity: Tensor = opacity self.rotation: Tensor = rotation self.scaling: Tensor = scaling self.shs: Tensor = shs # [B, SH_Coeff, 3] self.use_rgb = use_rgb # shs indicates rgb? def construct_list_of_attributes(self): l = ["x", "y", "z", "nx", "ny", "nz"] features_dc = self.shs[:, :1] features_rest = self.shs[:, 1:] for i in range(features_dc.shape[1] * features_dc.shape[2]): l.append("f_dc_{}".format(i)) for i in range(features_rest.shape[1] * features_rest.shape[2]): l.append("f_rest_{}".format(i)) l.append("opacity") for i in range(self.scaling.shape[1]): l.append("scale_{}".format(i)) for i in range(self.rotation.shape[1]): l.append("rot_{}".format(i)) return l def save_ply(self, path): xyz = self.xyz.detach().cpu().numpy() normals = np.zeros_like(xyz) if self.use_rgb: shs = RGB2SH(self.shs) else: shs = self.shs features_dc = shs[:, :1] features_rest = shs[:, 1:] f_dc = ( features_dc.float().detach().flatten(start_dim=1).contiguous().cpu().numpy() ) f_rest = ( features_rest.float() .detach() .flatten(start_dim=1) .contiguous() .cpu() .numpy() ) opacities = ( inverse_sigmoid(torch.clamp(self.opacity, 1e-3, 1 - 1e-3)) .detach() .cpu() .numpy() ) scale = np.log(self.scaling.detach().cpu().numpy()) rotation = self.rotation.detach().cpu().numpy() dtype_full = [ (attribute, "f4") for attribute in self.construct_list_of_attributes() ] elements = np.empty(xyz.shape[0], dtype=dtype_full) attributes = np.concatenate( (xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1 ) elements[:] = list(map(tuple, attributes)) el = PlyElement.describe(elements, "vertex") PlyData([el]).write(path) def load_ply(self, path): plydata = PlyData.read(path) xyz = np.stack( ( np.asarray(plydata.elements[0]["x"]), np.asarray(plydata.elements[0]["y"]), np.asarray(plydata.elements[0]["z"]), ), axis=1, ) opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] features_dc = np.zeros((xyz.shape[0], 3, 1)) features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) extra_f_names = [ p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_") ] extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1])) sh_degree = int(math.sqrt((len(extra_f_names) + 3) / 3)) - 1 print("load sh degree: ", sh_degree) features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) for idx, attr_name in enumerate(extra_f_names): features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) # 0, 3, 8, 15 features_extra = features_extra.reshape( (features_extra.shape[0], 3, (sh_degree + 1) ** 2 - 1) ) scale_names = [ p.name for p in plydata.elements[0].properties if p.name.startswith("scale_") ] scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1])) scales = np.zeros((xyz.shape[0], len(scale_names))) for idx, attr_name in enumerate(scale_names): scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) rot_names = [ p.name for p in plydata.elements[0].properties if p.name.startswith("rot") ] rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1])) rots = np.zeros((xyz.shape[0], len(rot_names))) for idx, attr_name in enumerate(rot_names): rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) xyz = torch.from_numpy(xyz).to(self.xyz) opacities = torch.from_numpy(opacities).to(self.opacity) rotation = torch.from_numpy(rots).to(self.rotation) scales = torch.from_numpy(scales).to(self.scaling) features_dc = torch.from_numpy(features_dc).to(self.shs) features_rest = torch.from_numpy(features_extra).to(self.shs) shs = torch.cat([features_dc, features_rest], dim=2) if self.use_rgb: shs = SH2RGB(shs) else: shs = shs self.xyz: Tensor = xyz self.opacity: Tensor = self.opacity_activation(opacities) self.rotation: Tensor = self.rotation_activation(rotation) self.scaling: Tensor = self.scaling_activation(scales) self.shs: Tensor = shs.permute(0, 2, 1) self.active_sh_degree = sh_degree def clone(self): xyz = self.xyz.clone() opacity = self.opacity.clone() rotation = self.rotation.clone() scaling = self.scaling.clone() shs = self.shs.clone() use_rgb = self.use_rgb return GaussianModel(xyz, opacity, rotation, scaling, shs, use_rgb) class GSLayer(nn.Module): """W/O Activation Function""" def setup_functions(self): self.scaling_activation = trunc_exp # proposed by torch-ngp self.scaling_inverse_activation = torch.log self.opacity_activation = torch.sigmoid self.inverse_opacity_activation = inverse_sigmoid self.rotation_activation = torch.nn.functional.normalize self.rgb_activation = torch.sigmoid def __init__( self, in_channels, use_rgb, clip_scaling=0.2, init_scaling=-5.0, init_density=0.1, sh_degree=None, xyz_offset=True, restrict_offset=True, xyz_offset_max_step=None, fix_opacity=False, fix_rotation=False, use_fine_feat=False, ): super().__init__() self.setup_functions() if isinstance(clip_scaling, omegaconf.listconfig.ListConfig) or isinstance( clip_scaling, list ): self.clip_scaling_pruner = LinerParameterTuner(*clip_scaling) else: self.clip_scaling_pruner = StaticParameterTuner(clip_scaling) self.clip_scaling = self.clip_scaling_pruner.get_value(0) self.use_rgb = use_rgb self.restrict_offset = restrict_offset self.xyz_offset = xyz_offset self.xyz_offset_max_step = xyz_offset_max_step # 1.2 / 32 self.fix_opacity = fix_opacity self.fix_rotation = fix_rotation self.use_fine_feat = use_fine_feat self.attr_dict = { "shs": (sh_degree + 1) ** 2 * 3, "scaling": 3, "xyz": 3, "opacity": None, "rotation": None, } if not self.fix_opacity: self.attr_dict["opacity"] = 1 if not self.fix_rotation: self.attr_dict["rotation"] = 4 self.out_layers = nn.ModuleDict() for key, out_ch in self.attr_dict.items(): if out_ch is None: layer = nn.Identity() else: if key == "shs" and use_rgb: out_ch = 3 if key == "shs": shs_out_ch = out_ch layer = nn.Linear(in_channels, out_ch) # initialize if not (key == "shs" and use_rgb): if key == "opacity" and self.fix_opacity: pass elif key == "rotation" and self.fix_rotation: pass else: nn.init.constant_(layer.weight, 0) nn.init.constant_(layer.bias, 0) if key == "scaling": nn.init.constant_(layer.bias, init_scaling) elif key == "rotation": if not self.fix_rotation: nn.init.constant_(layer.bias, 0) nn.init.constant_(layer.bias[0], 1.0) elif key == "opacity": if not self.fix_opacity: nn.init.constant_(layer.bias, inverse_sigmoid(init_density)) self.out_layers[key] = layer if self.use_fine_feat: fine_shs_layer = nn.Linear(in_channels, shs_out_ch) nn.init.constant_(fine_shs_layer.weight, 0) nn.init.constant_(fine_shs_layer.bias, 0) self.out_layers["fine_shs"] = fine_shs_layer def hyper_step(self, step): self.clip_scaling = self.clip_scaling_pruner.get_value(step) def forward(self, x, pts, x_fine=None): assert len(x.shape) == 2 ret = {} for k in self.attr_dict: layer = self.out_layers[k] v = layer(x) if k == "rotation": if self.fix_rotation: v = matrix_to_quaternion( torch.eye(3).type_as(x)[None, :, :].repeat(x.shape[0], 1, 1) ) # constant rotation else: # v = torch.nn.functional.normalize(v) v = self.rotation_activation(v) elif k == "scaling": # v = trunc_exp(v) v = self.scaling_activation(v) if self.clip_scaling is not None: v = torch.clamp(v, min=0, max=self.clip_scaling) elif k == "opacity": if self.fix_opacity: v = torch.ones_like(x)[..., 0:1] else: # v = torch.sigmoid(v) v = self.opacity_activation(v) elif k == "shs": if self.use_rgb: # v = torch.sigmoid(v) v = self.rgb_activation(v) if self.use_fine_feat: v_fine = self.out_layers["fine_shs"](x_fine) v_fine = torch.tanh(v_fine) v = v + v_fine else: if self.use_fine_feat: v_fine = self.out_layers["fine_shs"](x_fine) v = v + v_fine v = torch.reshape(v, (v.shape[0], -1, 3)) elif k == "xyz": # TODO check if self.restrict_offset: max_step = self.xyz_offset_max_step v = (torch.sigmoid(v) - 0.5) * max_step if self.xyz_offset: pass else: assert NotImplementedError v = v + pts k = "offset_xyz" ret[k] = v ret["use_rgb"] = self.use_rgb return GaussianAppOutput(**ret) class PointEmbed(nn.Module): def __init__(self, hidden_dim=48, dim=128): super().__init__() assert hidden_dim % 6 == 0 self.embedding_dim = hidden_dim e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi e = torch.stack( [ torch.cat( [ e, torch.zeros(self.embedding_dim // 6), torch.zeros(self.embedding_dim // 6), ] ), torch.cat( [ torch.zeros(self.embedding_dim // 6), e, torch.zeros(self.embedding_dim // 6), ] ), torch.cat( [ torch.zeros(self.embedding_dim // 6), torch.zeros(self.embedding_dim // 6), e, ] ), ] ) self.register_buffer("basis", e) # 3 x 16 self.mlp = nn.Linear(self.embedding_dim + 3, dim) self.norm = nn.LayerNorm(dim) @staticmethod def embed(input, basis): projections = torch.einsum("bnd,de->bne", input, basis) embeddings = torch.cat([projections.sin(), projections.cos()], dim=2) return embeddings def forward(self, input): # input: B x N x 3 embed = self.mlp( torch.cat([self.embed(input, self.basis), input], dim=2) ) # B x N x C embed = self.norm(embed) return embed class CrossAttnBlock(nn.Module): """ Transformer block that takes in a cross-attention condition. Designed for SparseLRM architecture. """ # Block contains a cross-attention layer, a self-attention layer, and an MLP def __init__( self, inner_dim: int, cond_dim: int, num_heads: int, eps: float = None, attn_drop: float = 0.0, attn_bias: bool = False, mlp_ratio: float = 4.0, mlp_drop: float = 0.0, feedforward=False, ): super().__init__() # TODO check already apply normalization # self.norm_q = nn.LayerNorm(inner_dim, eps=eps) # self.norm_k = nn.LayerNorm(cond_dim, eps=eps) self.norm_q = nn.Identity() self.norm_k = nn.Identity() self.cross_attn = nn.MultiheadAttention( embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim, dropout=attn_drop, bias=attn_bias, batch_first=True, ) self.mlp = None if feedforward: self.norm2 = nn.LayerNorm(inner_dim, eps=eps) self.self_attn = nn.MultiheadAttention( embed_dim=inner_dim, num_heads=num_heads, dropout=attn_drop, bias=attn_bias, batch_first=True, ) self.norm3 = nn.LayerNorm(inner_dim, eps=eps) self.mlp = nn.Sequential( nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), nn.GELU(), nn.Dropout(mlp_drop), nn.Linear(int(inner_dim * mlp_ratio), inner_dim), nn.Dropout(mlp_drop), ) def forward(self, x, cond): # x: [N, L, D] # cond: [N, L_cond, D_cond] x = self.cross_attn( self.norm_q(x), self.norm_k(cond), cond, need_weights=False )[0] if self.mlp is not None: before_sa = self.norm2(x) x = ( x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0] ) x = x + self.mlp(self.norm3(x)) return x class DecoderCrossAttn(nn.Module): def __init__( self, query_dim, context_dim, num_heads, mlp=False, decode_with_extra_info=None ): super().__init__() self.query_dim = query_dim self.context_dim = context_dim self.cross_attn = CrossAttnBlock( inner_dim=query_dim, cond_dim=context_dim, num_heads=num_heads, feedforward=mlp, eps=1e-5, ) self.decode_with_extra_info = decode_with_extra_info if decode_with_extra_info is not None: if decode_with_extra_info["type"] == "dinov2p14_feat": context_dim = decode_with_extra_info["cond_dim"] self.cross_attn_color = CrossAttnBlock( inner_dim=query_dim, cond_dim=context_dim, num_heads=num_heads, feedforward=False, eps=1e-5, ) elif decode_with_extra_info["type"] == "decoder_dinov2p14_feat": from LHM.models.encoders.dinov2_wrapper import Dinov2Wrapper self.encoder = Dinov2Wrapper( model_name="dinov2_vits14_reg", freeze=False, encoder_feat_dim=384 ) self.cross_attn_color = CrossAttnBlock( inner_dim=query_dim, cond_dim=384, num_heads=num_heads, feedforward=False, eps=1e-5, ) elif decode_with_extra_info["type"] == "decoder_resnet18_feat": from LHM.models.encoders.xunet_wrapper import XnetWrapper self.encoder = XnetWrapper( model_name="resnet18", freeze=False, encoder_feat_dim=64 ) self.cross_attn_color = CrossAttnBlock( inner_dim=query_dim, cond_dim=64, num_heads=num_heads, feedforward=False, eps=1e-5, ) def resize_image(self, image, multiply): B, _, H, W = image.shape new_h, new_w = ( math.ceil(H / multiply) * multiply, math.ceil(W / multiply) * multiply, ) image = F.interpolate( image, (new_h, new_w), align_corners=True, mode="bilinear" ) return image def forward(self, pcl_query, pcl_latent, extra_info=None): out = self.cross_attn(pcl_query, pcl_latent) if self.decode_with_extra_info is not None: out_dict = {} out_dict["coarse"] = out if self.decode_with_extra_info["type"] == "dinov2p14_feat": out = self.cross_attn_color(out, extra_info["image_feats"]) out_dict["fine"] = out return out_dict elif self.decode_with_extra_info["type"] == "decoder_dinov2p14_feat": img_feat = self.encoder(extra_info["image"]) out = self.cross_attn_color(out, img_feat) out_dict["fine"] = out return out_dict elif self.decode_with_extra_info["type"] == "decoder_resnet18_feat": image = extra_info["image"] image = self.resize_image(image, multiply=32) img_feat = self.encoder(image) out = self.cross_attn_color(out, img_feat) out_dict["fine"] = out return out_dict return out class GS3DRenderer(nn.Module): def __init__( self, human_model_path, subdivide_num, smpl_type, feat_dim, query_dim, use_rgb, sh_degree, xyz_offset_max_step, mlp_network_config, expr_param_dim, shape_param_dim, clip_scaling=0.2, cano_pose_type=0, decoder_mlp=False, skip_decoder=False, fix_opacity=False, fix_rotation=False, decode_with_extra_info=None, gradient_checkpointing=False, apply_pose_blendshape=False, dense_sample_pts=40000, # only use for dense_smaple_smplx ): super().__init__() self.gradient_checkpointing = gradient_checkpointing self.skip_decoder = skip_decoder self.smpl_type = smpl_type assert self.smpl_type in ["smplx", "smplx_0", "smplx_1", "smplx_2"] self.scaling_modifier = 1.0 self.sh_degree = sh_degree if self.smpl_type == "smplx_0" or self.smpl_type == "smplx": # Using pytorch3d dense sampling self.smplx_model = SMPLXModel( human_model_path, gender="neutral", subdivide_num=subdivide_num, shape_param_dim=shape_param_dim, expr_param_dim=expr_param_dim, cano_pose_type=cano_pose_type, apply_pose_blendshape=apply_pose_blendshape, ) elif self.smpl_type == "smplx_1": raise NotImplementedError("inference version does not support") elif self.smpl_type == "smplx_2": self.smplx_model = SMPLXVoxelMeshModel( human_model_path, gender="neutral", subdivide_num=subdivide_num, shape_param_dim=shape_param_dim, expr_param_dim=expr_param_dim, cano_pose_type=cano_pose_type, dense_sample_points=dense_sample_pts, apply_pose_blendshape=apply_pose_blendshape, ) else: raise NotImplementedError if not self.skip_decoder: self.pcl_embed = PointEmbed(dim=query_dim) self.decoder_cross_attn = DecoderCrossAttn( query_dim=query_dim, context_dim=feat_dim, num_heads=1, mlp=decoder_mlp, decode_with_extra_info=decode_with_extra_info, ) self.mlp_network_config = mlp_network_config # using to mapping transformer decode feature to regression features. as decode feature is processed by NormLayer. if self.mlp_network_config is not None: self.mlp_net = MLP(query_dim, query_dim, **self.mlp_network_config) self.gs_net = GSLayer( in_channels=query_dim, use_rgb=use_rgb, sh_degree=self.sh_degree, clip_scaling=clip_scaling, init_scaling=-5.0, init_density=0.1, xyz_offset=True, restrict_offset=True, xyz_offset_max_step=xyz_offset_max_step, fix_opacity=fix_opacity, fix_rotation=fix_rotation, use_fine_feat=( True if decode_with_extra_info is not None and decode_with_extra_info["type"] is not None else False ), ) def hyper_step(self, step): self.gs_net.hyper_step(step) def forward_single_view( self, gs: GaussianModel, viewpoint_camera: Camera, background_color: Optional[Float[Tensor, "3"]], ret_mask: bool = True, ): # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means screenspace_points = ( torch.zeros_like( gs.xyz, dtype=gs.xyz.dtype, requires_grad=True, device=self.device ) + 0 ) try: screenspace_points.retain_grad() except: pass bg_color = background_color # Set up rasterization configuration tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) raster_settings = GaussianRasterizationSettings( image_height=int(viewpoint_camera.height), image_width=int(viewpoint_camera.width), tanfovx=tanfovx, tanfovy=tanfovy, bg=bg_color, scale_modifier=self.scaling_modifier, viewmatrix=viewpoint_camera.world_view_transform, projmatrix=viewpoint_camera.full_proj_transform.float(), sh_degree=self.sh_degree, campos=viewpoint_camera.camera_center, prefiltered=False, debug=True, ) rasterizer = GaussianRasterizer(raster_settings=raster_settings) means3D = gs.xyz means2D = screenspace_points opacity = gs.opacity # If precomputed 3d covariance is provided, use it. If not, then it will be computed from # scaling / rotation by the rasterizer. scales = None rotations = None cov3D_precomp = None scales = gs.scaling rotations = gs.rotation # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. shs = None colors_precomp = None if self.gs_net.use_rgb: colors_precomp = gs.shs.squeeze(1) shs = None else: colors_precomp = None shs = gs.shs # print(shs, colors_precomp) # print(means3D.device, means2D.device, opacity.device, rotations.device, self.device) # print(means3D.dtype, means2D.dtype, rotations.dtype, opacity.dtype) # print(means3D.shape, means2D.shape, rotations.shape, opacity.shape) # Rasterize visible Gaussians to image, obtain their radii (on screen). # NOTE that dadong tries to regress rgb not shs # with torch.autocast(device_type=self.device.type, dtype=torch.float32): rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( means3D=means3D, means2D=means2D, shs=shs, colors_precomp=colors_precomp, opacities=opacity, scales=scales, rotations=rotations, cov3D_precomp=cov3D_precomp, ) ret = { "comp_rgb": rendered_image.permute(1, 2, 0), # [H, W, 3] "comp_rgb_bg": bg_color, "comp_mask": rendered_alpha.permute(1, 2, 0), "comp_depth": rendered_depth.permute(1, 2, 0), } # if ret_mask: # mask_bg_color = torch.zeros(3, dtype=torch.float32, device=self.device) # raster_settings = GaussianRasterizationSettings( # image_height=int(viewpoint_camera.height), # image_width=int(viewpoint_camera.width), # tanfovx=tanfovx, # tanfovy=tanfovy, # bg=mask_bg_color, # scale_modifier=self.scaling_modifier, # viewmatrix=viewpoint_camera.world_view_transform, # projmatrix=viewpoint_camera.full_proj_transform.float(), # sh_degree=0, # campos=viewpoint_camera.camera_center, # prefiltered=False, # debug=False # ) # rasterizer = GaussianRasterizer(raster_settings=raster_settings) # with torch.autocast(device_type=self.device.type, dtype=torch.float32): # rendered_mask, radii = rasterizer( # means3D = means3D, # means2D = means2D, # # shs = , # colors_precomp = torch.ones_like(means3D), # opacities = opacity, # scales = scales, # rotations = rotations, # cov3D_precomp = cov3D_precomp) # ret["comp_mask"] = rendered_mask.permute(1, 2, 0) return ret def animate_gs_model( self, gs_attr: GaussianAppOutput, query_points, smplx_data, debug=False ): """ query_points: [N, 3] """ device = gs_attr.offset_xyz.device if debug: N = gs_attr.offset_xyz.shape[0] gs_attr.xyz = torch.ones_like(gs_attr.offset_xyz) * 0.0 rotation = matrix_to_quaternion( torch.eye(3).float()[None, :, :].repeat(N, 1, 1) ).to( device ) # constant rotation opacity = torch.ones((N, 1)).float().to(device) # constant opacity gs_attr.opacity = opacity gs_attr.rotation = rotation # gs_attr.scaling = torch.ones_like(gs_attr.scaling) * 0.05 # print(gs_attr.shs.shape) # build cano_dependent_pose cano_smplx_data_keys = [ "root_pose", "body_pose", "jaw_pose", "leye_pose", "reye_pose", "lhand_pose", "rhand_pose", "expr", "trans", ] merge_smplx_data = dict() for cano_smplx_data_key in cano_smplx_data_keys: warp_data = smplx_data[cano_smplx_data_key] cano_pose = torch.zeros_like(warp_data[:1]) if cano_smplx_data_key == "body_pose": # A-posed cano_pose[0, 15, -1] = -math.pi / 6 cano_pose[0, 16, -1] = +math.pi / 6 merge_pose = torch.cat([warp_data, cano_pose], dim=0) merge_smplx_data[cano_smplx_data_key] = merge_pose merge_smplx_data["betas"] = smplx_data["betas"] merge_smplx_data["transform_mat_neutral_pose"] = smplx_data[ "transform_mat_neutral_pose" ] with torch.autocast(device_type=device.type, dtype=torch.float32): mean_3d = ( query_points + gs_attr.offset_xyz ) # [N, 3] # canonical space offset. # matrix to warp predefined pose to zero-pose transform_mat_neutral_pose = merge_smplx_data[ "transform_mat_neutral_pose" ] # [55, 4, 4] num_view = merge_smplx_data["body_pose"].shape[0] # [Nv, 21, 3] mean_3d = mean_3d.unsqueeze(0).repeat(num_view, 1, 1) # [Nv, N, 3] query_points = query_points.unsqueeze(0).repeat(num_view, 1, 1) transform_mat_neutral_pose = transform_mat_neutral_pose.unsqueeze(0).repeat( num_view, 1, 1, 1 ) # print(mean_3d.shape, transform_mat_neutral_pose.shape, query_points.shape, smplx_data["body_pose"].shape, smplx_data["betas"].shape) mean_3d, transform_matrix = ( self.smplx_model.transform_to_posed_verts_from_neutral_pose( mean_3d, merge_smplx_data, query_points, transform_mat_neutral_pose=transform_mat_neutral_pose, # from predefined pose to zero-pose matrix device=device, ) ) # [B, N, 3] # rotation appearance from canonical space to view_posed num_view, N, _, _ = transform_matrix.shape transform_rotation = transform_matrix[:, :, :3, :3] rigid_rotation_matrix = torch.nn.functional.normalize( matrix_to_quaternion(transform_rotation), dim=-1 ) I = matrix_to_quaternion(torch.eye(3)).to(device) # inference constrain is_constrain_body = self.smplx_model.is_constrain_body rigid_rotation_matrix[:, is_constrain_body] = I gs_attr.scaling[is_constrain_body] = gs_attr.scaling[ is_constrain_body ].clamp(max=0.02) rotation_neutral_pose = gs_attr.rotation.unsqueeze(0).repeat(num_view, 1, 1) # TODO do not move underarm gs # QUATERNION MULTIPLY rotation_pose_verts = quaternion_multiply( rigid_rotation_matrix, rotation_neutral_pose ) # rotation_pose_verts = rotation_neutral_pose gs_list = [] cano_gs_list = [] for i in range(num_view): gs_copy = GaussianModel( xyz=mean_3d[i], opacity=gs_attr.opacity, # rotation=gs_attr.rotation, rotation=rotation_pose_verts[i], scaling=gs_attr.scaling, shs=gs_attr.shs, use_rgb=self.gs_net.use_rgb, ) # [N, 3] if i == num_view - 1: cano_gs_list.append(gs_copy) else: gs_list.append(gs_copy) return gs_list, cano_gs_list def forward_gs_attr(self, x, query_points, smplx_data, debug=False, x_fine=None): """ x: [N, C] Float[Tensor, "Np Cp"], query_points: [N, 3] Float[Tensor, "Np 3"] """ device = x.device if self.mlp_network_config is not None: # x is processed by LayerNorm x = self.mlp_net(x) if x_fine is not None: x_fine = self.mlp_net(x_fine) # NOTE that gs_attr contains offset xyz gs_attr: GaussianAppOutput = self.gs_net(x, query_points, x_fine) return gs_attr def get_query_points(self, smplx_data, device): with torch.no_grad(): with torch.autocast(device_type=device.type, dtype=torch.float32): # print(smplx_data["betas"].shape, smplx_data["face_offset"].shape, smplx_data["joint_offset"].shape) positions, _, transform_mat_neutral_pose = ( self.smplx_model.get_query_points(smplx_data, device=device) ) # [B, N, 3] smplx_data["transform_mat_neutral_pose"] = ( transform_mat_neutral_pose # [B, 55, 4, 4] ) return positions, smplx_data def decoder_cross_attn_wrapper(self, pcl_embed, latent_feat, extra_info): # if self.training and self.gradient_checkpointing: # def create_custom_forward(module): # def custom_forward(*inputs): # return module(*inputs) # return custom_forward # ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} # gs_feats = torch.utils.checkpoint.checkpoint( # create_custom_forward(self.decoder_cross_attn), # pcl_embed.to(dtype=latent_feat.dtype), # latent_feat, # extra_info, # **ckpt_kwargs, # ) # else: gs_feats = self.decoder_cross_attn( pcl_embed.to(dtype=latent_feat.dtype), latent_feat, extra_info ) return gs_feats def query_latent_feat( self, positions: Float[Tensor, "*B N1 3"], smplx_data, latent_feat: Float[Tensor, "*B N2 C"], extra_info, ): device = latent_feat.device if self.skip_decoder: gs_feats = latent_feat assert positions is not None else: assert positions is None if positions is None: positions, smplx_data = self.get_query_points(smplx_data, device) with torch.autocast(device_type=device.type, dtype=torch.float32): pcl_embed = self.pcl_embed(positions) gs_feats = self.decoder_cross_attn_wrapper( pcl_embed, latent_feat, extra_info ) return gs_feats, positions, smplx_data def forward_single_batch( self, gs_list: list[GaussianModel], c2ws: Float[Tensor, "Nv 4 4"], intrinsics: Float[Tensor, "Nv 4 4"], height: int, width: int, background_color: Optional[Float[Tensor, "Nv 3"]], debug: bool = False, ): out_list = [] self.device = gs_list[0].xyz.device for v_idx, (c2w, intrinsic) in enumerate(zip(c2ws, intrinsics)): out_list.append( self.forward_single_view( gs_list[v_idx], Camera.from_c2w(c2w, intrinsic, height, width), background_color[v_idx], ) ) out = defaultdict(list) for out_ in out_list: for k, v in out_.items(): out[k].append(v) out = {k: torch.stack(v, dim=0) for k, v in out.items()} out["3dgs"] = gs_list # debug = True if debug: import cv2 cv2.imwrite( "fuck.png", (out["comp_rgb"].detach().cpu().numpy()[0, ..., ::-1] * 255).astype( np.uint8 ), ) return out def forward_cano_batch( self, gs_list: list[GaussianModel], c2ws: Float[Tensor, "Nv 4 4"], intrinsics: Float[Tensor, "Nv 4 4"], background_color: Optional[Float[Tensor, "Nv 3"]], height: int = 512, width: int = 512, debug: bool = False, ): """using to visualization.""" degree_list = [0, 90, 180, 270] out_list = [] self.device = gs_list[0].xyz.device gs_list_copy = [gs_list[0].clone() for _ in range(len(degree_list))] rotation_gs_list = [] for rotation_degree, gs in zip(degree_list, gs_list_copy): _R = torch.eye(3).to(gs.xyz) _R[-1, -1] *= -1 _R[1, 1] *= -1 self_R = torch.from_numpy(generate_rotation_matrix_y(rotation_degree)).to( _R ) _R = self_R @ _R gs.xyz = (_R @ gs.xyz.T).T _min, _max = aabb(gs.xyz) center = (_min + _max) / 2 gs.xyz -= center.unsqueeze(0) _R_quaternion = matrix_to_quaternion(_R) gs.rotation = quaternion_multiply(_R_quaternion, gs.rotation) gs.xyz[..., -1] += 2.5 # move to (0, 0, 3) rotation_gs_list.append(gs) intrinsics = torch.eye(4).to(intrinsics).unsqueeze(0) intrinsics[0, 0, 0] = width intrinsics[0, 1, 1] = height intrinsics[0, 0, 2] = width / 2 intrinsics[0, 1, 2] = height / 2 for v_idx, gs in enumerate(rotation_gs_list): out_list.append( self.forward_single_view( rotation_gs_list[v_idx], Camera.from_c2w(c2ws[0], intrinsics[0], height, width), torch.ones_like(background_color[0]), ) ) out = defaultdict(list) for out_ in out_list: for k, v in out_.items(): out[k].append(v) out = {k: torch.stack(v, dim=0) for k, v in out.items()} out["3dgs"] = rotation_gs_list if debug: import cv2 for i in range(4): cv2.imwrite( f"fuck_{i}.png", (out["comp_rgb"].detach().cpu().numpy()[i, ..., ::-1] * 255).astype( np.uint8 ), ) return out def get_single_batch_smpl_data(self, smpl_data, bidx): smpl_data_single_batch = {} for k, v in smpl_data.items(): smpl_data_single_batch[k] = v[ bidx ] # e.g. body_pose: [B, N_v, 21, 3] -> [N_v, 21, 3] if k == "betas" or (k == "joint_offset") or (k == "face_offset"): smpl_data_single_batch[k] = v[ bidx : bidx + 1 ] # e.g. betas: [B, 100] -> [1, 100] return smpl_data_single_batch def get_single_view_smpl_data(self, smpl_data, vidx): smpl_data_single_view = {} for k, v in smpl_data.items(): assert v.shape[0] == 1 if ( k == "betas" or (k == "joint_offset") or (k == "face_offset") or (k == "transform_mat_neutral_pose") ): smpl_data_single_view[k] = v # e.g. betas: [1, 100] -> [1, 100] else: smpl_data_single_view[k] = v[ :, vidx : vidx + 1 ] # e.g. body_pose: [1, N_v, 21, 3] -> [1, 1, 21, 3] return smpl_data_single_view def forward_gs( self, gs_hidden_features: Float[Tensor, "B Np Cp"], query_points: Float[Tensor, "B Np_q 3"], smplx_data, # e.g., body_pose:[B, Nv, 21, 3], betas:[B, 100] additional_features: Optional[dict] = None, debug: bool = False, **kwargs, ): batch_size = gs_hidden_features.shape[0] # obtain gs_features embedding, cur points position, and also smplx params query_gs_features, query_points, smplx_data = self.query_latent_feat( query_points, smplx_data, gs_hidden_features, additional_features ) gs_attr_list = [] for b in range(batch_size): if isinstance(query_gs_features, dict): gs_attr = self.forward_gs_attr( query_gs_features["coarse"][b], query_points[b], None, debug, x_fine=query_gs_features["fine"][b], ) else: gs_attr = self.forward_gs_attr( query_gs_features[b], query_points[b], None, debug ) gs_attr_list.append(gs_attr) return gs_attr_list, query_points, smplx_data def forward_animate_gs( self, gs_attr_list, query_points, smplx_data, c2w, intrinsic, height, width, background_color, debug=False, df_data=None, # deepfashion-style dataset ): batch_size = len(gs_attr_list) out_list = [] cano_out_list = [] # inference DO NOT use N_view = smplx_data["root_pose"].shape[1] for b in range(batch_size): gs_attr = gs_attr_list[b] query_pt = query_points[b] # len(animatable_gs_model_list) = num_view merge_animatable_gs_model_list, cano_gs_model_list = self.animate_gs_model( gs_attr, query_pt, self.get_single_batch_smpl_data(smplx_data, b), debug=debug, ) animatable_gs_model_list = merge_animatable_gs_model_list[:N_view] assert len(animatable_gs_model_list) == c2w.shape[1] # gs render animated gs model. out_list.append( self.forward_single_batch( animatable_gs_model_list, c2w[b], intrinsic[b], height, width, background_color[b] if background_color is not None else None, debug=debug, ) ) out = defaultdict(list) for out_ in out_list: for k, v in out_.items(): out[k].append(v) for k, v in out.items(): if isinstance(v[0], torch.Tensor): out[k] = torch.stack(v, dim=0) else: out[k] = v out["comp_rgb"] = out["comp_rgb"].permute( 0, 1, 4, 2, 3 ) # [B, NV, H, W, 3] -> [B, NV, 3, H, W] out["comp_mask"] = out["comp_mask"].permute( 0, 1, 4, 2, 3 ) # [B, NV, H, W, 3] -> [B, NV, 1, H, W] out["comp_depth"] = out["comp_depth"].permute( 0, 1, 4, 2, 3 ) # [B, NV, H, W, 3] -> [B, NV, 1, H, W] return out def forward( self, gs_hidden_features: Float[Tensor, "B Np Cp"], query_points: Float[Tensor, "B Np 3"], smplx_data, # e.g., body_pose:[B, Nv, 21, 3], betas:[B, 100] c2w: Float[Tensor, "B Nv 4 4"], intrinsic: Float[Tensor, "B Nv 4 4"], height, width, additional_features: Optional[Float[Tensor, "B C H W"]] = None, background_color: Optional[Float[Tensor, "B Nv 3"]] = None, debug: bool = False, **kwargs, ): # need shape_params of smplx_data to get querty points and get "transform_mat_neutral_pose" # only forward gs params gs_attr_list, query_points, smplx_data = self.forward_gs( gs_hidden_features, query_points, smplx_data=smplx_data, additional_features=additional_features, debug=debug, ) out = self.forward_animate_gs( gs_attr_list, query_points, smplx_data, c2w, intrinsic, height, width, background_color, debug, df_data=kwargs["df_data"], ) out["gs_attr"] = gs_attr_list return out def test(): import cv2 human_model_path = "./pretrained_models/human_model_files" smplx_data_root = "/data1/projects/ExAvatar_RELEASE/avatar/data/Custom/data/gyeongsik/smplx_optimized/smplx_params_smoothed" shape_param_file = "/data1/projects/ExAvatar_RELEASE/avatar/data/Custom/data/gyeongsik/smplx_optimized/shape_param.json" batch_size = 1 device = "cuda" smplx_data, cam_param_list, ori_image_list = read_smplx_param( smplx_data_root=smplx_data_root, shape_param_file=shape_param_file, batch_size=2 ) smplx_data_tmp = smplx_data for k, v in smplx_data.items(): smplx_data_tmp[k] = v.unsqueeze(0) if (k == "betas") or (k == "face_offset") or (k == "joint_offset"): smplx_data_tmp[k] = v[0].unsqueeze(0) smplx_data = smplx_data_tmp gs_render = GS3DRenderer( human_model_path=human_model_path, subdivide_num=2, smpl_type="smplx", feat_dim=64, query_dim=64, use_rgb=False, sh_degree=3, mlp_network_config=None, xyz_offset_max_step=1.8 / 32, ) gs_render.to(device) # print(cam_param_list[0]) c2w_list = [] intr_list = [] for cam_param in cam_param_list: c2w = torch.eye(4).to(device) c2w[:3, :3] = cam_param["R"] c2w[:3, 3] = cam_param["t"] c2w_list.append(c2w) intr = torch.eye(4).to(device) intr[0, 0] = cam_param["focal"][0] intr[1, 1] = cam_param["focal"][1] intr[0, 2] = cam_param["princpt"][0] intr[1, 2] = cam_param["princpt"][1] intr_list.append(intr) c2w = torch.stack(c2w_list).unsqueeze(0) intrinsic = torch.stack(intr_list).unsqueeze(0) out = gs_render.forward( gs_hidden_features=torch.zeros((batch_size, 2048, 64)).float().to(device), query_points=None, smplx_data=smplx_data, c2w=c2w, intrinsic=intrinsic, height=int(cam_param_list[0]["princpt"][1]) * 2, width=int(cam_param_list[0]["princpt"][0]) * 2, background_color=torch.tensor([1.0, 1.0, 1.0]) .float() .view(1, 1, 3) .repeat(batch_size, 2, 1) .to(device), debug=False, ) for k, v in out.items(): if k == "comp_rgb_bg": print("comp_rgb_bg", v) continue for b_idx in range(len(v)): if k == "3dgs": for v_idx in range(len(v[b_idx])): v[b_idx][v_idx].save_ply(f"./debug_vis/{b_idx}_{v_idx}.ply") continue for v_idx in range(v.shape[1]): save_path = os.path.join("./debug_vis", f"{b_idx}_{v_idx}_{k}.jpg") cv2.imwrite( save_path, (v[b_idx, v_idx].detach().cpu().numpy() * 255).astype(np.uint8), ) def test1(): import cv2 human_model_path = "./pretrained_models/human_model_files" device = "cuda" # root_dir = "/data1/projects/ExAvatar_RELEASE/avatar/data/Custom/data" # meta_path = "/data1/projects/ExAvatar_RELEASE/avatar/data/Custom/data/data_list.json" # dataset = ExAvatarDataset(root_dirs=root_dir, meta_path=meta_path, sample_side_views=3, # render_image_res_low=384, render_image_res_high=384, # render_region_size=(224, 224), source_image_res=384) # root_dir = "/data1/datasets1/3d_human_data/humman/humman_compressed" # meta_path = "/data1/datasets1/3d_human_data/humman/humman_id_debug_list.json" # dataset = HuMManDataset(root_dirs=root_dir, meta_path=meta_path, sample_side_views=3, # render_image_res_low=384, render_image_res_high=384, # render_region_size=(682, 384), source_image_res=384) # from openlrm.datasets.static_human import StaticHumanDataset # root_dir = "./train_data/static_human_data" # meta_path = "./train_data/static_human_data/data_id_list.json" # dataset = StaticHumanDataset(root_dirs=root_dir, meta_path=meta_path, sample_side_views=7, # render_image_res_low=384, render_image_res_high=384, # render_region_size=(682, 384), source_image_res=384, # debug=False) # from openlrm.datasets.singleview_human import SingleViewHumanDataset # root_dir = "./train_data/single_view" # meta_path = "./train_data/single_view/data_list.json" # dataset = SingleViewHumanDataset(root_dirs=root_dir, meta_path=meta_path, sample_side_views=0, # render_image_res_low=384, render_image_res_high=384, # render_region_size=(682, 384), source_image_res=384, # debug=False) from accelerate.utils import set_seed set_seed(1234) from LHM.datasets.video_human import VideoHumanDataset root_dir = "./train_data/ClothVideo" meta_path = "./train_data/ClothVideo/label/valid_id_with_img_list.json" dataset = VideoHumanDataset( root_dirs=root_dir, meta_path=meta_path, sample_side_views=7, render_image_res_low=384, render_image_res_high=384, render_region_size=(682, 384), source_image_res=384, enlarge_ratio=[0.85, 1.2], debug=False, ) data = dataset[0] def get_smplx_params(data): smplx_params = {} smplx_keys = [ "root_pose", "body_pose", "jaw_pose", "leye_pose", "reye_pose", "lhand_pose", "rhand_pose", "expr", "trans", "betas", ] for k, v in data.items(): if k in smplx_keys: # print(k, v.shape) smplx_params[k] = data[k] return smplx_params smplx_data = get_smplx_params(data) smplx_data_tmp = {} for k, v in smplx_data.items(): smplx_data_tmp[k] = v.unsqueeze(0).to(device) print(k, v.shape) smplx_data = smplx_data_tmp c2ws = data["c2ws"].unsqueeze(0).to(device) intrs = data["intrs"].unsqueeze(0).to(device) render_images = data["render_image"].numpy() render_h = data["render_full_resolutions"][0, 0] render_w = data["render_full_resolutions"][0, 1] render_bg_colors = data["render_bg_colors"].unsqueeze(0).to(device) print("c2ws", c2ws.shape, "intrs", intrs.shape, intrs) gs_render = GS3DRenderer( human_model_path=human_model_path, subdivide_num=2, smpl_type="smplx", feat_dim=64, query_dim=64, use_rgb=False, sh_degree=3, mlp_network_config=None, xyz_offset_max_step=1.8 / 32, expr_param_dim=10, shape_param_dim=10, fix_opacity=True, fix_rotation=True, ) gs_render.to(device) out = gs_render.forward( gs_hidden_features=torch.zeros((1, 2048, 64)).float().to(device), query_points=None, smplx_data=smplx_data, c2w=c2ws, intrinsic=intrs, height=render_h, width=render_w, background_color=render_bg_colors, debug=False, ) os.makedirs("./debug_vis/gs_render", exist_ok=True) for k, v in out.items(): if k == "comp_rgb_bg": print("comp_rgb_bg", v) continue for b_idx in range(len(v)): if k == "3dgs": for v_idx in range(len(v[b_idx])): v[b_idx][v_idx].save_ply( f"./debug_vis/gs_render/{b_idx}_{v_idx}.ply" ) continue for v_idx in range(v.shape[1]): save_path = os.path.join( "./debug_vis/gs_render", f"{b_idx}_{v_idx}_{k}.jpg" ) img = ( v[b_idx, v_idx].permute(1, 2, 0).detach().cpu().numpy() * 255 ).astype(np.uint8) print(img.shape, save_path) if "mask" in k: render_img = render_images[v_idx].transpose(1, 2, 0) * 255 cv2.imwrite( save_path, np.hstack( [np.tile(img, (1, 1, 3)), render_img.astype(np.uint8)] ), ) else: cv2.imwrite(save_path, img) if __name__ == "__main__": # test1() test() test()