AniGen / anigen /utils /export_utils.py
Yihua7's picture
Initial commit: AniGen - Animatable 3D Generation
6b92ff7
import io
import numpy as np
import torch
import trimesh
from PIL import Image as PILImage
from pygltflib import (
GLTF2, Scene, Node, Mesh, Primitive, Attributes, Buffer, BufferView, Accessor, Asset, Skin,
Material, PbrMetallicRoughness, TextureInfo,
Image as GltfImage, Texture as GltfTexture, Sampler as GltfSampler,
FLOAT, UNSIGNED_SHORT, UNSIGNED_INT, VEC3, VEC2, VEC4, SCALAR, MAT4, ARRAY_BUFFER, ELEMENT_ARRAY_BUFFER
)
from anigen.utils.skin_utils import repair_skeleton_parents
_ROT_Z_UP_TO_Y_UP = np.array([[-1, 0, 0], [0, 0, 1], [0, 1, 0]], dtype=np.float32)
def visualize_skeleton_as_mesh(joints, parents, joint_radius_ratio=0.02, bone_radius_ratio=0.01):
"""
Convert a skeleton into a trimesh object with spheres for joints and cones for bones.
The sizes are adaptive to the skeleton's bounding box.
"""
if len(joints) == 0:
return trimesh.Trimesh()
joints = np.asarray(joints, dtype=np.float32) @ _ROT_Z_UP_TO_Y_UP
# Calculate adaptive scale
min_bound = np.min(joints, axis=0)
max_bound = np.max(joints, axis=0)
scale = np.linalg.norm(max_bound - min_bound)
if scale < 1e-5:
scale = 1.0
joint_radius = scale * joint_radius_ratio
bone_radius = scale * bone_radius_ratio
meshes = []
# Create joints
for i, joint in enumerate(joints):
sphere = trimesh.creation.icosphere(radius=joint_radius, subdivisions=2)
sphere.apply_translation(joint)
# Optional: add vertex colors for joints
sphere.visual.vertex_colors = [230, 180, 80, 255]
meshes.append(sphere)
# Create bones
for i, parent_idx in enumerate(parents):
if parent_idx < 0 or parent_idx == i:
continue
child = joints[i]
parent = joints[parent_idx]
vec = child - parent
length = np.linalg.norm(vec)
if length < 1e-5:
continue
# Create a cone pointing along Z axis by default
cone = trimesh.creation.cone(radius=bone_radius, height=length)
# Align the cone's Z axis to the vector
z_axis = np.array([0, 0, 1])
direction = vec / length
# Calculate rotation matrix from Z axis to direction
rot_mat = trimesh.geometry.align_vectors(z_axis, direction)
# Apply rotation
cone.apply_transform(rot_mat)
# Translate to parent
cone.apply_translation(parent)
# Optional: add vertex colors for bones
cone.visual.vertex_colors = [160, 180, 200, 255]
meshes.append(cone)
if not meshes:
return trimesh.Trimesh()
# Combine all meshes
skeleton_mesh = trimesh.util.concatenate(meshes)
return skeleton_mesh
def convert_to_glb_from_data(mesh, joints, parents, skin_weights, output_path, vertex_colors=None, texture_image=None):
joints = np.asarray(joints, dtype=np.float32) @ _ROT_Z_UP_TO_Y_UP
num_joints = joints.shape[0]
num_verts = mesh.vertices.shape[0]
# Basic sanity checks to avoid writing malformed GLB that can hang Blender.
if hasattr(mesh, 'faces'):
faces = np.asarray(mesh.faces)
if faces.size > 0:
if faces.min() < 0 or faces.max() >= num_verts:
raise ValueError(f"Mesh faces contain out-of-range indices (min={faces.min()}, max={faces.max()}, num_verts={num_verts}).")
for name, arr in (
("mesh.vertices", getattr(mesh, 'vertices', None)),
("joints", joints),
("skin_weights", skin_weights),
):
if arr is None:
continue
arr_np = np.asarray(arr)
if arr_np.size and not np.isfinite(arr_np).all():
raise ValueError(f"Non-finite values detected in {name}; refusing to write GLB.")
parents = np.asarray(parents).astype(np.int64, copy=False)
if parents.shape != (num_joints,):
parents = parents.reshape(-1)
if parents.shape[0] != num_joints:
raise ValueError(f"parents has wrong shape: expected ({num_joints},) got {parents.shape}")
# Repair invalid/cyclic skeletons instead of exporting a GLB that hangs viewers.
parents = repair_skeleton_parents(joints=joints, parents=parents, verbose=True)
if skin_weights.shape[0] != num_verts:
print(f"Warning: Mismatch in vertex count. Mesh: {num_verts}, Skin: {skin_weights.shape[0]}")
return
# Prepare Binary Data
binary_data = bytearray()
def align_to_4bytes():
padding = len(binary_data) % 4
if padding > 0:
binary_data.extend(b'\x00' * (4 - padding))
accessors = []
buffer_views = []
def add_buffer_view(data_bytes, target=None):
align_to_4bytes()
byte_offset = len(binary_data)
binary_data.extend(data_bytes)
byte_length = len(data_bytes)
bv = BufferView(buffer=0, byteOffset=byte_offset, byteLength=byte_length, target=target)
buffer_views.append(bv)
return len(buffer_views) - 1
def add_accessor(buffer_view_idx, component_type, count, type_str, min_val=None, max_val=None):
acc = Accessor(
bufferView=buffer_view_idx,
componentType=component_type,
count=count,
type=type_str,
min=min_val,
max=max_val
)
accessors.append(acc)
return len(accessors) - 1
# 1. Indices
if hasattr(mesh, 'faces'):
indices = mesh.faces.flatten().astype(np.uint32)
if indices.max() <= 65535:
indices = indices.astype(np.uint16)
component_type = UNSIGNED_SHORT
else:
component_type = UNSIGNED_INT
indices_bytes = indices.tobytes()
bv_idx = add_buffer_view(indices_bytes, target=ELEMENT_ARRAY_BUFFER)
indices_acc_idx = add_accessor(bv_idx, component_type, len(indices), SCALAR, min_val=[int(indices.min())], max_val=[int(indices.max())])
# 2. Positions
positions = mesh.vertices.astype(np.float32) @ _ROT_Z_UP_TO_Y_UP
# GLTF requires min/max for POSITION
min_pos = positions.min(axis=0).tolist()
max_pos = positions.max(axis=0).tolist()
pos_bytes = positions.tobytes()
bv_idx = add_buffer_view(pos_bytes, target=ARRAY_BUFFER)
pos_acc_idx = add_accessor(bv_idx, FLOAT, len(positions), VEC3, min_val=min_pos, max_val=max_pos)
# 3. Normals
norm_acc_idx = None
if hasattr(mesh, 'vertex_normals'):
normals = mesh.vertex_normals.astype(np.float32) @ _ROT_Z_UP_TO_Y_UP
norm_bytes = normals.tobytes()
bv_idx = add_buffer_view(norm_bytes, target=ARRAY_BUFFER)
norm_acc_idx = add_accessor(bv_idx, FLOAT, len(normals), VEC3)
# 4. UVs
tex_acc_idx = None
if hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None:
uvs = mesh.visual.uv.astype(np.float32)
uvs[:, 1] = 1.0 - uvs[:, 1] # OpenGL -> glTF: flip V
uv_bytes = uvs.tobytes()
bv_idx = add_buffer_view(uv_bytes, target=ARRAY_BUFFER)
tex_acc_idx = add_accessor(bv_idx, FLOAT, len(uvs), VEC2)
# 4.5 Vertex Colors (skipped when a texture image is provided, since glTF
# multiplies vertex colors with the texture which is not desired here)
color_acc_idx = None
if vertex_colors is not None and texture_image is None:
colors = np.asarray(vertex_colors)
if colors.shape[0] != num_verts:
print(f"Warning: Mismatch in vertex color count. Mesh: {num_verts}, Colors: {colors.shape[0]}")
else:
if colors.ndim != 2 or colors.shape[1] not in (3, 4):
print(f"Warning: Unexpected vertex_colors shape: {colors.shape}. Expected (N,3) or (N,4).")
else:
colors = colors.astype(np.float32, copy=False)
# Expect colors in [0,1]; clamp to be safe.
np.clip(colors, 0.0, 1.0, out=colors)
color_bytes = colors.tobytes()
bv_idx = add_buffer_view(color_bytes, target=ARRAY_BUFFER)
color_acc_idx = add_accessor(bv_idx, FLOAT, len(colors), VEC4 if colors.shape[1] == 4 else VEC3)
# 4.6 Texture Image (embedded as PNG in the binary blob)
gltf_images = []
gltf_textures = []
gltf_samplers = []
if texture_image is not None:
tex_arr = np.asarray(texture_image)
img_pil = PILImage.fromarray(tex_arr)
buf = io.BytesIO()
img_pil.save(buf, format='PNG')
png_bytes = buf.getvalue()
align_to_4bytes()
img_bv_offset = len(binary_data)
binary_data.extend(png_bytes)
img_bv_len = len(png_bytes)
img_bv = BufferView(buffer=0, byteOffset=img_bv_offset, byteLength=img_bv_len)
buffer_views.append(img_bv)
img_bv_idx = len(buffer_views) - 1
gltf_images.append(GltfImage(bufferView=img_bv_idx, mimeType='image/png'))
gltf_samplers.append(GltfSampler(magFilter=9729, minFilter=9987, wrapS=10497, wrapT=10497))
gltf_textures.append(GltfTexture(source=0, sampler=0))
# 5. Joints & Weights
if num_joints >= 4:
top_indices = np.argsort(skin_weights, axis=1)[:, -4:]
top_indices = np.flip(top_indices, axis=1)
top_weights = np.take_along_axis(skin_weights, top_indices, axis=1)
else:
top_indices = np.argsort(skin_weights, axis=1)
top_indices = np.flip(top_indices, axis=1)
top_weights = np.take_along_axis(skin_weights, top_indices, axis=1)
pad_width = 4 - num_joints
top_indices = np.pad(top_indices, ((0,0), (0, pad_width)), mode='constant', constant_values=0)
top_weights = np.pad(top_weights, ((0,0), (0, pad_width)), mode='constant', constant_values=0.0)
weight_sums = np.sum(top_weights, axis=1, keepdims=True)
weight_sums[weight_sums == 0] = 1.0
top_weights = top_weights / weight_sums
joints_0 = top_indices.astype(np.uint16)
weights_0 = top_weights.astype(np.float32)
joints_bytes = joints_0.tobytes()
bv_idx = add_buffer_view(joints_bytes, target=ARRAY_BUFFER)
joints_acc_idx = add_accessor(bv_idx, UNSIGNED_SHORT, len(joints_0), VEC4)
weights_bytes = weights_0.tobytes()
bv_idx = add_buffer_view(weights_bytes, target=ARRAY_BUFFER)
weights_acc_idx = add_accessor(bv_idx, FLOAT, len(weights_0), VEC4)
# 6. Inverse Bind Matrices
ibms_list = []
for i in range(num_joints):
mat = np.eye(4, dtype=np.float32)
mat[:3, 3] = joints[i]
inv_mat = np.linalg.inv(mat)
ibms_list.append(inv_mat.flatten('F'))
ibms_data = np.concatenate(ibms_list)
ibms_bytes = ibms_data.tobytes()
bv_idx = add_buffer_view(ibms_bytes)
ibms_acc_idx = add_accessor(bv_idx, FLOAT, num_joints, MAT4)
# Nodes
nodes = []
# Mesh Node (Node 0)
mesh_node = Node(name="Mesh", mesh=0, skin=0)
nodes.append(mesh_node)
# Skeleton Nodes (Node 1 to M)
joint_nodes = []
for i in range(num_joints):
parent_idx = parents[i]
pos = joints[i]
if parent_idx == -1 or parent_idx == i:
local_pos = pos
else:
parent_pos = joints[parent_idx]
local_pos = pos - parent_pos
node = Node(name=f"joint_{i}", translation=local_pos.tolist())
joint_nodes.append(node)
# Set children for skeleton nodes
for i in range(num_joints):
parent_idx = parents[i]
if parent_idx != -1 and parent_idx != i:
parent_node = joint_nodes[parent_idx]
if parent_node.children is None:
parent_node.children = []
parent_node.children.append(i + 1)
nodes.extend(joint_nodes)
# Find root joint index
root_joint_indices = np.where(parents == -1)[0]
if len(root_joint_indices) == 0:
root_joint_indices = np.where(parents == np.arange(num_joints))[0]
if len(root_joint_indices) == 0:
print("Error: No root joint found.")
return
if len(root_joint_indices) > 1:
print(f"Warning: Found {len(root_joint_indices)} root joints. Creating a virtual root.")
virtual_root_node_idx = len(nodes)
virtual_root_node = Node(name="VirtualRoot", children=[int(i) + 1 for i in root_joint_indices])
nodes.append(virtual_root_node)
root_node_idx = virtual_root_node_idx
else:
root_joint_idx = root_joint_indices[0]
root_node_idx = int(root_joint_idx) + 1
# Skin
skin = Skin(
inverseBindMatrices=ibms_acc_idx,
joints=[i + 1 for i in range(num_joints)],
skeleton=root_node_idx
)
# Mesh
attributes_dict = {
"POSITION": pos_acc_idx,
"JOINTS_0": joints_acc_idx,
"WEIGHTS_0": weights_acc_idx
}
if norm_acc_idx is not None:
attributes_dict["NORMAL"] = norm_acc_idx
if tex_acc_idx is not None:
attributes_dict["TEXCOORD_0"] = tex_acc_idx
if color_acc_idx is not None:
attributes_dict["COLOR_0"] = color_acc_idx
primitive = Primitive(
attributes=Attributes(**attributes_dict),
indices=indices_acc_idx,
material=0,
)
mesh_obj = Mesh(primitives=[primitive])
# Scene
scene = Scene(nodes=[0, root_node_idx])
# Ensure GLB BIN chunk is 4-byte aligned.
align_to_4bytes()
# Buffer
buffer = Buffer(byteLength=len(binary_data))
# Material — attach texture when available, otherwise plain white
if gltf_textures:
pbr = PbrMetallicRoughness(
baseColorTexture=TextureInfo(index=0),
metallicFactor=0.0,
roughnessFactor=1.0,
)
else:
pbr = PbrMetallicRoughness(
baseColorFactor=[1.0, 1.0, 1.0, 1.0],
metallicFactor=0.0,
roughnessFactor=1.0,
)
material = Material(pbrMetallicRoughness=pbr)
gltf = GLTF2(
asset=Asset(version="2.0"),
scene=0,
scenes=[scene],
nodes=nodes,
meshes=[mesh_obj],
materials=[material],
skins=[skin],
accessors=accessors,
bufferViews=buffer_views,
buffers=[buffer],
images=gltf_images if gltf_images else None,
textures=gltf_textures if gltf_textures else None,
samplers=gltf_samplers if gltf_samplers else None,
)
gltf.set_binary_blob(bytes(binary_data))
# IMPORTANT: `save()` may write JSON `.gltf` even if the filename ends with `.glb`.
# Many viewers (Meshlab/Open3D) will then report “glb not supported”.
if str(output_path).lower().endswith('.glb'):
gltf.save_binary(output_path)
else:
gltf.save(output_path)
# Quick sanity check: a real GLB starts with magic bytes b'glTF'.
if str(output_path).lower().endswith('.glb'):
try:
with open(output_path, 'rb') as f:
magic = f.read(4)
if magic != b'glTF':
print(f"Warning: output does not look like a valid GLB (magic={magic!r}).")
except Exception as e:
print(f"Warning: failed to validate GLB header: {e}")
print(f"Saved GLB to {output_path}")
def save_colored_pcl(array, filename):
if isinstance(array, torch.Tensor):
array = array.detach().cpu().numpy()
points = array[:, :3]
if array.shape[1] >= 6:
colors = (array[:, 3:6] * 255).astype(np.uint8)
else:
colors = np.full((points.shape[0], 3), 255, dtype=np.uint8)
# Combine points and colors into a single array
data = np.hstack((points, colors))
# Define the .ply header
header = f"""ply\nformat ascii 1.0\nelement vertex {points.shape[0]}\nproperty float x\nproperty float y\nproperty float z\nproperty uchar red\nproperty uchar green\nproperty uchar blue\nend_header\n"""
# Write to the .ply file
with open(filename, 'w') as file:
file.write(header)
np.savetxt(file, data, fmt='%f %f %f %d %d %d')
def _extract_vertex_rgb(vertex_attrs):
"""Extract Nx3 RGB in [0,1] from vertex_attrs (torch/np), or return None."""
if vertex_attrs is None:
return None
attrs = vertex_attrs
if isinstance(attrs, torch.Tensor):
attrs = attrs.detach().cpu().numpy()
attrs = np.asarray(attrs)
if attrs.ndim != 2 or attrs.shape[1] < 3:
return None
colors = attrs[:, :3].astype(np.float32, copy=False)
np.clip(colors, 0.0, 1.0, out=colors)
return colors
def transfer_vertex_colors_nearest(src_vertices, src_colors, dst_vertices):
"""Transfer per-vertex colors by nearest-vertex mapping.
Args:
src_vertices: (Ns, 3) float
src_colors: (Ns, 3) float in [0,1]
dst_vertices: (Nd, 3) float
Returns:
(Nd, 3) float in [0,1]
"""
if src_vertices is None or src_colors is None or dst_vertices is None:
return None
src_vertices = np.asarray(src_vertices)
src_colors = np.asarray(src_colors)
dst_vertices = np.asarray(dst_vertices)
if src_vertices.ndim != 2 or src_vertices.shape[1] != 3:
return None
if dst_vertices.ndim != 2 or dst_vertices.shape[1] != 3:
return None
if src_colors.ndim != 2 or src_colors.shape[1] != 3 or src_colors.shape[0] != src_vertices.shape[0]:
return None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
src_v = torch.from_numpy(src_vertices).to(device=device, dtype=torch.float32)
dst_v = torch.from_numpy(dst_vertices).to(device=device, dtype=torch.float32)
src_c = torch.from_numpy(src_colors).to(device=device, dtype=torch.float32)
# Prefer pytorch3d for KNN if available.
idx = None
try:
from pytorch3d.ops import knn_points
_, nn_idx, _ = knn_points(dst_v[None], src_v[None], K=1, return_nn=False)
idx = nn_idx[0, :, 0]
except Exception:
# Fallback: brute-force cdist.
dist = torch.cdist(dst_v, src_v) # (Nd, Ns)
idx = torch.argmin(dist, dim=1)
out = src_c[idx].clamp_(0.0, 1.0)
return out.detach().cpu().numpy()