Lingteng Qiu (邱陵腾)
rm assets & wheels
434b0b0
import nvdiffrast.torch as dr
import torch
from matplotlib import image
def _warmup(glctx):
# windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59
def tensor(*args, **kwargs):
return torch.tensor(*args, device="cuda", **kwargs)
pos = tensor(
[[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]],
dtype=torch.float32,
)
tri = tensor([[0, 1, 2]], dtype=torch.int32)
dr.rasterize(glctx, pos, tri, resolution=[256, 256])
class NormalsRenderer:
def __init__(
self,
mv: torch.Tensor, # C,4,4
proj: torch.Tensor, # C,4,4
image_size: tuple[int, int],
):
self._mvp = proj @ mv # C,4,4
self._image_size = image_size
# self._glctx = dr.RasterizeGLContext()
self._glctx = dr.RasterizeCudaContext()
_warmup(self._glctx)
def render(
self,
vertices: torch.Tensor, # V,3 float
normals: torch.Tensor, # V,3 float
faces: torch.Tensor, # F,3 long
) -> torch.Tensor: # C,H,W,4
V = vertices.shape[0]
faces = faces.type(torch.int32)
vert_hom = torch.cat(
(vertices, torch.ones(V, 1, device=vertices.device)), axis=-1
) # V,3 -> V,4
vertices_clip = vert_hom @ self._mvp.transpose(-2, -1) # C,V,4
rast_out, _ = dr.rasterize(
self._glctx,
vertices_clip,
faces,
resolution=self._image_size,
grad_db=False,
) # C,H,W,4
vert_col = (normals + 1) / 2 # V,3
col, _ = dr.interpolate(vert_col, rast_out, faces) # C,H,W,3
alpha = torch.clamp(rast_out[..., -1:], max=1) # C,H,W,1
col = torch.concat((col, alpha), dim=-1) # C,H,W,4
col = dr.antialias(col, rast_out, vertices_clip, faces) # C,H,W,4
return col # C,H,W,4