TripoSR / tsr /bake_texture.py
codemaker2015's picture
Initial commit
9ba6f9a
import numpy as np
import torch
import xatlas
def bake_texture(mesh, model, scene_code, texture_resolution):
atlas = xatlas.Atlas()
atlas.add_mesh(mesh.vertices, mesh.faces)
options = xatlas.PackOptions()
options.resolution = texture_resolution
options.padding = max(2, texture_resolution // 256)
options.bilinear = True
atlas.generate(pack_options=options)
vmapping, indices, uvs = atlas[0]
new_vertices = mesh.vertices[vmapping]
new_normals = mesh.vertex_normals[vmapping]
texture = np.zeros((texture_resolution, texture_resolution, 4), dtype=np.float32)
for face_idx in range(len(indices)):
face = indices[face_idx]
face_uvs = uvs[face]
face_verts = new_vertices[face]
u_coords = face_uvs[:, 0] * (texture_resolution - 1)
v_coords = (1 - face_uvs[:, 1]) * (texture_resolution - 1)
u_min, u_max = int(np.floor(u_coords.min())), int(np.ceil(u_coords.max()))
v_min, v_max = int(np.floor(v_coords.min())), int(np.ceil(v_coords.max()))
for v in range(max(0, v_min), min(texture_resolution, v_max + 1)):
for u in range(max(0, u_min), min(texture_resolution, u_max + 1)):
uv_point = np.array([u / (texture_resolution - 1), 1 - v / (texture_resolution - 1)])
v0 = face_uvs[0]
v1 = face_uvs[1]
v2 = face_uvs[2]
d00 = np.dot(v1 - v0, v1 - v0)
d01 = np.dot(v1 - v0, v2 - v0)
d11 = np.dot(v2 - v0, v2 - v0)
d20 = np.dot(uv_point - v0, v1 - v0)
d21 = np.dot(uv_point - v0, v2 - v0)
denom = d00 * d11 - d01 * d01
if abs(denom) < 1e-10:
continue
bary_v = (d11 * d20 - d01 * d21) / denom
bary_w = (d00 * d21 - d01 * d20) / denom
bary_u = 1.0 - bary_v - bary_w
if bary_u >= -0.01 and bary_v >= -0.01 and bary_w >= -0.01:
pos_3d = bary_u * face_verts[0] + bary_v * face_verts[1] + bary_w * face_verts[2]
texture[v, u, :3] = pos_3d
texture[v, u, 3] = 1.0
valid_mask = texture[:, :, 3] > 0
positions = texture[:, :, :3].reshape(-1, 3)
valid_positions = positions[valid_mask.flatten()]
if len(valid_positions) > 0:
positions_tensor = torch.tensor(valid_positions, dtype=torch.float32)
with torch.no_grad():
queried = model.renderer.query_triplane(
model.decoder,
positions_tensor,
scene_code,
)
colors = queried["color"].cpu().numpy()
color_texture = np.zeros((texture_resolution, texture_resolution, 4), dtype=np.float32)
color_texture[valid_mask, :3] = colors
color_texture[valid_mask, 3] = 1.0
else:
color_texture = np.zeros((texture_resolution, texture_resolution, 4), dtype=np.float32)
return {
"vmapping": vmapping,
"indices": indices,
"uvs": uvs,
"colors": color_texture,
"vertices": new_vertices,
"normals": new_normals,
}