| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | from typing import Tuple, Literal, Optional |
| | |
| | import nvdiffrast.torch as dr |
| | import torch.nn.functional as F |
| | import torch |
| | import numpy as np |
| | from vhap.util import vector_ops as V |
| |
|
| |
|
| | def get_SH_shading(normals, sh_coefficients, sh_const): |
| | """ |
| | :param normals: shape N, H, W, K, 3 |
| | :param sh_coefficients: shape N, 9, 3 |
| | :return: |
| | """ |
| |
|
| | N = normals |
| |
|
| | |
| | sh = torch.stack( |
| | [ |
| | N[..., 0] * 0.0 + 1.0, |
| | N[..., 0], |
| | N[..., 1], |
| | N[..., 2], |
| | N[..., 0] * N[..., 1], |
| | N[..., 0] * N[..., 2], |
| | N[..., 1] * N[..., 2], |
| | N[..., 0] ** 2 - N[..., 1] ** 2, |
| | 3 * (N[..., 2] ** 2) - 1, |
| | ], |
| | dim=-1, |
| | ) |
| | sh = sh * sh_const[None, None, None, :].to(sh.device) |
| |
|
| | |
| | sh = sh[..., None] |
| |
|
| | |
| | sh_coefficients = sh_coefficients[:, None, None, :, :] |
| |
|
| | |
| | shading = torch.sum(sh_coefficients * sh, dim=3) |
| | return shading |
| |
|
| |
|
| | class NVDiffRenderer(torch.nn.Module): |
| | def __init__( |
| | self, |
| | use_opengl: bool = False, |
| | lighting_type: Literal['constant', 'front', 'front-range', 'SH'] = 'front', |
| | lighting_space: Literal['camera', 'world'] = 'world', |
| | disturb_rate_fg: Optional[float] = 0.5, |
| | disturb_rate_bg: Optional[float] = 0.5, |
| | fid2cid: Optional[torch.Tensor] = None, |
| | ): |
| | super().__init__() |
| | self.backend = 'nvdiffrast' |
| | self.lighting_type = lighting_type |
| | self.lighting_space = lighting_space |
| | self.disturb_rate_fg = disturb_rate_fg |
| | self.disturb_rate_bg = disturb_rate_bg |
| | self.glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext() |
| | self.fragment_cache = None |
| |
|
| | if fid2cid is not None: |
| | fid2cid = F.pad(fid2cid, [1, 0], value=0) |
| | self.register_buffer("fid2cid", fid2cid, persistent=False) |
| |
|
| | |
| | pi = np.pi |
| | sh_const = torch.tensor( |
| | [ |
| | 1 / np.sqrt(4 * pi), |
| | ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), |
| | ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), |
| | ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), |
| | (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), |
| | (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), |
| | (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), |
| | (pi / 4) * (3 / 2) * (np.sqrt(5 / (12 * pi))), |
| | (pi / 4) * (1 / 2) * (np.sqrt(5 / (4 * pi))), |
| | ], |
| | dtype=torch.float32, |
| | ) |
| | self.register_buffer("sh_const", sh_const, persistent=False) |
| |
|
| | def clear_cache(self): |
| | self.fragment_cache = None |
| | |
| | def mvp_from_camera_param(self, RT, K, image_size): |
| | |
| | proj = self.projection_from_intrinsics(K, image_size) |
| |
|
| | |
| | if RT.shape[-2] == 3: |
| | mv = torch.nn.functional.pad(RT, [0, 0, 0, 1]) |
| | mv[..., 3, 3] = 1 |
| | elif RT.shape[-2] == 4: |
| | mv = RT |
| | mvp = torch.bmm(proj, mv) |
| | return mvp |
| | |
| | def projection_from_intrinsics(self, K: torch.Tensor, image_size: Tuple[int], near: float=0.1, far:float=10): |
| | """ |
| | Transform points from camera space (x: right, y: up, z: out) to clip space (x: right, y: down, z: in) |
| | Args: |
| | K: Intrinsic matrix, (N, 3, 3) |
| | K = [[ |
| | [fx, 0, cx], |
| | [0, fy, cy], |
| | [0, 0, 1], |
| | ] |
| | ] |
| | image_size: (height, width) |
| | Output: |
| | proj = [[ |
| | [2*fx/w, 0.0, (w - 2*cx)/w, 0.0 ], |
| | [0.0, 2*fy/h, (h - 2*cy)/h, 0.0 ], |
| | [0.0, 0.0, -(far+near) / (far-near), -2*far*near / (far-near)], |
| | [0.0, 0.0, -1.0, 0.0 ] |
| | ] |
| | ] |
| | """ |
| |
|
| | B = K.shape[0] |
| | h, w = image_size |
| |
|
| | if K.shape[-2:] == (3, 3): |
| | fx = K[..., 0, 0] |
| | fy = K[..., 1, 1] |
| | cx = K[..., 0, 2] |
| | cy = K[..., 1, 2] |
| | elif K.shape[-1] == 4: |
| | fx, fy, cx, cy = K[..., [0, 1, 2, 3]].split(1, dim=-1) |
| | else: |
| | raise ValueError(f"Expected K to be (N, 3, 3) or (N, 4) but got: {K.shape}") |
| |
|
| | proj = torch.zeros([B, 4, 4], device=K.device) |
| | proj[:, 0, 0] = fx * 2 / w |
| | proj[:, 1, 1] = fy * 2 / h |
| | proj[:, 0, 2] = (w - 2 * cx) / w |
| | proj[:, 1, 2] = (h - 2 * cy) / h |
| | proj[:, 2, 2] = -(far+near) / (far-near) |
| | proj[:, 2, 3] = -2*far*near / (far-near) |
| | proj[:, 3, 2] = -1 |
| | return proj |
| | |
| | def world_to_camera(self, vtx, RT): |
| | """Transform vertex positions from the world space to the camera space""" |
| | RT = torch.from_numpy(RT).cuda() if isinstance(RT, np.ndarray) else RT |
| | if RT.shape[-2] == 3: |
| | mv = torch.nn.functional.pad(RT, [0, 0, 0, 1]) |
| | mv[..., 3, 3] = 1 |
| | elif RT.shape[-2] == 4: |
| | mv = RT |
| |
|
| | |
| | assert vtx.shape[-1] in [3, 4] |
| | if vtx.shape[-1] == 3: |
| | posw = torch.cat([vtx, torch.ones([*vtx.shape[:2], 1]).cuda()], axis=-1) |
| | elif vtx.shape[-1] == 4: |
| | posw = vtx |
| | else: |
| | raise ValueError(f"Expected 3D or 4D points but got: {vtx.shape[-1]}") |
| | return torch.bmm(posw, RT.transpose(-1, -2)) |
| | |
| | def camera_to_clip(self, vtx, K, image_size): |
| | """Transform vertex positions from the camera space to the clip space""" |
| | K = torch.from_numpy(K).cuda() if isinstance(K, np.ndarray) else K |
| | proj = self.projection_from_intrinsics(K, image_size) |
| | |
| | |
| | assert vtx.shape[-1] in [3, 4] |
| | if vtx.shape[-1] == 3: |
| | posw = torch.cat([vtx, torch.ones([*vtx.shape[:2], 1]).cuda()], axis=-1) |
| | elif vtx.shape[-1] == 4: |
| | posw = vtx |
| | else: |
| | raise ValueError(f"Expected 3D or 4D points but got: {vtx.shape[-1]}") |
| | return torch.bmm(posw, proj.transpose(-1, -2)) |
| | |
| | def world_to_clip(self, vtx, RT, K, image_size): |
| | """Transform vertex positions from the world space to the clip space""" |
| | mvp = self.mvp_from_camera_param(RT, K, image_size) |
| |
|
| | mvp = torch.from_numpy(mvp).cuda() if isinstance(mvp, np.ndarray) else mvp |
| | |
| | posw = torch.cat([vtx, torch.ones([*vtx.shape[:2], 1]).cuda()], axis=-1) |
| | return torch.bmm(posw, mvp.transpose(-1, -2)) |
| | |
| | def world_to_ndc(self, vtx, RT, K, image_size, flip_y=False): |
| | """Transform vertex positions from the world space to the NDC space""" |
| | verts_clip = self.world_to_clip(vtx, RT, K, image_size) |
| | verts_ndc = verts_clip[:, :, :3] / verts_clip[:, :, 3:] |
| | if flip_y: |
| | verts_ndc[:, :, 1] *= -1 |
| | return verts_ndc |
| |
|
| | def rasterize(self, verts, faces, RT, K, image_size, use_cache=False, require_grad=False): |
| | """ |
| | Rasterizes meshes using a standard rasterization approach |
| | :param meshes: |
| | :param cameras: |
| | :param image_size: |
| | :return: fragments: |
| | screen_coords: N x H x W x 2 with x, y values following pytorch3ds NDC-coord system convention |
| | top left = +1, +1 ; bottom_right = -1, -1 |
| | """ |
| | |
| | |
| | verts_camera = self.world_to_camera(verts, RT) |
| | verts_clip = self.camera_to_clip(verts_camera, K, image_size) |
| | tri = faces.int() |
| | rast_out, rast_out_db = self.rasterize_fragments(verts_clip, tri, image_size, use_cache, require_grad) |
| | rast_dict = { |
| | "rast_out": rast_out, |
| | "rast_out_db": rast_out_db, |
| | "verts": verts, |
| | "verts_camera": verts_camera[..., :3], |
| | "verts_clip": verts_clip, |
| | } |
| | |
| | |
| | |
| | |
| | |
| |
|
| | return rast_dict |
| |
|
| | def rasterize_fragments(self, verts_clip, tri, image_size, use_cache, require_grad=False): |
| | """ |
| | Either rasterizes meshes or returns cached result |
| | """ |
| |
|
| | if not use_cache or self.fragment_cache is None: |
| | if require_grad: |
| | rast_out, rast_out_db = dr.rasterize(self.glctx, verts_clip, tri, image_size) |
| | else: |
| | with torch.no_grad(): |
| | rast_out, rast_out_db = dr.rasterize(self.glctx, verts_clip, tri, image_size) |
| | self.fragment_cache = (rast_out, rast_out_db) |
| |
|
| | return self.fragment_cache |
| |
|
| | def compute_screen_coords(self, rast_out: torch.Tensor, verts:torch.Tensor, faces:torch.Tensor, image_size: Tuple[int]): |
| | """ Compute screen coords for visible pixels |
| | Args: |
| | verts: (N, V, 3), the verts should lie in the ndc space |
| | faces: (F, 3) |
| | """ |
| | N = verts.shape[0] |
| | F = faces.shape[0] |
| | meshes = Meshes(verts, faces[None, ...].expand(N, -1, -1)) |
| | verts_packed = meshes.verts_packed() |
| | faces_packed = meshes.faces_packed() |
| | face_verts = verts_packed[faces_packed] |
| |
|
| | |
| | pix2face = rast_out[..., -1:].long() - 1 |
| | is_visible = pix2face > -1 |
| | |
| | pix2face_packed = pix2face + torch.arange(0, N)[:, None, None, None].to(pix2face) * F |
| |
|
| | bary_coords = rast_out[..., :2] |
| | bary_coords = torch.cat([bary_coords, 1 - bary_coords.sum(dim=-1, keepdim=True)], dim =-1) |
| |
|
| | visible_faces = pix2face_packed[is_visible] |
| | visible_face_verts = face_verts[visible_faces] |
| | visible_bary_coords = bary_coords[is_visible[..., 0]] |
| | |
| |
|
| | visible_surface_point = visible_face_verts * visible_bary_coords[..., None] |
| | visible_surface_point = visible_surface_point.sum(dim=1) |
| |
|
| | screen_coords = torch.zeros(*pix2face_packed.shape[:3], 2, device=meshes.device) |
| | screen_coords[is_visible[..., 0]] = visible_surface_point[:, :2] |
| |
|
| | return screen_coords |
| | |
| | def compute_v_normals(self, verts, faces): |
| | i0 = faces[..., 0].long() |
| | i1 = faces[..., 1].long() |
| | i2 = faces[..., 2].long() |
| |
|
| | v0 = verts[..., i0, :] |
| | v1 = verts[..., i1, :] |
| | v2 = verts[..., i2, :] |
| | face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) |
| | v_normals = torch.zeros_like(verts) |
| | N = verts.shape[0] |
| | v_normals.scatter_add_(1, i0[..., None].repeat(N, 1, 3), face_normals) |
| | v_normals.scatter_add_(1, i1[..., None].repeat(N, 1, 3), face_normals) |
| | v_normals.scatter_add_(1, i2[..., None].repeat(N, 1, 3), face_normals) |
| |
|
| | v_normals = torch.where(V.dot(v_normals, v_normals) > 1e-20, v_normals, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda')) |
| | v_normals = V.safe_normalize(v_normals) |
| | if torch.is_anomaly_enabled(): |
| | assert torch.all(torch.isfinite(v_normals)) |
| | return v_normals |
| | |
| | def compute_face_normals(self, verts, faces): |
| | i0 = faces[..., 0].long() |
| | i1 = faces[..., 1].long() |
| | i2 = faces[..., 2].long() |
| |
|
| | v0 = verts[..., i0, :] |
| | v1 = verts[..., i1, :] |
| | v2 = verts[..., i2, :] |
| | face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) |
| | face_normals = V.safe_normalize(face_normals) |
| | if torch.is_anomaly_enabled(): |
| | assert torch.all(torch.isfinite(face_normals)) |
| | return face_normals |
| | |
| | def shade(self, normal, lighting_coeff=None): |
| | if self.lighting_type == 'constant': |
| | diffuse = torch.ones_like(normal[..., :3]) |
| | elif self.lighting_type == 'front': |
| | |
| | diffuse = V.dot(normal, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda')) |
| | mask_backface = diffuse < 0 |
| | diffuse[mask_backface] = diffuse[mask_backface].abs()*0.3 |
| | elif self.lighting_type == 'front-range': |
| | bias = 0.75 |
| | diffuse = torch.clamp(V.dot(normal, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda')) + bias, 0.0, 1.0) |
| | elif self.lighting_type == 'SH': |
| | diffuse = get_SH_shading(normal, lighting_coeff, self.sh_const) |
| | else: |
| | raise NotImplementedError(f"Unknown lighting type: {self.lighting_type}") |
| | return diffuse |
| | |
| | def detach_by_indices(self, x, indices): |
| | x = x.clone() |
| | x[:, indices] = x[:, indices].detach() |
| | return x |
| | |
| | def render_rgba( |
| | self, rast_dict, verts, faces, verts_uv, faces_uv, tex, lights, background_color=[1., 1., 1.], |
| | align_texture_except_fid=None, align_boundary_except_vid=None, enable_disturbance=False, |
| | ): |
| | """ |
| | Renders flame RGBA images |
| | """ |
| |
|
| | rast_out = rast_dict["rast_out"] |
| | rast_out_db = rast_dict["rast_out_db"] |
| | verts = rast_dict["verts"] |
| | verts_camera = rast_dict["verts_camera"] |
| | verts_clip = rast_dict["verts_clip"] |
| | faces = faces.int() |
| | faces_uv = faces_uv.int() |
| | fg_mask = torch.clamp(rast_out[..., -1:], 0, 1).bool() |
| |
|
| | out_dict = {} |
| |
|
| | |
| | if self.lighting_space == 'world': |
| | v_normal = self.compute_v_normals(verts, faces) |
| | elif self.lighting_space == 'camera': |
| | v_normal = self.compute_v_normals(verts_camera, faces) |
| | else: |
| | raise NotImplementedError(f"Unknown lighting space: {self.lighting_space}") |
| |
|
| | v_attr = [v_normal] |
| | |
| | v_attr = torch.cat(v_attr, dim=-1) |
| | attr, _ = dr.interpolate(v_attr, rast_out, faces) |
| | normal = attr[..., :3] |
| | normal = V.safe_normalize(normal) |
| |
|
| | |
| | texc, texd = dr.interpolate(verts_uv[None, ...], rast_out, faces_uv, rast_db=rast_out_db, diff_attrs='all') |
| | if align_texture_except_fid is not None: |
| | fid = rast_out[..., -1:].long() |
| | mask = torch.zeros(faces.shape[0]+1, dtype=torch.bool, device=fid.device) |
| | mask[align_texture_except_fid + 1] = True |
| | b, h, w = rast_out.shape[:3] |
| | rast_mask = torch.gather(mask.reshape(1, 1, 1, -1).expand(b, h, w, -1), 3, fid) |
| | texc = torch.where(rast_mask, texc.detach(), texc) |
| |
|
| | tex = tex.permute(0, 2, 3, 1).contiguous() |
| | albedo = dr.texture(tex, texc, texd, filter_mode='linear-mipmap-linear', max_mip_level=None) |
| | |
| | |
| | diffuse = self.shade(normal, lights) |
| | diffuse_detach_normal = self.shade(normal.detach(), lights) |
| |
|
| | rgb = albedo * diffuse |
| | alpha = fg_mask.float() |
| | rgba = torch.cat([rgb, alpha], dim=-1) |
| |
|
| | |
| | if isinstance(background_color, list): |
| | """Background as a constant color""" |
| | rgba_bg = torch.tensor(background_color + [0]).to(rgba).expand_as(rgba) |
| | elif isinstance(background_color, torch.Tensor): |
| | """Background as a image""" |
| | rgba_bg = background_color |
| | rgba_bg = torch.cat([rgba_bg, torch.zeros_like(rgba_bg[..., :1])], dim=-1) |
| | else: |
| | raise ValueError(f"Unknown background type: {type(background_color)}") |
| | rgba_bg = rgba_bg.flip(1) |
| | |
| | rgba = torch.where(fg_mask, rgba, rgba_bg) |
| | rgba_orig = rgba |
| |
|
| | if enable_disturbance: |
| | |
| | B, H, W, _ = rgba.shape |
| | |
| | if self.disturb_rate_fg is not None: |
| | w_fg = (torch.rand_like(rgba[..., :1]) < self.disturb_rate_fg).int() |
| | else: |
| | w_fg = torch.zeros_like(rgba[..., :1]).int() |
| | if self.disturb_rate_bg is not None: |
| | w_bg = (torch.rand_like(rgba[..., :1]) < self.disturb_rate_bg).int() |
| | else: |
| | w_bg = torch.zeros_like(rgba[..., :1]).int() |
| | |
| | |
| | fid = rast_out[..., -1:].long() |
| | num_clusters = self.fid2cid.max() + 1 |
| |
|
| | fid2cid = self.fid2cid[None, None, None, :].expand(B, H, W, -1) |
| | cid = torch.gather(fid2cid, -1, fid) |
| | out_dict['cid'] = cid.flip(1) |
| |
|
| | rgba_ = torch.zeros_like(rgba) |
| | for i in range(num_clusters): |
| | c_rgba = rgba_bg if i == 0 else rgba |
| | w = w_bg if i == 0 else w_fg |
| |
|
| | c_mask = cid == i |
| | c_pixels = c_rgba[c_mask.repeat_interleave(4, dim=-1)].reshape(-1, 4).detach() |
| |
|
| | if i != 1: |
| | if len(c_pixels) > 0: |
| | c_idx = torch.randint(0, len(c_pixels), (B * H * W, ), device=c_pixels.device) |
| | c_sample = c_pixels[c_idx].reshape(B, H, W, 4) |
| | rgba_ += c_mask * (c_sample * w + c_rgba * (1 - w)) |
| | else: |
| | rgba_ += c_mask * c_rgba |
| | rgba = rgba_ |
| |
|
| | |
| | if align_boundary_except_vid is not None: |
| | verts_clip = self.detach_by_indices(verts_clip, align_boundary_except_vid) |
| | rgba_aa = dr.antialias(rgba, rast_out, verts_clip, faces.int()) |
| | aa = ((rgba - rgba_aa) != 0).any(dim=-1, keepdim=True).repeat_interleave(4, dim=-1) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | out_dict.update({ |
| | 'albedo': albedo.flip(1), |
| | 'normal': normal.flip(1), |
| | 'diffuse': diffuse.flip(1), |
| | 'diffuse_detach_normal': diffuse_detach_normal.flip(1), |
| | 'rgba': rgba_aa.flip(1), |
| | 'aa': aa[..., :3].float().flip(1), |
| | }) |
| | return out_dict |
| | |
| | def render_without_texture( |
| | self, verts, faces, RT, K, image_size, background_color=[1., 1., 1.], |
| | ): |
| | """ |
| | Renders meshes into RGBA images |
| | """ |
| |
|
| | verts_camera_ = self.world_to_camera(verts, RT) |
| | verts_camera = verts_camera_[..., :3] |
| | verts_clip = self.camera_to_clip(verts_camera_, K, image_size) |
| | tri = faces.int() |
| | rast_out, rast_out_db = dr.rasterize(self.glctx, verts_clip, tri, image_size) |
| |
|
| | faces = faces.int() |
| | fg_mask = torch.clamp(rast_out[..., -1:], 0, 1).bool() |
| | face_id = torch.clamp(rast_out[..., -1:].long() - 1, 0) |
| | W, H = face_id.shape[1:3] |
| |
|
| | face_normals = self.compute_face_normals(verts_camera, faces) |
| | face_normals_ = face_normals[:, None, None, :, :].expand(-1, W, H, -1, -1) |
| | face_id_ = face_id[:, :, :, None].expand(-1, -1, -1, -1, 3) |
| | normal = torch.gather(face_normals_, -2, face_id_).squeeze(-2) |
| |
|
| | albedo = torch.ones_like(normal) |
| | |
| | |
| | diffuse = self.shade(normal) |
| |
|
| | rgb = albedo * diffuse |
| | alpha = fg_mask.float() |
| | rgba = torch.cat([rgb, alpha], dim=-1) |
| |
|
| | |
| | if isinstance(background_color, list) or isinstance(background_color, tuple): |
| | """Background as a constant color""" |
| | rgba_bg = torch.tensor(list(background_color) + [0]).to(rgba).expand_as(rgba) |
| | elif isinstance(background_color, torch.Tensor): |
| | """Background as a image""" |
| | rgba_bg = background_color |
| | rgba_bg = torch.cat([rgba_bg, torch.zeros_like(rgba_bg[..., :1])], dim=-1) |
| | else: |
| | raise ValueError(f"Unknown background type: {type(background_color)}") |
| | rgba_bg = rgba_bg.flip(1) |
| | |
| | normal = torch.where(fg_mask, normal, rgba_bg[..., :3]) |
| | diffuse = torch.where(fg_mask, diffuse, rgba_bg[..., :3]) |
| | rgba = torch.where(fg_mask, rgba, rgba_bg) |
| |
|
| | |
| | rgba_aa = dr.antialias(rgba, rast_out, verts_clip, faces.int()) |
| | |
| | return { |
| | 'albedo': albedo.flip(1), |
| | 'normal': normal.flip(1), |
| | 'diffuse': diffuse.flip(1), |
| | 'rgba': rgba_aa.flip(1), |
| | 'verts_clip': verts_clip, |
| | } |
| |
|
| | def render_v_color( |
| | self, verts, v_color, faces, RT, K, image_size, background_color=[1., 1., 1.], |
| | ): |
| | """ |
| | Renders meshes into RGBA images |
| | """ |
| |
|
| | verts_camera_ = self.world_to_camera(verts, RT) |
| | verts_camera = verts_camera_[..., :3] |
| | verts_clip = self.camera_to_clip(verts_camera_, K, image_size) |
| | tri = faces.int() |
| | rast_out, rast_out_db = dr.rasterize(self.glctx, verts_clip, tri, image_size) |
| |
|
| | faces = faces.int() |
| | fg_mask = torch.clamp(rast_out[..., -1:], 0, 1).bool() |
| | face_id = torch.clamp(rast_out[..., -1:].long() - 1, 0) |
| | W, H = face_id.shape[1:3] |
| |
|
| | face_normals = self.compute_face_normals(verts_camera, faces) |
| | face_normals_ = face_normals[:, None, None, :, :].expand(-1, W, H, -1, -1) |
| | face_id_ = face_id[:, :, :, None].expand(-1, -1, -1, -1, 3) |
| | normal = torch.gather(face_normals_, -2, face_id_).squeeze(-2) |
| |
|
| | albedo = torch.ones_like(normal) |
| |
|
| | v_attr = [v_color] |
| | v_attr = torch.cat(v_attr, dim=-1) |
| | attr, _ = dr.interpolate(v_attr, rast_out, faces) |
| | albedo = attr[..., :3] |
| | |
| | |
| | diffuse = self.shade(normal) |
| |
|
| | rgb = albedo * diffuse |
| | alpha = fg_mask.float() |
| | rgba = torch.cat([rgb, alpha], dim=-1) |
| |
|
| | |
| | if isinstance(background_color, list) or isinstance(background_color, tuple): |
| | """Background as a constant color""" |
| | rgba_bg = torch.tensor(list(background_color) + [0]).to(rgba).expand_as(rgba) |
| | elif isinstance(background_color, torch.Tensor): |
| | """Background as a image""" |
| | rgba_bg = background_color |
| | rgba_bg = torch.cat([rgba_bg, torch.zeros_like(rgba_bg[..., :1])], dim=-1) |
| | else: |
| | raise ValueError(f"Unknown background type: {type(background_color)}") |
| | rgba_bg = rgba_bg.flip(1) |
| | |
| | normal = torch.where(fg_mask, normal, rgba_bg[..., :3]) |
| | diffuse = torch.where(fg_mask, diffuse, rgba_bg[..., :3]) |
| | rgba = torch.where(fg_mask, rgba, rgba_bg) |
| |
|
| | |
| | rgba_aa = dr.antialias(rgba, rast_out, verts_clip, faces.int()) |
| | |
| | return { |
| | 'albedo': albedo.flip(1), |
| | 'normal': normal.flip(1), |
| | 'diffuse': diffuse.flip(1), |
| | 'rgba': rgba_aa.flip(1), |
| | } |
| |
|