Spaces:
Sleeping
Sleeping
| from enum import Enum | |
| import numpy as np | |
| import torch | |
| import trimesh | |
| from salad.utils import thutil | |
| def write_obj(name: str, vertices: np.ndarray, faces: np.ndarray): | |
| """ | |
| name: filename | |
| vertices: (V,3) | |
| faces: (F,3) Assume the mesh is a triangle mesh. | |
| """ | |
| vertices = thutil.th2np(vertices) | |
| faces = thutil.th2np(faces).astype(np.uint32) | |
| fout = open(name, "w") | |
| for ii in range(len(vertices)): | |
| fout.write( | |
| "v " | |
| + str(vertices[ii, 0]) | |
| + " " | |
| + str(vertices[ii, 1]) | |
| + " " | |
| + str(vertices[ii, 2]) | |
| + "\n" | |
| ) | |
| for ii in range(len(faces)): | |
| fout.write( | |
| "f " | |
| + str(faces[ii, 0] + 1) | |
| + " " | |
| + str(faces[ii, 1] + 1) | |
| + " " | |
| + str(faces[ii, 2] + 1) | |
| + "\n" | |
| ) | |
| fout.close() | |
| def write_obj_triangle(name: str, vertices: np.ndarray, triangles: np.ndarray): | |
| fout = open(name, "w") | |
| for ii in range(len(vertices)): | |
| fout.write( | |
| "v " | |
| + str(vertices[ii, 0]) | |
| + " " | |
| + str(vertices[ii, 1]) | |
| + " " | |
| + str(vertices[ii, 2]) | |
| + "\n" | |
| ) | |
| for ii in range(len(triangles)): | |
| fout.write( | |
| "f " | |
| + str(triangles[ii, 0] + 1) | |
| + " " | |
| + str(triangles[ii, 1] + 1) | |
| + " " | |
| + str(triangles[ii, 2] + 1) | |
| + "\n" | |
| ) | |
| fout.close() | |
| def write_obj_polygon(name: str, vertices: np.ndarray, polygons: np.ndarray): | |
| fout = open(name, "w") | |
| for ii in range(len(vertices)): | |
| fout.write( | |
| "v " | |
| + str(vertices[ii][0]) | |
| + " " | |
| + str(vertices[ii][1]) | |
| + " " | |
| + str(vertices[ii][2]) | |
| + "\n" | |
| ) | |
| for ii in range(len(polygons)): | |
| fout.write("f") | |
| for jj in range(len(polygons[ii])): | |
| fout.write(" " + str(polygons[ii][jj] + 1)) | |
| fout.write("\n") | |
| fout.close() | |
| def read_obj(name: str): | |
| verts = [] | |
| faces = [] | |
| with open(name, "r") as f: | |
| lines = [line.rstrip() for line in f] | |
| for line in lines: | |
| if line.startswith("v "): | |
| verts.append(np.float32(line.split()[1:4])) | |
| elif line.startswith("f "): | |
| faces.append( | |
| np.int32([item.split("/")[0] for item in line.split()[1:4]]) | |
| ) | |
| v = np.vstack(verts) | |
| f = np.vstack(faces) - 1 | |
| return v, f | |
| def scene_as_mesh(scene_or_mesh): | |
| if isinstance(scene_or_mesh, trimesh.Scene): | |
| if len(scene_or_mesh.geometry) == 0: | |
| mesh = None | |
| else: | |
| mesh = trimesh.util.concatenate( | |
| tuple( | |
| trimesh.Trimesh(vertices=g.vertices, faces=g.faces) | |
| for g in scene_or_mesh.geometry.values() | |
| if g.faces.shape[1] == 3 | |
| ) | |
| ) | |
| else: | |
| mesh = scene_or_mesh | |
| return mesh | |
| def get_center(verts): | |
| max_vals = verts.max(0) | |
| min_vals = verts.min(0) | |
| center = (max_vals + min_vals) / 2 | |
| return center | |
| def to_center(verts): | |
| verts -= get_center(verts)[None, :] | |
| return verts | |
| def get_offset_and_scale(verts, radius=1.0): | |
| verts = thutil.th2np(verts) | |
| verts = verts.copy() | |
| offset = get_center(verts)[None, :] | |
| verts -= offset | |
| scale = 1 / np.linalg.norm(verts, axis=1).max() * radius | |
| return offset, scale | |
| def normalize_mesh(mesh: trimesh.Trimesh): | |
| # unit cube normalization | |
| v, f = np.array(mesh.vertices), np.array(mesh.faces) | |
| maxv, minv = np.max(v, 0), np.min(v, 0) | |
| offset = minv | |
| v = v - offset | |
| scale = np.sqrt(np.sum((maxv - minv) ** 2)) | |
| v = v / scale | |
| normed_mesh = trimesh.Trimesh(vertices=v, faces=f, process=False) | |
| return dict(mesh=normed_mesh, offset=offset, scale=scale) | |
| def normalize_scene(scene: trimesh.Scene): | |
| mesh_merged = scene_as_mesh(scene) | |
| out = normalize_mesh(mesh_merged) | |
| offset = out["offset"] | |
| scale = out["scale"] | |
| submesh_normalized_list = [] | |
| for i, submesh in enumerate(list(scene.geometry.values())): | |
| v, f = np.array(submesh.vertices), np.array(submesh.faces) | |
| v = v - offset | |
| v = v / scale | |
| submesh_normalized_list.append(trimesh.Trimesh(v, f)) | |
| return trimesh.Scene(submesh_normalized_list) | |
| class SampleBy(Enum): | |
| AREAS = 0 | |
| FACES = 1 | |
| HYB = 2 | |
| def get_faces_normals(mesh): | |
| if type(mesh) is not torch.Tensor: | |
| vs, faces = mesh | |
| vs_faces = vs[faces] | |
| else: | |
| vs_faces = mesh | |
| if vs_faces.shape[-1] == 2: | |
| vs_faces = torch.cat( | |
| ( | |
| vs_faces, | |
| torch.zeros( | |
| *vs_faces.shape[:2], 1, dtype=vs_faces.dtype, device=vs_faces.device | |
| ), | |
| ), | |
| dim=2, | |
| ) | |
| face_normals = torch.cross( | |
| vs_faces[:, 1, :] - vs_faces[:, 0, :], vs_faces[:, 2, :] - vs_faces[:, 1, :] | |
| ) | |
| return face_normals | |
| def compute_face_areas(mesh): | |
| face_normals = get_faces_normals(mesh) | |
| face_areas = torch.norm(face_normals, p=2, dim=1) | |
| face_areas_ = face_areas.clone() | |
| face_areas_[torch.eq(face_areas_, 0)] = 1 | |
| face_normals = face_normals / face_areas_[:, None] | |
| face_areas = 0.5 * face_areas | |
| return face_areas, face_normals | |
| def sample_uvw(shape, device): | |
| u, v = torch.rand(*shape, device=device), torch.rand(*shape, device=device) | |
| mask = (u + v).gt(1) | |
| u[mask], v[mask] = -u[mask] + 1, -v[mask] + 1 | |
| w = -u - v + 1 | |
| uvw = torch.stack([u, v, w], dim=len(shape)) | |
| return uvw | |
| def sample_on_mesh(mesh, num_samples: int, face_areas=None, sample_s=SampleBy.HYB): | |
| vs, faces = mesh | |
| if faces is None: # sample from pc | |
| uvw = None | |
| if vs.shape[0] < num_samples: | |
| chosen_faces_inds = torch.arange(vs.shape[0]) | |
| else: | |
| chosen_faces_inds = torch.argsort(torch.rand(vs.shape[0]))[:num_samples] | |
| samples = vs[chosen_faces_inds] | |
| else: | |
| weighted_p = [] | |
| if sample_s == SampleBy.AREAS or sample_s == SampleBy.HYB: | |
| if face_areas is None: | |
| face_areas, _ = compute_face_areas(mesh) | |
| face_areas[torch.isnan(face_areas)] = 0 | |
| weighted_p.append(face_areas / face_areas.sum()) | |
| if sample_s == SampleBy.FACES or sample_s == SampleBy.HYB: | |
| weighted_p.append(torch.ones(mesh[1].shape[0], device=mesh[0].device)) | |
| chosen_faces_inds = [ | |
| torch.multinomial(weights, num_samples // len(weighted_p), replacement=True) | |
| for weights in weighted_p | |
| ] | |
| if sample_s == SampleBy.HYB: | |
| chosen_faces_inds = torch.cat(chosen_faces_inds, dim=0) | |
| chosen_faces = faces[chosen_faces_inds] | |
| uvw = sample_uvw([num_samples], vs.device) | |
| samples = torch.einsum("sf,sfd->sd", uvw, vs[chosen_faces]) | |
| return samples, chosen_faces_inds, uvw | |
| def repair_normals(v, f): | |
| mesh = trimesh.Trimesh(v, f) | |
| trimesh.repair.fix_normals(mesh) | |
| v = mesh.vertices | |
| f = np.asarray(mesh.faces) | |
| return v, f | |