| import io |
| import numpy as np |
| import torch |
| import trimesh |
| from PIL import Image as PILImage |
| from pygltflib import ( |
| GLTF2, Scene, Node, Mesh, Primitive, Attributes, Buffer, BufferView, Accessor, Asset, Skin, |
| Material, PbrMetallicRoughness, TextureInfo, |
| Image as GltfImage, Texture as GltfTexture, Sampler as GltfSampler, |
| FLOAT, UNSIGNED_SHORT, UNSIGNED_INT, VEC3, VEC2, VEC4, SCALAR, MAT4, ARRAY_BUFFER, ELEMENT_ARRAY_BUFFER |
| ) |
| from anigen.utils.skin_utils import repair_skeleton_parents |
|
|
| _ROT_Z_UP_TO_Y_UP = np.array([[-1, 0, 0], [0, 0, 1], [0, 1, 0]], dtype=np.float32) |
|
|
| def visualize_skeleton_as_mesh(joints, parents, joint_radius_ratio=0.02, bone_radius_ratio=0.01): |
| """ |
| Convert a skeleton into a trimesh object with spheres for joints and cones for bones. |
| The sizes are adaptive to the skeleton's bounding box. |
| """ |
| if len(joints) == 0: |
| return trimesh.Trimesh() |
|
|
| joints = np.asarray(joints, dtype=np.float32) @ _ROT_Z_UP_TO_Y_UP |
|
|
| |
| min_bound = np.min(joints, axis=0) |
| max_bound = np.max(joints, axis=0) |
| scale = np.linalg.norm(max_bound - min_bound) |
| if scale < 1e-5: |
| scale = 1.0 |
|
|
| joint_radius = scale * joint_radius_ratio |
| bone_radius = scale * bone_radius_ratio |
|
|
| meshes = [] |
|
|
| |
| for i, joint in enumerate(joints): |
| sphere = trimesh.creation.icosphere(radius=joint_radius, subdivisions=2) |
| sphere.apply_translation(joint) |
| |
| sphere.visual.vertex_colors = [230, 180, 80, 255] |
| meshes.append(sphere) |
|
|
| |
| for i, parent_idx in enumerate(parents): |
| if parent_idx < 0 or parent_idx == i: |
| continue |
| |
| child = joints[i] |
| parent = joints[parent_idx] |
| |
| vec = child - parent |
| length = np.linalg.norm(vec) |
| if length < 1e-5: |
| continue |
| |
| |
| cone = trimesh.creation.cone(radius=bone_radius, height=length) |
| |
| |
| z_axis = np.array([0, 0, 1]) |
| direction = vec / length |
| |
| |
| rot_mat = trimesh.geometry.align_vectors(z_axis, direction) |
| |
| |
| cone.apply_transform(rot_mat) |
| |
| |
| cone.apply_translation(parent) |
| |
| |
| cone.visual.vertex_colors = [160, 180, 200, 255] |
| meshes.append(cone) |
|
|
| if not meshes: |
| return trimesh.Trimesh() |
|
|
| |
| skeleton_mesh = trimesh.util.concatenate(meshes) |
| return skeleton_mesh |
|
|
| def convert_to_glb_from_data(mesh, joints, parents, skin_weights, output_path, vertex_colors=None, texture_image=None): |
| joints = np.asarray(joints, dtype=np.float32) @ _ROT_Z_UP_TO_Y_UP |
| num_joints = joints.shape[0] |
| num_verts = mesh.vertices.shape[0] |
|
|
| |
| if hasattr(mesh, 'faces'): |
| faces = np.asarray(mesh.faces) |
| if faces.size > 0: |
| if faces.min() < 0 or faces.max() >= num_verts: |
| raise ValueError(f"Mesh faces contain out-of-range indices (min={faces.min()}, max={faces.max()}, num_verts={num_verts}).") |
|
|
| for name, arr in ( |
| ("mesh.vertices", getattr(mesh, 'vertices', None)), |
| ("joints", joints), |
| ("skin_weights", skin_weights), |
| ): |
| if arr is None: |
| continue |
| arr_np = np.asarray(arr) |
| if arr_np.size and not np.isfinite(arr_np).all(): |
| raise ValueError(f"Non-finite values detected in {name}; refusing to write GLB.") |
|
|
| parents = np.asarray(parents).astype(np.int64, copy=False) |
| if parents.shape != (num_joints,): |
| parents = parents.reshape(-1) |
| if parents.shape[0] != num_joints: |
| raise ValueError(f"parents has wrong shape: expected ({num_joints},) got {parents.shape}") |
|
|
| |
| parents = repair_skeleton_parents(joints=joints, parents=parents, verbose=True) |
| |
| if skin_weights.shape[0] != num_verts: |
| print(f"Warning: Mismatch in vertex count. Mesh: {num_verts}, Skin: {skin_weights.shape[0]}") |
| return |
|
|
| |
| binary_data = bytearray() |
| |
| def align_to_4bytes(): |
| padding = len(binary_data) % 4 |
| if padding > 0: |
| binary_data.extend(b'\x00' * (4 - padding)) |
|
|
| accessors = [] |
| buffer_views = [] |
|
|
| def add_buffer_view(data_bytes, target=None): |
| align_to_4bytes() |
| byte_offset = len(binary_data) |
| binary_data.extend(data_bytes) |
| byte_length = len(data_bytes) |
| bv = BufferView(buffer=0, byteOffset=byte_offset, byteLength=byte_length, target=target) |
| buffer_views.append(bv) |
| return len(buffer_views) - 1 |
|
|
| def add_accessor(buffer_view_idx, component_type, count, type_str, min_val=None, max_val=None): |
| acc = Accessor( |
| bufferView=buffer_view_idx, |
| componentType=component_type, |
| count=count, |
| type=type_str, |
| min=min_val, |
| max=max_val |
| ) |
| accessors.append(acc) |
| return len(accessors) - 1 |
|
|
| |
| if hasattr(mesh, 'faces'): |
| indices = mesh.faces.flatten().astype(np.uint32) |
| if indices.max() <= 65535: |
| indices = indices.astype(np.uint16) |
| component_type = UNSIGNED_SHORT |
| else: |
| component_type = UNSIGNED_INT |
| |
| indices_bytes = indices.tobytes() |
| bv_idx = add_buffer_view(indices_bytes, target=ELEMENT_ARRAY_BUFFER) |
| indices_acc_idx = add_accessor(bv_idx, component_type, len(indices), SCALAR, min_val=[int(indices.min())], max_val=[int(indices.max())]) |
| |
| |
| positions = mesh.vertices.astype(np.float32) @ _ROT_Z_UP_TO_Y_UP |
| |
| min_pos = positions.min(axis=0).tolist() |
| max_pos = positions.max(axis=0).tolist() |
| |
| pos_bytes = positions.tobytes() |
| bv_idx = add_buffer_view(pos_bytes, target=ARRAY_BUFFER) |
| pos_acc_idx = add_accessor(bv_idx, FLOAT, len(positions), VEC3, min_val=min_pos, max_val=max_pos) |
|
|
| |
| norm_acc_idx = None |
| if hasattr(mesh, 'vertex_normals'): |
| normals = mesh.vertex_normals.astype(np.float32) @ _ROT_Z_UP_TO_Y_UP |
| norm_bytes = normals.tobytes() |
| bv_idx = add_buffer_view(norm_bytes, target=ARRAY_BUFFER) |
| norm_acc_idx = add_accessor(bv_idx, FLOAT, len(normals), VEC3) |
|
|
| |
| tex_acc_idx = None |
| if hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None: |
| uvs = mesh.visual.uv.astype(np.float32) |
| uvs[:, 1] = 1.0 - uvs[:, 1] |
| uv_bytes = uvs.tobytes() |
| bv_idx = add_buffer_view(uv_bytes, target=ARRAY_BUFFER) |
| tex_acc_idx = add_accessor(bv_idx, FLOAT, len(uvs), VEC2) |
|
|
| |
| |
| color_acc_idx = None |
| if vertex_colors is not None and texture_image is None: |
| colors = np.asarray(vertex_colors) |
| if colors.shape[0] != num_verts: |
| print(f"Warning: Mismatch in vertex color count. Mesh: {num_verts}, Colors: {colors.shape[0]}") |
| else: |
| if colors.ndim != 2 or colors.shape[1] not in (3, 4): |
| print(f"Warning: Unexpected vertex_colors shape: {colors.shape}. Expected (N,3) or (N,4).") |
| else: |
| colors = colors.astype(np.float32, copy=False) |
| |
| np.clip(colors, 0.0, 1.0, out=colors) |
| color_bytes = colors.tobytes() |
| bv_idx = add_buffer_view(color_bytes, target=ARRAY_BUFFER) |
| color_acc_idx = add_accessor(bv_idx, FLOAT, len(colors), VEC4 if colors.shape[1] == 4 else VEC3) |
|
|
| |
| gltf_images = [] |
| gltf_textures = [] |
| gltf_samplers = [] |
| if texture_image is not None: |
| tex_arr = np.asarray(texture_image) |
| img_pil = PILImage.fromarray(tex_arr) |
| buf = io.BytesIO() |
| img_pil.save(buf, format='PNG') |
| png_bytes = buf.getvalue() |
|
|
| align_to_4bytes() |
| img_bv_offset = len(binary_data) |
| binary_data.extend(png_bytes) |
| img_bv_len = len(png_bytes) |
| img_bv = BufferView(buffer=0, byteOffset=img_bv_offset, byteLength=img_bv_len) |
| buffer_views.append(img_bv) |
| img_bv_idx = len(buffer_views) - 1 |
|
|
| gltf_images.append(GltfImage(bufferView=img_bv_idx, mimeType='image/png')) |
| gltf_samplers.append(GltfSampler(magFilter=9729, minFilter=9987, wrapS=10497, wrapT=10497)) |
| gltf_textures.append(GltfTexture(source=0, sampler=0)) |
|
|
| |
| if num_joints >= 4: |
| top_indices = np.argsort(skin_weights, axis=1)[:, -4:] |
| top_indices = np.flip(top_indices, axis=1) |
| top_weights = np.take_along_axis(skin_weights, top_indices, axis=1) |
| else: |
| top_indices = np.argsort(skin_weights, axis=1) |
| top_indices = np.flip(top_indices, axis=1) |
| top_weights = np.take_along_axis(skin_weights, top_indices, axis=1) |
| |
| pad_width = 4 - num_joints |
| top_indices = np.pad(top_indices, ((0,0), (0, pad_width)), mode='constant', constant_values=0) |
| top_weights = np.pad(top_weights, ((0,0), (0, pad_width)), mode='constant', constant_values=0.0) |
|
|
| weight_sums = np.sum(top_weights, axis=1, keepdims=True) |
| weight_sums[weight_sums == 0] = 1.0 |
| top_weights = top_weights / weight_sums |
|
|
| joints_0 = top_indices.astype(np.uint16) |
| weights_0 = top_weights.astype(np.float32) |
|
|
| joints_bytes = joints_0.tobytes() |
| bv_idx = add_buffer_view(joints_bytes, target=ARRAY_BUFFER) |
| joints_acc_idx = add_accessor(bv_idx, UNSIGNED_SHORT, len(joints_0), VEC4) |
|
|
| weights_bytes = weights_0.tobytes() |
| bv_idx = add_buffer_view(weights_bytes, target=ARRAY_BUFFER) |
| weights_acc_idx = add_accessor(bv_idx, FLOAT, len(weights_0), VEC4) |
|
|
| |
| ibms_list = [] |
| for i in range(num_joints): |
| mat = np.eye(4, dtype=np.float32) |
| mat[:3, 3] = joints[i] |
| inv_mat = np.linalg.inv(mat) |
| ibms_list.append(inv_mat.flatten('F')) |
| |
| ibms_data = np.concatenate(ibms_list) |
| ibms_bytes = ibms_data.tobytes() |
| bv_idx = add_buffer_view(ibms_bytes) |
| ibms_acc_idx = add_accessor(bv_idx, FLOAT, num_joints, MAT4) |
|
|
| |
| nodes = [] |
| |
| |
| mesh_node = Node(name="Mesh", mesh=0, skin=0) |
| nodes.append(mesh_node) |
|
|
| |
| joint_nodes = [] |
| for i in range(num_joints): |
| parent_idx = parents[i] |
| pos = joints[i] |
| |
| if parent_idx == -1 or parent_idx == i: |
| local_pos = pos |
| else: |
| parent_pos = joints[parent_idx] |
| local_pos = pos - parent_pos |
| |
| node = Node(name=f"joint_{i}", translation=local_pos.tolist()) |
| joint_nodes.append(node) |
| |
| |
| for i in range(num_joints): |
| parent_idx = parents[i] |
| if parent_idx != -1 and parent_idx != i: |
| parent_node = joint_nodes[parent_idx] |
| if parent_node.children is None: |
| parent_node.children = [] |
| parent_node.children.append(i + 1) |
| |
| nodes.extend(joint_nodes) |
|
|
| |
| root_joint_indices = np.where(parents == -1)[0] |
| if len(root_joint_indices) == 0: |
| root_joint_indices = np.where(parents == np.arange(num_joints))[0] |
|
|
| if len(root_joint_indices) == 0: |
| print("Error: No root joint found.") |
| return |
|
|
| if len(root_joint_indices) > 1: |
| print(f"Warning: Found {len(root_joint_indices)} root joints. Creating a virtual root.") |
| virtual_root_node_idx = len(nodes) |
| virtual_root_node = Node(name="VirtualRoot", children=[int(i) + 1 for i in root_joint_indices]) |
| nodes.append(virtual_root_node) |
| root_node_idx = virtual_root_node_idx |
| else: |
| root_joint_idx = root_joint_indices[0] |
| root_node_idx = int(root_joint_idx) + 1 |
|
|
| |
| skin = Skin( |
| inverseBindMatrices=ibms_acc_idx, |
| joints=[i + 1 for i in range(num_joints)], |
| skeleton=root_node_idx |
| ) |
|
|
| |
| attributes_dict = { |
| "POSITION": pos_acc_idx, |
| "JOINTS_0": joints_acc_idx, |
| "WEIGHTS_0": weights_acc_idx |
| } |
| if norm_acc_idx is not None: |
| attributes_dict["NORMAL"] = norm_acc_idx |
| if tex_acc_idx is not None: |
| attributes_dict["TEXCOORD_0"] = tex_acc_idx |
| if color_acc_idx is not None: |
| attributes_dict["COLOR_0"] = color_acc_idx |
|
|
| primitive = Primitive( |
| attributes=Attributes(**attributes_dict), |
| indices=indices_acc_idx, |
| material=0, |
| ) |
| mesh_obj = Mesh(primitives=[primitive]) |
|
|
| |
| scene = Scene(nodes=[0, root_node_idx]) |
|
|
| |
| align_to_4bytes() |
|
|
| |
| buffer = Buffer(byteLength=len(binary_data)) |
| |
| |
| if gltf_textures: |
| pbr = PbrMetallicRoughness( |
| baseColorTexture=TextureInfo(index=0), |
| metallicFactor=0.0, |
| roughnessFactor=1.0, |
| ) |
| else: |
| pbr = PbrMetallicRoughness( |
| baseColorFactor=[1.0, 1.0, 1.0, 1.0], |
| metallicFactor=0.0, |
| roughnessFactor=1.0, |
| ) |
| material = Material(pbrMetallicRoughness=pbr) |
|
|
| gltf = GLTF2( |
| asset=Asset(version="2.0"), |
| scene=0, |
| scenes=[scene], |
| nodes=nodes, |
| meshes=[mesh_obj], |
| materials=[material], |
| skins=[skin], |
| accessors=accessors, |
| bufferViews=buffer_views, |
| buffers=[buffer], |
| images=gltf_images if gltf_images else None, |
| textures=gltf_textures if gltf_textures else None, |
| samplers=gltf_samplers if gltf_samplers else None, |
| ) |
| |
| gltf.set_binary_blob(bytes(binary_data)) |
|
|
| |
| |
| if str(output_path).lower().endswith('.glb'): |
| gltf.save_binary(output_path) |
| else: |
| gltf.save(output_path) |
|
|
| |
| if str(output_path).lower().endswith('.glb'): |
| try: |
| with open(output_path, 'rb') as f: |
| magic = f.read(4) |
| if magic != b'glTF': |
| print(f"Warning: output does not look like a valid GLB (magic={magic!r}).") |
| except Exception as e: |
| print(f"Warning: failed to validate GLB header: {e}") |
| print(f"Saved GLB to {output_path}") |
|
|
| def save_colored_pcl(array, filename): |
| if isinstance(array, torch.Tensor): |
| array = array.detach().cpu().numpy() |
| points = array[:, :3] |
| if array.shape[1] >= 6: |
| colors = (array[:, 3:6] * 255).astype(np.uint8) |
| else: |
| colors = np.full((points.shape[0], 3), 255, dtype=np.uint8) |
| |
| |
| data = np.hstack((points, colors)) |
| |
| header = f"""ply\nformat ascii 1.0\nelement vertex {points.shape[0]}\nproperty float x\nproperty float y\nproperty float z\nproperty uchar red\nproperty uchar green\nproperty uchar blue\nend_header\n""" |
| |
| with open(filename, 'w') as file: |
| file.write(header) |
| np.savetxt(file, data, fmt='%f %f %f %d %d %d') |
|
|
| def _extract_vertex_rgb(vertex_attrs): |
| """Extract Nx3 RGB in [0,1] from vertex_attrs (torch/np), or return None.""" |
| if vertex_attrs is None: |
| return None |
| attrs = vertex_attrs |
| if isinstance(attrs, torch.Tensor): |
| attrs = attrs.detach().cpu().numpy() |
| attrs = np.asarray(attrs) |
| if attrs.ndim != 2 or attrs.shape[1] < 3: |
| return None |
| colors = attrs[:, :3].astype(np.float32, copy=False) |
| np.clip(colors, 0.0, 1.0, out=colors) |
| return colors |
|
|
| def transfer_vertex_colors_nearest(src_vertices, src_colors, dst_vertices): |
| """Transfer per-vertex colors by nearest-vertex mapping. |
| |
| Args: |
| src_vertices: (Ns, 3) float |
| src_colors: (Ns, 3) float in [0,1] |
| dst_vertices: (Nd, 3) float |
| Returns: |
| (Nd, 3) float in [0,1] |
| """ |
| if src_vertices is None or src_colors is None or dst_vertices is None: |
| return None |
| src_vertices = np.asarray(src_vertices) |
| src_colors = np.asarray(src_colors) |
| dst_vertices = np.asarray(dst_vertices) |
| if src_vertices.ndim != 2 or src_vertices.shape[1] != 3: |
| return None |
| if dst_vertices.ndim != 2 or dst_vertices.shape[1] != 3: |
| return None |
| if src_colors.ndim != 2 or src_colors.shape[1] != 3 or src_colors.shape[0] != src_vertices.shape[0]: |
| return None |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| src_v = torch.from_numpy(src_vertices).to(device=device, dtype=torch.float32) |
| dst_v = torch.from_numpy(dst_vertices).to(device=device, dtype=torch.float32) |
| src_c = torch.from_numpy(src_colors).to(device=device, dtype=torch.float32) |
|
|
| |
| idx = None |
| try: |
| from pytorch3d.ops import knn_points |
| _, nn_idx, _ = knn_points(dst_v[None], src_v[None], K=1, return_nn=False) |
| idx = nn_idx[0, :, 0] |
| except Exception: |
| |
| dist = torch.cdist(dst_v, src_v) |
| idx = torch.argmin(dist, dim=1) |
|
|
| out = src_c[idx].clamp_(0.0, 1.0) |
| return out.detach().cpu().numpy() |
|
|