LHM2 / LHM /models /rendering /gs_renderer.py
DyrusQZ's picture
update zerogpu
d74e574
import copy
import math
import os
import pdb
from collections import defaultdict
from dataclasses import dataclass, field
import numpy as np
import omegaconf
import torch
import torch.nn as nn
import torch.nn.functional as F
from diff_gaussian_rasterization import (
GaussianRasterizationSettings,
GaussianRasterizer,
)
from plyfile import PlyData, PlyElement
from pytorch3d.transforms import matrix_to_quaternion
from pytorch3d.transforms.rotation_conversions import quaternion_multiply
from LHM.models.rendering.smpl_x import SMPLXModel, read_smplx_param
from LHM.models.rendering.smpl_x_voxel_dense_sampling import SMPLXVoxelMeshModel
from LHM.models.rendering.utils.sh_utils import RGB2SH, SH2RGB
from LHM.models.rendering.utils.typing import *
from LHM.models.rendering.utils.utils import MLP, trunc_exp
from LHM.models.utils import LinerParameterTuner, StaticParameterTuner
from LHM.outputs.output import GaussianAppOutput
def auto_repeat_size(tensor, repeat_num, axis=0):
repeat_size = [1] * tensor.dim()
repeat_size[axis] = repeat_num
return repeat_size
def aabb(xyz):
return torch.min(xyz, dim=0).values, torch.max(xyz, dim=0).values
def inverse_sigmoid(x):
if isinstance(x, float):
x = torch.tensor(x).float()
return torch.log(x / (1 - x))
def generate_rotation_matrix_y(degrees):
theta = math.radians(degrees)
cos_theta = math.cos(theta)
sin_theta = math.sin(theta)
R = [[cos_theta, 0, sin_theta], [0, 1, 0], [-sin_theta, 0, cos_theta]]
return np.asarray(R, dtype=np.float32)
def getWorld2View2(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0):
Rt = np.zeros((4, 4))
Rt[:3, :3] = R.transpose()
Rt[:3, 3] = t
Rt[3, 3] = 1.0
C2W = np.linalg.inv(Rt)
cam_center = C2W[:3, 3]
cam_center = (cam_center + translate) * scale
C2W[:3, 3] = cam_center
Rt = np.linalg.inv(C2W)
return np.float32(Rt)
def getProjectionMatrix(znear, zfar, fovX, fovY):
tanHalfFovY = math.tan((fovY / 2))
tanHalfFovX = math.tan((fovX / 2))
top = tanHalfFovY * znear
bottom = -top
right = tanHalfFovX * znear
left = -right
P = torch.zeros(4, 4)
z_sign = 1.0
P[0, 0] = 2.0 * znear / (right - left)
P[1, 1] = 2.0 * znear / (top - bottom)
P[0, 2] = (right + left) / (right - left)
P[1, 2] = (top + bottom) / (top - bottom)
P[3, 2] = z_sign
P[2, 2] = z_sign * zfar / (zfar - znear)
P[2, 3] = -(zfar * znear) / (zfar - znear)
return P
def intrinsic_to_fov(intrinsic, w, h):
fx, fy = intrinsic[0, 0], intrinsic[1, 1]
fov_x = 2 * torch.arctan2(w, 2 * fx)
fov_y = 2 * torch.arctan2(h, 2 * fy)
return fov_x, fov_y
class Camera:
def __init__(
self,
w2c,
intrinsic,
FoVx,
FoVy,
height,
width,
trans=np.array([0.0, 0.0, 0.0]),
scale=1.0,
) -> None:
self.FoVx = FoVx
self.FoVy = FoVy
self.height = height
self.width = width
self.world_view_transform = w2c.transpose(0, 1)
self.zfar = 100.0
self.znear = 0.01
self.trans = trans
self.scale = scale
self.projection_matrix = (
getProjectionMatrix(
znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy
)
.transpose(0, 1)
.to(w2c.device)
)
self.full_proj_transform = (
self.world_view_transform.unsqueeze(0).bmm(
self.projection_matrix.unsqueeze(0)
)
).squeeze(0)
self.camera_center = self.world_view_transform.inverse()[3, :3]
self.intrinsic = intrinsic
@staticmethod
def from_c2w(c2w, intrinsic, height, width):
w2c = torch.inverse(c2w)
FoVx, FoVy = intrinsic_to_fov(
intrinsic,
w=torch.tensor(width, device=w2c.device),
h=torch.tensor(height, device=w2c.device),
)
return Camera(
w2c=w2c,
intrinsic=intrinsic,
FoVx=FoVx,
FoVy=FoVy,
height=height,
width=width,
)
class GaussianModel:
def setup_functions(self):
self.scaling_activation = torch.exp
self.scaling_inverse_activation = torch.log
self.opacity_activation = torch.sigmoid
self.inverse_opacity_activation = inverse_sigmoid
self.rotation_activation = torch.nn.functional.normalize
# rgb activation function
self.rgb_activation = torch.sigmoid
def __init__(self, xyz, opacity, rotation, scaling, shs, use_rgb=False) -> None:
"""
Initializes the GSRenderer object.
Args:
xyz (Tensor): The xyz coordinates.
opacity (Tensor): The opacity values.
rotation (Tensor): The rotation values.
scaling (Tensor): The scaling values.
before_activate: if True, the output appearance is needed to process by activation function.
shs (Tensor): The spherical harmonics coefficients.
use_rgb (bool, optional): Indicates whether shs represents RGB values. Defaults to False.
"""
self.setup_functions()
self.xyz: Tensor = xyz
self.opacity: Tensor = opacity
self.rotation: Tensor = rotation
self.scaling: Tensor = scaling
self.shs: Tensor = shs # [B, SH_Coeff, 3]
self.use_rgb = use_rgb # shs indicates rgb?
def construct_list_of_attributes(self):
l = ["x", "y", "z", "nx", "ny", "nz"]
features_dc = self.shs[:, :1]
features_rest = self.shs[:, 1:]
for i in range(features_dc.shape[1] * features_dc.shape[2]):
l.append("f_dc_{}".format(i))
for i in range(features_rest.shape[1] * features_rest.shape[2]):
l.append("f_rest_{}".format(i))
l.append("opacity")
for i in range(self.scaling.shape[1]):
l.append("scale_{}".format(i))
for i in range(self.rotation.shape[1]):
l.append("rot_{}".format(i))
return l
def save_ply(self, path):
xyz = self.xyz.detach().cpu().numpy()
normals = np.zeros_like(xyz)
if self.use_rgb:
shs = RGB2SH(self.shs)
else:
shs = self.shs
features_dc = shs[:, :1]
features_rest = shs[:, 1:]
f_dc = (
features_dc.float().detach().flatten(start_dim=1).contiguous().cpu().numpy()
)
f_rest = (
features_rest.float()
.detach()
.flatten(start_dim=1)
.contiguous()
.cpu()
.numpy()
)
opacities = (
inverse_sigmoid(torch.clamp(self.opacity, 1e-3, 1 - 1e-3))
.detach()
.cpu()
.numpy()
)
scale = np.log(self.scaling.detach().cpu().numpy())
rotation = self.rotation.detach().cpu().numpy()
dtype_full = [
(attribute, "f4") for attribute in self.construct_list_of_attributes()
]
elements = np.empty(xyz.shape[0], dtype=dtype_full)
attributes = np.concatenate(
(xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1
)
elements[:] = list(map(tuple, attributes))
el = PlyElement.describe(elements, "vertex")
PlyData([el]).write(path)
def load_ply(self, path):
plydata = PlyData.read(path)
xyz = np.stack(
(
np.asarray(plydata.elements[0]["x"]),
np.asarray(plydata.elements[0]["y"]),
np.asarray(plydata.elements[0]["z"]),
),
axis=1,
)
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
features_dc = np.zeros((xyz.shape[0], 3, 1))
features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
extra_f_names = [
p.name
for p in plydata.elements[0].properties
if p.name.startswith("f_rest_")
]
extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1]))
sh_degree = int(math.sqrt((len(extra_f_names) + 3) / 3)) - 1
print("load sh degree: ", sh_degree)
features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
for idx, attr_name in enumerate(extra_f_names):
features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
# Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
# 0, 3, 8, 15
features_extra = features_extra.reshape(
(features_extra.shape[0], 3, (sh_degree + 1) ** 2 - 1)
)
scale_names = [
p.name
for p in plydata.elements[0].properties
if p.name.startswith("scale_")
]
scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
scales = np.zeros((xyz.shape[0], len(scale_names)))
for idx, attr_name in enumerate(scale_names):
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
rot_names = [
p.name for p in plydata.elements[0].properties if p.name.startswith("rot")
]
rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
rots = np.zeros((xyz.shape[0], len(rot_names)))
for idx, attr_name in enumerate(rot_names):
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
xyz = torch.from_numpy(xyz).to(self.xyz)
opacities = torch.from_numpy(opacities).to(self.opacity)
rotation = torch.from_numpy(rots).to(self.rotation)
scales = torch.from_numpy(scales).to(self.scaling)
features_dc = torch.from_numpy(features_dc).to(self.shs)
features_rest = torch.from_numpy(features_extra).to(self.shs)
shs = torch.cat([features_dc, features_rest], dim=2)
if self.use_rgb:
shs = SH2RGB(shs)
else:
shs = shs
self.xyz: Tensor = xyz
self.opacity: Tensor = self.opacity_activation(opacities)
self.rotation: Tensor = self.rotation_activation(rotation)
self.scaling: Tensor = self.scaling_activation(scales)
self.shs: Tensor = shs.permute(0, 2, 1)
self.active_sh_degree = sh_degree
def clone(self):
xyz = self.xyz.clone()
opacity = self.opacity.clone()
rotation = self.rotation.clone()
scaling = self.scaling.clone()
shs = self.shs.clone()
use_rgb = self.use_rgb
return GaussianModel(xyz, opacity, rotation, scaling, shs, use_rgb)
class GSLayer(nn.Module):
"""W/O Activation Function"""
def setup_functions(self):
self.scaling_activation = trunc_exp # proposed by torch-ngp
self.scaling_inverse_activation = torch.log
self.opacity_activation = torch.sigmoid
self.inverse_opacity_activation = inverse_sigmoid
self.rotation_activation = torch.nn.functional.normalize
self.rgb_activation = torch.sigmoid
def __init__(
self,
in_channels,
use_rgb,
clip_scaling=0.2,
init_scaling=-5.0,
init_density=0.1,
sh_degree=None,
xyz_offset=True,
restrict_offset=True,
xyz_offset_max_step=None,
fix_opacity=False,
fix_rotation=False,
use_fine_feat=False,
):
super().__init__()
self.setup_functions()
if isinstance(clip_scaling, omegaconf.listconfig.ListConfig) or isinstance(
clip_scaling, list
):
self.clip_scaling_pruner = LinerParameterTuner(*clip_scaling)
else:
self.clip_scaling_pruner = StaticParameterTuner(clip_scaling)
self.clip_scaling = self.clip_scaling_pruner.get_value(0)
self.use_rgb = use_rgb
self.restrict_offset = restrict_offset
self.xyz_offset = xyz_offset
self.xyz_offset_max_step = xyz_offset_max_step # 1.2 / 32
self.fix_opacity = fix_opacity
self.fix_rotation = fix_rotation
self.use_fine_feat = use_fine_feat
self.attr_dict = {
"shs": (sh_degree + 1) ** 2 * 3,
"scaling": 3,
"xyz": 3,
"opacity": None,
"rotation": None,
}
if not self.fix_opacity:
self.attr_dict["opacity"] = 1
if not self.fix_rotation:
self.attr_dict["rotation"] = 4
self.out_layers = nn.ModuleDict()
for key, out_ch in self.attr_dict.items():
if out_ch is None:
layer = nn.Identity()
else:
if key == "shs" and use_rgb:
out_ch = 3
if key == "shs":
shs_out_ch = out_ch
layer = nn.Linear(in_channels, out_ch)
# initialize
if not (key == "shs" and use_rgb):
if key == "opacity" and self.fix_opacity:
pass
elif key == "rotation" and self.fix_rotation:
pass
else:
nn.init.constant_(layer.weight, 0)
nn.init.constant_(layer.bias, 0)
if key == "scaling":
nn.init.constant_(layer.bias, init_scaling)
elif key == "rotation":
if not self.fix_rotation:
nn.init.constant_(layer.bias, 0)
nn.init.constant_(layer.bias[0], 1.0)
elif key == "opacity":
if not self.fix_opacity:
nn.init.constant_(layer.bias, inverse_sigmoid(init_density))
self.out_layers[key] = layer
if self.use_fine_feat:
fine_shs_layer = nn.Linear(in_channels, shs_out_ch)
nn.init.constant_(fine_shs_layer.weight, 0)
nn.init.constant_(fine_shs_layer.bias, 0)
self.out_layers["fine_shs"] = fine_shs_layer
def hyper_step(self, step):
self.clip_scaling = self.clip_scaling_pruner.get_value(step)
def forward(self, x, pts, x_fine=None):
assert len(x.shape) == 2
ret = {}
for k in self.attr_dict:
layer = self.out_layers[k]
v = layer(x)
if k == "rotation":
if self.fix_rotation:
v = matrix_to_quaternion(
torch.eye(3).type_as(x)[None, :, :].repeat(x.shape[0], 1, 1)
) # constant rotation
else:
# v = torch.nn.functional.normalize(v)
v = self.rotation_activation(v)
elif k == "scaling":
# v = trunc_exp(v)
v = self.scaling_activation(v)
if self.clip_scaling is not None:
v = torch.clamp(v, min=0, max=self.clip_scaling)
elif k == "opacity":
if self.fix_opacity:
v = torch.ones_like(x)[..., 0:1]
else:
# v = torch.sigmoid(v)
v = self.opacity_activation(v)
elif k == "shs":
if self.use_rgb:
# v = torch.sigmoid(v)
v = self.rgb_activation(v)
if self.use_fine_feat:
v_fine = self.out_layers["fine_shs"](x_fine)
v_fine = torch.tanh(v_fine)
v = v + v_fine
else:
if self.use_fine_feat:
v_fine = self.out_layers["fine_shs"](x_fine)
v = v + v_fine
v = torch.reshape(v, (v.shape[0], -1, 3))
elif k == "xyz":
# TODO check
if self.restrict_offset:
max_step = self.xyz_offset_max_step
v = (torch.sigmoid(v) - 0.5) * max_step
if self.xyz_offset:
pass
else:
assert NotImplementedError
v = v + pts
k = "offset_xyz"
ret[k] = v
ret["use_rgb"] = self.use_rgb
return GaussianAppOutput(**ret)
class PointEmbed(nn.Module):
def __init__(self, hidden_dim=48, dim=128):
super().__init__()
assert hidden_dim % 6 == 0
self.embedding_dim = hidden_dim
e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
e = torch.stack(
[
torch.cat(
[
e,
torch.zeros(self.embedding_dim // 6),
torch.zeros(self.embedding_dim // 6),
]
),
torch.cat(
[
torch.zeros(self.embedding_dim // 6),
e,
torch.zeros(self.embedding_dim // 6),
]
),
torch.cat(
[
torch.zeros(self.embedding_dim // 6),
torch.zeros(self.embedding_dim // 6),
e,
]
),
]
)
self.register_buffer("basis", e) # 3 x 16
self.mlp = nn.Linear(self.embedding_dim + 3, dim)
self.norm = nn.LayerNorm(dim)
@staticmethod
def embed(input, basis):
projections = torch.einsum("bnd,de->bne", input, basis)
embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
return embeddings
def forward(self, input):
# input: B x N x 3
embed = self.mlp(
torch.cat([self.embed(input, self.basis), input], dim=2)
) # B x N x C
embed = self.norm(embed)
return embed
class CrossAttnBlock(nn.Module):
"""
Transformer block that takes in a cross-attention condition.
Designed for SparseLRM architecture.
"""
# Block contains a cross-attention layer, a self-attention layer, and an MLP
def __init__(
self,
inner_dim: int,
cond_dim: int,
num_heads: int,
eps: float = None,
attn_drop: float = 0.0,
attn_bias: bool = False,
mlp_ratio: float = 4.0,
mlp_drop: float = 0.0,
feedforward=False,
):
super().__init__()
# TODO check already apply normalization
# self.norm_q = nn.LayerNorm(inner_dim, eps=eps)
# self.norm_k = nn.LayerNorm(cond_dim, eps=eps)
self.norm_q = nn.Identity()
self.norm_k = nn.Identity()
self.cross_attn = nn.MultiheadAttention(
embed_dim=inner_dim,
num_heads=num_heads,
kdim=cond_dim,
vdim=cond_dim,
dropout=attn_drop,
bias=attn_bias,
batch_first=True,
)
self.mlp = None
if feedforward:
self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
self.self_attn = nn.MultiheadAttention(
embed_dim=inner_dim,
num_heads=num_heads,
dropout=attn_drop,
bias=attn_bias,
batch_first=True,
)
self.norm3 = nn.LayerNorm(inner_dim, eps=eps)
self.mlp = nn.Sequential(
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
nn.GELU(),
nn.Dropout(mlp_drop),
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
nn.Dropout(mlp_drop),
)
def forward(self, x, cond):
# x: [N, L, D]
# cond: [N, L_cond, D_cond]
x = self.cross_attn(
self.norm_q(x), self.norm_k(cond), cond, need_weights=False
)[0]
if self.mlp is not None:
before_sa = self.norm2(x)
x = (
x
+ self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
)
x = x + self.mlp(self.norm3(x))
return x
class DecoderCrossAttn(nn.Module):
def __init__(
self, query_dim, context_dim, num_heads, mlp=False, decode_with_extra_info=None
):
super().__init__()
self.query_dim = query_dim
self.context_dim = context_dim
self.cross_attn = CrossAttnBlock(
inner_dim=query_dim,
cond_dim=context_dim,
num_heads=num_heads,
feedforward=mlp,
eps=1e-5,
)
self.decode_with_extra_info = decode_with_extra_info
if decode_with_extra_info is not None:
if decode_with_extra_info["type"] == "dinov2p14_feat":
context_dim = decode_with_extra_info["cond_dim"]
self.cross_attn_color = CrossAttnBlock(
inner_dim=query_dim,
cond_dim=context_dim,
num_heads=num_heads,
feedforward=False,
eps=1e-5,
)
elif decode_with_extra_info["type"] == "decoder_dinov2p14_feat":
from LHM.models.encoders.dinov2_wrapper import Dinov2Wrapper
self.encoder = Dinov2Wrapper(
model_name="dinov2_vits14_reg", freeze=False, encoder_feat_dim=384
)
self.cross_attn_color = CrossAttnBlock(
inner_dim=query_dim,
cond_dim=384,
num_heads=num_heads,
feedforward=False,
eps=1e-5,
)
elif decode_with_extra_info["type"] == "decoder_resnet18_feat":
from LHM.models.encoders.xunet_wrapper import XnetWrapper
self.encoder = XnetWrapper(
model_name="resnet18", freeze=False, encoder_feat_dim=64
)
self.cross_attn_color = CrossAttnBlock(
inner_dim=query_dim,
cond_dim=64,
num_heads=num_heads,
feedforward=False,
eps=1e-5,
)
def resize_image(self, image, multiply):
B, _, H, W = image.shape
new_h, new_w = (
math.ceil(H / multiply) * multiply,
math.ceil(W / multiply) * multiply,
)
image = F.interpolate(
image, (new_h, new_w), align_corners=True, mode="bilinear"
)
return image
def forward(self, pcl_query, pcl_latent, extra_info=None):
out = self.cross_attn(pcl_query, pcl_latent)
if self.decode_with_extra_info is not None:
out_dict = {}
out_dict["coarse"] = out
if self.decode_with_extra_info["type"] == "dinov2p14_feat":
out = self.cross_attn_color(out, extra_info["image_feats"])
out_dict["fine"] = out
return out_dict
elif self.decode_with_extra_info["type"] == "decoder_dinov2p14_feat":
img_feat = self.encoder(extra_info["image"])
out = self.cross_attn_color(out, img_feat)
out_dict["fine"] = out
return out_dict
elif self.decode_with_extra_info["type"] == "decoder_resnet18_feat":
image = extra_info["image"]
image = self.resize_image(image, multiply=32)
img_feat = self.encoder(image)
out = self.cross_attn_color(out, img_feat)
out_dict["fine"] = out
return out_dict
return out
class GS3DRenderer(nn.Module):
def __init__(
self,
human_model_path,
subdivide_num,
smpl_type,
feat_dim,
query_dim,
use_rgb,
sh_degree,
xyz_offset_max_step,
mlp_network_config,
expr_param_dim,
shape_param_dim,
clip_scaling=0.2,
cano_pose_type=0,
decoder_mlp=False,
skip_decoder=False,
fix_opacity=False,
fix_rotation=False,
decode_with_extra_info=None,
gradient_checkpointing=False,
apply_pose_blendshape=False,
dense_sample_pts=40000, # only use for dense_smaple_smplx
):
super().__init__()
self.gradient_checkpointing = gradient_checkpointing
self.skip_decoder = skip_decoder
self.smpl_type = smpl_type
assert self.smpl_type in ["smplx", "smplx_0", "smplx_1", "smplx_2"]
self.scaling_modifier = 1.0
self.sh_degree = sh_degree
if self.smpl_type == "smplx_0" or self.smpl_type == "smplx":
# Using pytorch3d dense sampling
self.smplx_model = SMPLXModel(
human_model_path,
gender="neutral",
subdivide_num=subdivide_num,
shape_param_dim=shape_param_dim,
expr_param_dim=expr_param_dim,
cano_pose_type=cano_pose_type,
apply_pose_blendshape=apply_pose_blendshape,
)
elif self.smpl_type == "smplx_1":
raise NotImplementedError("inference version does not support")
elif self.smpl_type == "smplx_2":
self.smplx_model = SMPLXVoxelMeshModel(
human_model_path,
gender="neutral",
subdivide_num=subdivide_num,
shape_param_dim=shape_param_dim,
expr_param_dim=expr_param_dim,
cano_pose_type=cano_pose_type,
dense_sample_points=dense_sample_pts,
apply_pose_blendshape=apply_pose_blendshape,
)
else:
raise NotImplementedError
if not self.skip_decoder:
self.pcl_embed = PointEmbed(dim=query_dim)
self.decoder_cross_attn = DecoderCrossAttn(
query_dim=query_dim,
context_dim=feat_dim,
num_heads=1,
mlp=decoder_mlp,
decode_with_extra_info=decode_with_extra_info,
)
self.mlp_network_config = mlp_network_config
# using to mapping transformer decode feature to regression features. as decode feature is processed by NormLayer.
if self.mlp_network_config is not None:
self.mlp_net = MLP(query_dim, query_dim, **self.mlp_network_config)
self.gs_net = GSLayer(
in_channels=query_dim,
use_rgb=use_rgb,
sh_degree=self.sh_degree,
clip_scaling=clip_scaling,
init_scaling=-5.0,
init_density=0.1,
xyz_offset=True,
restrict_offset=True,
xyz_offset_max_step=xyz_offset_max_step,
fix_opacity=fix_opacity,
fix_rotation=fix_rotation,
use_fine_feat=(
True
if decode_with_extra_info is not None
and decode_with_extra_info["type"] is not None
else False
),
)
def hyper_step(self, step):
self.gs_net.hyper_step(step)
def forward_single_view(
self,
gs: GaussianModel,
viewpoint_camera: Camera,
background_color: Optional[Float[Tensor, "3"]],
ret_mask: bool = True,
):
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
screenspace_points = (
torch.zeros_like(
gs.xyz, dtype=gs.xyz.dtype, requires_grad=True, device=self.device
)
+ 0
)
try:
screenspace_points.retain_grad()
except:
pass
bg_color = background_color
# Set up rasterization configuration
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
raster_settings = GaussianRasterizationSettings(
image_height=int(viewpoint_camera.height),
image_width=int(viewpoint_camera.width),
tanfovx=tanfovx,
tanfovy=tanfovy,
bg=bg_color,
scale_modifier=self.scaling_modifier,
viewmatrix=viewpoint_camera.world_view_transform,
projmatrix=viewpoint_camera.full_proj_transform.float(),
sh_degree=self.sh_degree,
campos=viewpoint_camera.camera_center,
prefiltered=False,
debug=True,
)
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
means3D = gs.xyz
means2D = screenspace_points
opacity = gs.opacity
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
# scaling / rotation by the rasterizer.
scales = None
rotations = None
cov3D_precomp = None
scales = gs.scaling
rotations = gs.rotation
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
shs = None
colors_precomp = None
if self.gs_net.use_rgb:
colors_precomp = gs.shs.squeeze(1)
shs = None
else:
colors_precomp = None
shs = gs.shs
# print(shs, colors_precomp)
# print(means3D.device, means2D.device, opacity.device, rotations.device, self.device)
# print(means3D.dtype, means2D.dtype, rotations.dtype, opacity.dtype)
# print(means3D.shape, means2D.shape, rotations.shape, opacity.shape)
# Rasterize visible Gaussians to image, obtain their radii (on screen).
# NOTE that dadong tries to regress rgb not shs
# with torch.autocast(device_type=self.device.type, dtype=torch.float32):
rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(
means3D=means3D,
means2D=means2D,
shs=shs,
colors_precomp=colors_precomp,
opacities=opacity,
scales=scales,
rotations=rotations,
cov3D_precomp=cov3D_precomp,
)
ret = {
"comp_rgb": rendered_image.permute(1, 2, 0), # [H, W, 3]
"comp_rgb_bg": bg_color,
"comp_mask": rendered_alpha.permute(1, 2, 0),
"comp_depth": rendered_depth.permute(1, 2, 0),
}
# if ret_mask:
# mask_bg_color = torch.zeros(3, dtype=torch.float32, device=self.device)
# raster_settings = GaussianRasterizationSettings(
# image_height=int(viewpoint_camera.height),
# image_width=int(viewpoint_camera.width),
# tanfovx=tanfovx,
# tanfovy=tanfovy,
# bg=mask_bg_color,
# scale_modifier=self.scaling_modifier,
# viewmatrix=viewpoint_camera.world_view_transform,
# projmatrix=viewpoint_camera.full_proj_transform.float(),
# sh_degree=0,
# campos=viewpoint_camera.camera_center,
# prefiltered=False,
# debug=False
# )
# rasterizer = GaussianRasterizer(raster_settings=raster_settings)
# with torch.autocast(device_type=self.device.type, dtype=torch.float32):
# rendered_mask, radii = rasterizer(
# means3D = means3D,
# means2D = means2D,
# # shs = ,
# colors_precomp = torch.ones_like(means3D),
# opacities = opacity,
# scales = scales,
# rotations = rotations,
# cov3D_precomp = cov3D_precomp)
# ret["comp_mask"] = rendered_mask.permute(1, 2, 0)
return ret
def animate_gs_model(
self, gs_attr: GaussianAppOutput, query_points, smplx_data, debug=False
):
"""
query_points: [N, 3]
"""
device = gs_attr.offset_xyz.device
if debug:
N = gs_attr.offset_xyz.shape[0]
gs_attr.xyz = torch.ones_like(gs_attr.offset_xyz) * 0.0
rotation = matrix_to_quaternion(
torch.eye(3).float()[None, :, :].repeat(N, 1, 1)
).to(
device
) # constant rotation
opacity = torch.ones((N, 1)).float().to(device) # constant opacity
gs_attr.opacity = opacity
gs_attr.rotation = rotation
# gs_attr.scaling = torch.ones_like(gs_attr.scaling) * 0.05
# print(gs_attr.shs.shape)
# build cano_dependent_pose
cano_smplx_data_keys = [
"root_pose",
"body_pose",
"jaw_pose",
"leye_pose",
"reye_pose",
"lhand_pose",
"rhand_pose",
"expr",
"trans",
]
merge_smplx_data = dict()
for cano_smplx_data_key in cano_smplx_data_keys:
warp_data = smplx_data[cano_smplx_data_key]
cano_pose = torch.zeros_like(warp_data[:1])
if cano_smplx_data_key == "body_pose":
# A-posed
cano_pose[0, 15, -1] = -math.pi / 6
cano_pose[0, 16, -1] = +math.pi / 6
merge_pose = torch.cat([warp_data, cano_pose], dim=0)
merge_smplx_data[cano_smplx_data_key] = merge_pose
merge_smplx_data["betas"] = smplx_data["betas"]
merge_smplx_data["transform_mat_neutral_pose"] = smplx_data[
"transform_mat_neutral_pose"
]
with torch.autocast(device_type=device.type, dtype=torch.float32):
mean_3d = (
query_points + gs_attr.offset_xyz
) # [N, 3] # canonical space offset.
# matrix to warp predefined pose to zero-pose
transform_mat_neutral_pose = merge_smplx_data[
"transform_mat_neutral_pose"
] # [55, 4, 4]
num_view = merge_smplx_data["body_pose"].shape[0] # [Nv, 21, 3]
mean_3d = mean_3d.unsqueeze(0).repeat(num_view, 1, 1) # [Nv, N, 3]
query_points = query_points.unsqueeze(0).repeat(num_view, 1, 1)
transform_mat_neutral_pose = transform_mat_neutral_pose.unsqueeze(0).repeat(
num_view, 1, 1, 1
)
# print(mean_3d.shape, transform_mat_neutral_pose.shape, query_points.shape, smplx_data["body_pose"].shape, smplx_data["betas"].shape)
mean_3d, transform_matrix = (
self.smplx_model.transform_to_posed_verts_from_neutral_pose(
mean_3d,
merge_smplx_data,
query_points,
transform_mat_neutral_pose=transform_mat_neutral_pose, # from predefined pose to zero-pose matrix
device=device,
)
) # [B, N, 3]
# rotation appearance from canonical space to view_posed
num_view, N, _, _ = transform_matrix.shape
transform_rotation = transform_matrix[:, :, :3, :3]
rigid_rotation_matrix = torch.nn.functional.normalize(
matrix_to_quaternion(transform_rotation), dim=-1
)
I = matrix_to_quaternion(torch.eye(3)).to(device)
# inference constrain
is_constrain_body = self.smplx_model.is_constrain_body
rigid_rotation_matrix[:, is_constrain_body] = I
gs_attr.scaling[is_constrain_body] = gs_attr.scaling[
is_constrain_body
].clamp(max=0.02)
rotation_neutral_pose = gs_attr.rotation.unsqueeze(0).repeat(num_view, 1, 1)
# TODO do not move underarm gs
# QUATERNION MULTIPLY
rotation_pose_verts = quaternion_multiply(
rigid_rotation_matrix, rotation_neutral_pose
)
# rotation_pose_verts = rotation_neutral_pose
gs_list = []
cano_gs_list = []
for i in range(num_view):
gs_copy = GaussianModel(
xyz=mean_3d[i],
opacity=gs_attr.opacity,
# rotation=gs_attr.rotation,
rotation=rotation_pose_verts[i],
scaling=gs_attr.scaling,
shs=gs_attr.shs,
use_rgb=self.gs_net.use_rgb,
) # [N, 3]
if i == num_view - 1:
cano_gs_list.append(gs_copy)
else:
gs_list.append(gs_copy)
return gs_list, cano_gs_list
def forward_gs_attr(self, x, query_points, smplx_data, debug=False, x_fine=None):
"""
x: [N, C] Float[Tensor, "Np Cp"],
query_points: [N, 3] Float[Tensor, "Np 3"]
"""
device = x.device
if self.mlp_network_config is not None:
# x is processed by LayerNorm
x = self.mlp_net(x)
if x_fine is not None:
x_fine = self.mlp_net(x_fine)
# NOTE that gs_attr contains offset xyz
gs_attr: GaussianAppOutput = self.gs_net(x, query_points, x_fine)
return gs_attr
def get_query_points(self, smplx_data, device):
with torch.no_grad():
with torch.autocast(device_type=device.type, dtype=torch.float32):
# print(smplx_data["betas"].shape, smplx_data["face_offset"].shape, smplx_data["joint_offset"].shape)
positions, _, transform_mat_neutral_pose = (
self.smplx_model.get_query_points(smplx_data, device=device)
) # [B, N, 3]
smplx_data["transform_mat_neutral_pose"] = (
transform_mat_neutral_pose # [B, 55, 4, 4]
)
return positions, smplx_data
def decoder_cross_attn_wrapper(self, pcl_embed, latent_feat, extra_info):
# if self.training and self.gradient_checkpointing:
# def create_custom_forward(module):
# def custom_forward(*inputs):
# return module(*inputs)
# return custom_forward
# ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
# gs_feats = torch.utils.checkpoint.checkpoint(
# create_custom_forward(self.decoder_cross_attn),
# pcl_embed.to(dtype=latent_feat.dtype),
# latent_feat,
# extra_info,
# **ckpt_kwargs,
# )
# else:
gs_feats = self.decoder_cross_attn(
pcl_embed.to(dtype=latent_feat.dtype), latent_feat, extra_info
)
return gs_feats
def query_latent_feat(
self,
positions: Float[Tensor, "*B N1 3"],
smplx_data,
latent_feat: Float[Tensor, "*B N2 C"],
extra_info,
):
device = latent_feat.device
if self.skip_decoder:
gs_feats = latent_feat
assert positions is not None
else:
assert positions is None
if positions is None:
positions, smplx_data = self.get_query_points(smplx_data, device)
with torch.autocast(device_type=device.type, dtype=torch.float32):
pcl_embed = self.pcl_embed(positions)
gs_feats = self.decoder_cross_attn_wrapper(
pcl_embed, latent_feat, extra_info
)
return gs_feats, positions, smplx_data
def forward_single_batch(
self,
gs_list: list[GaussianModel],
c2ws: Float[Tensor, "Nv 4 4"],
intrinsics: Float[Tensor, "Nv 4 4"],
height: int,
width: int,
background_color: Optional[Float[Tensor, "Nv 3"]],
debug: bool = False,
):
out_list = []
self.device = gs_list[0].xyz.device
for v_idx, (c2w, intrinsic) in enumerate(zip(c2ws, intrinsics)):
out_list.append(
self.forward_single_view(
gs_list[v_idx],
Camera.from_c2w(c2w, intrinsic, height, width),
background_color[v_idx],
)
)
out = defaultdict(list)
for out_ in out_list:
for k, v in out_.items():
out[k].append(v)
out = {k: torch.stack(v, dim=0) for k, v in out.items()}
out["3dgs"] = gs_list
# debug = True
if debug:
import cv2
cv2.imwrite(
"fuck.png",
(out["comp_rgb"].detach().cpu().numpy()[0, ..., ::-1] * 255).astype(
np.uint8
),
)
return out
def forward_cano_batch(
self,
gs_list: list[GaussianModel],
c2ws: Float[Tensor, "Nv 4 4"],
intrinsics: Float[Tensor, "Nv 4 4"],
background_color: Optional[Float[Tensor, "Nv 3"]],
height: int = 512,
width: int = 512,
debug: bool = False,
):
"""using to visualization."""
degree_list = [0, 90, 180, 270]
out_list = []
self.device = gs_list[0].xyz.device
gs_list_copy = [gs_list[0].clone() for _ in range(len(degree_list))]
rotation_gs_list = []
for rotation_degree, gs in zip(degree_list, gs_list_copy):
_R = torch.eye(3).to(gs.xyz)
_R[-1, -1] *= -1
_R[1, 1] *= -1
self_R = torch.from_numpy(generate_rotation_matrix_y(rotation_degree)).to(
_R
)
_R = self_R @ _R
gs.xyz = (_R @ gs.xyz.T).T
_min, _max = aabb(gs.xyz)
center = (_min + _max) / 2
gs.xyz -= center.unsqueeze(0)
_R_quaternion = matrix_to_quaternion(_R)
gs.rotation = quaternion_multiply(_R_quaternion, gs.rotation)
gs.xyz[..., -1] += 2.5 # move to (0, 0, 3)
rotation_gs_list.append(gs)
intrinsics = torch.eye(4).to(intrinsics).unsqueeze(0)
intrinsics[0, 0, 0] = width
intrinsics[0, 1, 1] = height
intrinsics[0, 0, 2] = width / 2
intrinsics[0, 1, 2] = height / 2
for v_idx, gs in enumerate(rotation_gs_list):
out_list.append(
self.forward_single_view(
rotation_gs_list[v_idx],
Camera.from_c2w(c2ws[0], intrinsics[0], height, width),
torch.ones_like(background_color[0]),
)
)
out = defaultdict(list)
for out_ in out_list:
for k, v in out_.items():
out[k].append(v)
out = {k: torch.stack(v, dim=0) for k, v in out.items()}
out["3dgs"] = rotation_gs_list
if debug:
import cv2
for i in range(4):
cv2.imwrite(
f"fuck_{i}.png",
(out["comp_rgb"].detach().cpu().numpy()[i, ..., ::-1] * 255).astype(
np.uint8
),
)
return out
def get_single_batch_smpl_data(self, smpl_data, bidx):
smpl_data_single_batch = {}
for k, v in smpl_data.items():
smpl_data_single_batch[k] = v[
bidx
] # e.g. body_pose: [B, N_v, 21, 3] -> [N_v, 21, 3]
if k == "betas" or (k == "joint_offset") or (k == "face_offset"):
smpl_data_single_batch[k] = v[
bidx : bidx + 1
] # e.g. betas: [B, 100] -> [1, 100]
return smpl_data_single_batch
def get_single_view_smpl_data(self, smpl_data, vidx):
smpl_data_single_view = {}
for k, v in smpl_data.items():
assert v.shape[0] == 1
if (
k == "betas"
or (k == "joint_offset")
or (k == "face_offset")
or (k == "transform_mat_neutral_pose")
):
smpl_data_single_view[k] = v # e.g. betas: [1, 100] -> [1, 100]
else:
smpl_data_single_view[k] = v[
:, vidx : vidx + 1
] # e.g. body_pose: [1, N_v, 21, 3] -> [1, 1, 21, 3]
return smpl_data_single_view
def forward_gs(
self,
gs_hidden_features: Float[Tensor, "B Np Cp"],
query_points: Float[Tensor, "B Np_q 3"],
smplx_data, # e.g., body_pose:[B, Nv, 21, 3], betas:[B, 100]
additional_features: Optional[dict] = None,
debug: bool = False,
**kwargs,
):
batch_size = gs_hidden_features.shape[0]
# obtain gs_features embedding, cur points position, and also smplx params
query_gs_features, query_points, smplx_data = self.query_latent_feat(
query_points, smplx_data, gs_hidden_features, additional_features
)
gs_attr_list = []
for b in range(batch_size):
if isinstance(query_gs_features, dict):
gs_attr = self.forward_gs_attr(
query_gs_features["coarse"][b],
query_points[b],
None,
debug,
x_fine=query_gs_features["fine"][b],
)
else:
gs_attr = self.forward_gs_attr(
query_gs_features[b], query_points[b], None, debug
)
gs_attr_list.append(gs_attr)
return gs_attr_list, query_points, smplx_data
def forward_animate_gs(
self,
gs_attr_list,
query_points,
smplx_data,
c2w,
intrinsic,
height,
width,
background_color,
debug=False,
df_data=None, # deepfashion-style dataset
):
batch_size = len(gs_attr_list)
out_list = []
cano_out_list = [] # inference DO NOT use
N_view = smplx_data["root_pose"].shape[1]
for b in range(batch_size):
gs_attr = gs_attr_list[b]
query_pt = query_points[b]
# len(animatable_gs_model_list) = num_view
merge_animatable_gs_model_list, cano_gs_model_list = self.animate_gs_model(
gs_attr,
query_pt,
self.get_single_batch_smpl_data(smplx_data, b),
debug=debug,
)
animatable_gs_model_list = merge_animatable_gs_model_list[:N_view]
assert len(animatable_gs_model_list) == c2w.shape[1]
# gs render animated gs model.
out_list.append(
self.forward_single_batch(
animatable_gs_model_list,
c2w[b],
intrinsic[b],
height,
width,
background_color[b] if background_color is not None else None,
debug=debug,
)
)
out = defaultdict(list)
for out_ in out_list:
for k, v in out_.items():
out[k].append(v)
for k, v in out.items():
if isinstance(v[0], torch.Tensor):
out[k] = torch.stack(v, dim=0)
else:
out[k] = v
out["comp_rgb"] = out["comp_rgb"].permute(
0, 1, 4, 2, 3
) # [B, NV, H, W, 3] -> [B, NV, 3, H, W]
out["comp_mask"] = out["comp_mask"].permute(
0, 1, 4, 2, 3
) # [B, NV, H, W, 3] -> [B, NV, 1, H, W]
out["comp_depth"] = out["comp_depth"].permute(
0, 1, 4, 2, 3
) # [B, NV, H, W, 3] -> [B, NV, 1, H, W]
return out
def forward(
self,
gs_hidden_features: Float[Tensor, "B Np Cp"],
query_points: Float[Tensor, "B Np 3"],
smplx_data, # e.g., body_pose:[B, Nv, 21, 3], betas:[B, 100]
c2w: Float[Tensor, "B Nv 4 4"],
intrinsic: Float[Tensor, "B Nv 4 4"],
height,
width,
additional_features: Optional[Float[Tensor, "B C H W"]] = None,
background_color: Optional[Float[Tensor, "B Nv 3"]] = None,
debug: bool = False,
**kwargs,
):
# need shape_params of smplx_data to get querty points and get "transform_mat_neutral_pose"
# only forward gs params
gs_attr_list, query_points, smplx_data = self.forward_gs(
gs_hidden_features,
query_points,
smplx_data=smplx_data,
additional_features=additional_features,
debug=debug,
)
out = self.forward_animate_gs(
gs_attr_list,
query_points,
smplx_data,
c2w,
intrinsic,
height,
width,
background_color,
debug,
df_data=kwargs["df_data"],
)
out["gs_attr"] = gs_attr_list
return out
def test():
import cv2
human_model_path = "./pretrained_models/human_model_files"
smplx_data_root = "/data1/projects/ExAvatar_RELEASE/avatar/data/Custom/data/gyeongsik/smplx_optimized/smplx_params_smoothed"
shape_param_file = "/data1/projects/ExAvatar_RELEASE/avatar/data/Custom/data/gyeongsik/smplx_optimized/shape_param.json"
batch_size = 1
device = "cuda"
smplx_data, cam_param_list, ori_image_list = read_smplx_param(
smplx_data_root=smplx_data_root, shape_param_file=shape_param_file, batch_size=2
)
smplx_data_tmp = smplx_data
for k, v in smplx_data.items():
smplx_data_tmp[k] = v.unsqueeze(0)
if (k == "betas") or (k == "face_offset") or (k == "joint_offset"):
smplx_data_tmp[k] = v[0].unsqueeze(0)
smplx_data = smplx_data_tmp
gs_render = GS3DRenderer(
human_model_path=human_model_path,
subdivide_num=2,
smpl_type="smplx",
feat_dim=64,
query_dim=64,
use_rgb=False,
sh_degree=3,
mlp_network_config=None,
xyz_offset_max_step=1.8 / 32,
)
gs_render.to(device)
# print(cam_param_list[0])
c2w_list = []
intr_list = []
for cam_param in cam_param_list:
c2w = torch.eye(4).to(device)
c2w[:3, :3] = cam_param["R"]
c2w[:3, 3] = cam_param["t"]
c2w_list.append(c2w)
intr = torch.eye(4).to(device)
intr[0, 0] = cam_param["focal"][0]
intr[1, 1] = cam_param["focal"][1]
intr[0, 2] = cam_param["princpt"][0]
intr[1, 2] = cam_param["princpt"][1]
intr_list.append(intr)
c2w = torch.stack(c2w_list).unsqueeze(0)
intrinsic = torch.stack(intr_list).unsqueeze(0)
out = gs_render.forward(
gs_hidden_features=torch.zeros((batch_size, 2048, 64)).float().to(device),
query_points=None,
smplx_data=smplx_data,
c2w=c2w,
intrinsic=intrinsic,
height=int(cam_param_list[0]["princpt"][1]) * 2,
width=int(cam_param_list[0]["princpt"][0]) * 2,
background_color=torch.tensor([1.0, 1.0, 1.0])
.float()
.view(1, 1, 3)
.repeat(batch_size, 2, 1)
.to(device),
debug=False,
)
for k, v in out.items():
if k == "comp_rgb_bg":
print("comp_rgb_bg", v)
continue
for b_idx in range(len(v)):
if k == "3dgs":
for v_idx in range(len(v[b_idx])):
v[b_idx][v_idx].save_ply(f"./debug_vis/{b_idx}_{v_idx}.ply")
continue
for v_idx in range(v.shape[1]):
save_path = os.path.join("./debug_vis", f"{b_idx}_{v_idx}_{k}.jpg")
cv2.imwrite(
save_path,
(v[b_idx, v_idx].detach().cpu().numpy() * 255).astype(np.uint8),
)
def test1():
import cv2
human_model_path = "./pretrained_models/human_model_files"
device = "cuda"
# root_dir = "/data1/projects/ExAvatar_RELEASE/avatar/data/Custom/data"
# meta_path = "/data1/projects/ExAvatar_RELEASE/avatar/data/Custom/data/data_list.json"
# dataset = ExAvatarDataset(root_dirs=root_dir, meta_path=meta_path, sample_side_views=3,
# render_image_res_low=384, render_image_res_high=384,
# render_region_size=(224, 224), source_image_res=384)
# root_dir = "/data1/datasets1/3d_human_data/humman/humman_compressed"
# meta_path = "/data1/datasets1/3d_human_data/humman/humman_id_debug_list.json"
# dataset = HuMManDataset(root_dirs=root_dir, meta_path=meta_path, sample_side_views=3,
# render_image_res_low=384, render_image_res_high=384,
# render_region_size=(682, 384), source_image_res=384)
# from openlrm.datasets.static_human import StaticHumanDataset
# root_dir = "./train_data/static_human_data"
# meta_path = "./train_data/static_human_data/data_id_list.json"
# dataset = StaticHumanDataset(root_dirs=root_dir, meta_path=meta_path, sample_side_views=7,
# render_image_res_low=384, render_image_res_high=384,
# render_region_size=(682, 384), source_image_res=384,
# debug=False)
# from openlrm.datasets.singleview_human import SingleViewHumanDataset
# root_dir = "./train_data/single_view"
# meta_path = "./train_data/single_view/data_list.json"
# dataset = SingleViewHumanDataset(root_dirs=root_dir, meta_path=meta_path, sample_side_views=0,
# render_image_res_low=384, render_image_res_high=384,
# render_region_size=(682, 384), source_image_res=384,
# debug=False)
from accelerate.utils import set_seed
set_seed(1234)
from LHM.datasets.video_human import VideoHumanDataset
root_dir = "./train_data/ClothVideo"
meta_path = "./train_data/ClothVideo/label/valid_id_with_img_list.json"
dataset = VideoHumanDataset(
root_dirs=root_dir,
meta_path=meta_path,
sample_side_views=7,
render_image_res_low=384,
render_image_res_high=384,
render_region_size=(682, 384),
source_image_res=384,
enlarge_ratio=[0.85, 1.2],
debug=False,
)
data = dataset[0]
def get_smplx_params(data):
smplx_params = {}
smplx_keys = [
"root_pose",
"body_pose",
"jaw_pose",
"leye_pose",
"reye_pose",
"lhand_pose",
"rhand_pose",
"expr",
"trans",
"betas",
]
for k, v in data.items():
if k in smplx_keys:
# print(k, v.shape)
smplx_params[k] = data[k]
return smplx_params
smplx_data = get_smplx_params(data)
smplx_data_tmp = {}
for k, v in smplx_data.items():
smplx_data_tmp[k] = v.unsqueeze(0).to(device)
print(k, v.shape)
smplx_data = smplx_data_tmp
c2ws = data["c2ws"].unsqueeze(0).to(device)
intrs = data["intrs"].unsqueeze(0).to(device)
render_images = data["render_image"].numpy()
render_h = data["render_full_resolutions"][0, 0]
render_w = data["render_full_resolutions"][0, 1]
render_bg_colors = data["render_bg_colors"].unsqueeze(0).to(device)
print("c2ws", c2ws.shape, "intrs", intrs.shape, intrs)
gs_render = GS3DRenderer(
human_model_path=human_model_path,
subdivide_num=2,
smpl_type="smplx",
feat_dim=64,
query_dim=64,
use_rgb=False,
sh_degree=3,
mlp_network_config=None,
xyz_offset_max_step=1.8 / 32,
expr_param_dim=10,
shape_param_dim=10,
fix_opacity=True,
fix_rotation=True,
)
gs_render.to(device)
out = gs_render.forward(
gs_hidden_features=torch.zeros((1, 2048, 64)).float().to(device),
query_points=None,
smplx_data=smplx_data,
c2w=c2ws,
intrinsic=intrs,
height=render_h,
width=render_w,
background_color=render_bg_colors,
debug=False,
)
os.makedirs("./debug_vis/gs_render", exist_ok=True)
for k, v in out.items():
if k == "comp_rgb_bg":
print("comp_rgb_bg", v)
continue
for b_idx in range(len(v)):
if k == "3dgs":
for v_idx in range(len(v[b_idx])):
v[b_idx][v_idx].save_ply(
f"./debug_vis/gs_render/{b_idx}_{v_idx}.ply"
)
continue
for v_idx in range(v.shape[1]):
save_path = os.path.join(
"./debug_vis/gs_render", f"{b_idx}_{v_idx}_{k}.jpg"
)
img = (
v[b_idx, v_idx].permute(1, 2, 0).detach().cpu().numpy() * 255
).astype(np.uint8)
print(img.shape, save_path)
if "mask" in k:
render_img = render_images[v_idx].transpose(1, 2, 0) * 255
cv2.imwrite(
save_path,
np.hstack(
[np.tile(img, (1, 1, 3)), render_img.astype(np.uint8)]
),
)
else:
cv2.imwrite(save_path, img)
if __name__ == "__main__":
# test1()
test()
test()