# src/render_preview.py """Deterministic GGX preview render of predicted PBR maps. Unlike the training rendering loss (random multi-light), this uses a single fixed key light + small ambient so the demo preview is stable and reproducible. Two renderers: - render_preview: flat front-facing patch (quick sanity view). - render_sphere_preview: the predicted material mapped onto a lit sphere ("material ball"), which reads as a 3D object and can be relit by moving the light direction. """ import math import torch import torch.nn.functional as F from src.rendering_loss import ggx_shade, metallic_to_specular # Fixed key light from the upper-right, view straight-on. _LIGHT_DIR = (0.4, 0.4, 1.0) _VIEW_DIR = (0.0, 0.0, 1.0) _AMBIENT = 0.04 # Sphere "material ball" lighting: brighter key + a little more fill. _KEY_INTENSITY = 2.4 _SPHERE_AMBIENT = 0.08 def render_preview(normal, roughness, metallic, basecolor): """Render predicted maps under fixed lighting. Inputs (C,H,W) in [0,1].""" # (C,H,W) -> (1,H,W,C) n = normal.permute(1, 2, 0).unsqueeze(0) r = roughness.permute(1, 2, 0).unsqueeze(0) m = metallic.permute(1, 2, 0).unsqueeze(0) bc = basecolor.permute(1, 2, 0).unsqueeze(0) # Normal: [0,1] storage -> signed [-1,1], normalized. n = n * 2.0 - 1.0 n = n / (torch.norm(n, dim=-1, keepdim=True) + 1e-8) diffuse, specular = metallic_to_specular(bc, m) _, H, W, _ = bc.shape wi = torch.tensor(_LIGHT_DIR, dtype=bc.dtype, device=bc.device).view(1, 1, 1, 3) wo = torch.tensor(_VIEW_DIR, dtype=bc.dtype, device=bc.device).view(1, 1, 1, 3) radiance = ggx_shade(diffuse, specular, r, n, wi, wo) # (1,H,W,3) radiance = radiance + _AMBIENT * diffuse # Reinhard tone-map to [0,1), then clamp for safety. ldr = (radiance / (1.0 + radiance)).clamp(0.0, 1.0) return ldr[0].permute(2, 0, 1).contiguous() # (3,H,W) def _light_dir(az_deg, el_deg, dtype, device): """Unit light direction from azimuth/elevation (degrees). +z faces camera.""" az = math.radians(az_deg) el = math.radians(el_deg) x = math.cos(el) * math.sin(az) y = math.sin(el) z = math.cos(el) * math.cos(az) return torch.tensor((x, y, z), dtype=dtype, device=device).view(1, 1, 1, 3) def render_sphere_preview( normal, roughness, metallic, basecolor, light_az_deg=-35.0, light_el_deg=35.0, detail_strength=1.0, res=384, ): """Render the predicted material onto a lit sphere ("material ball"). Inputs are (C,H,W) maps in [0,1]. The flat texture is front-projected onto a unit sphere: the geometric sphere normal gives the round 3D shading, and the predicted tangent-space normal perturbs it for surface detail. Returns a (3,res,res) RGB image in [0,1] with the ball composited on a neutral background. Deterministic for fixed light/strength. """ dtype, device = basecolor.dtype, basecolor.device def resize(t): # (C,H,W) -> (1,res,res,C) t = F.interpolate(t.unsqueeze(0), size=(res, res), mode="bilinear", align_corners=False) return t.permute(0, 2, 3, 1) bc = resize(basecolor) r = resize(roughness) m = resize(metallic) n_det = resize(normal) * 2.0 - 1.0 n_det = n_det / (torch.norm(n_det, dim=-1, keepdim=True) + 1e-8) # Sphere geometry: screen (u,v) in [-1,1], v up. Inside unit disk => surface. ys = torch.linspace(1.0, -1.0, res, dtype=dtype, device=device) # +v up xs = torch.linspace(-1.0, 1.0, res, dtype=dtype, device=device) vv, uu = torch.meshgrid(ys, xs, indexing="ij") r2 = uu ** 2 + vv ** 2 mask = (r2 <= 1.0) z = torch.sqrt((1.0 - r2).clamp(min=0.0)) n_geo = torch.stack([uu, vv, z], dim=-1).unsqueeze(0) # (1,res,res,3) # Tangent frame on the sphere; fall back near the poles where up ~ n_geo. up = torch.tensor((0.0, 1.0, 0.0), dtype=dtype, device=device).view(1, 1, 1, 3) T = torch.cross(up.expand_as(n_geo), n_geo, dim=-1) T_len = torch.norm(T, dim=-1, keepdim=True) alt = torch.tensor((1.0, 0.0, 0.0), dtype=dtype, device=device).view(1, 1, 1, 3) T = torch.where(T_len < 1e-4, alt.expand_as(T), T) T = T / (torch.norm(T, dim=-1, keepdim=True) + 1e-8) B = torch.cross(n_geo, T, dim=-1) # Perturb the geometric normal by the predicted detail normal. s = float(detail_strength) N = (s * n_det[..., 0:1] * T + s * n_det[..., 1:2] * B + n_det[..., 2:3] * n_geo) N = N / (torch.norm(N, dim=-1, keepdim=True) + 1e-8) diffuse, specular = metallic_to_specular(bc, m) wi = _light_dir(light_az_deg, light_el_deg, dtype, device) wo = torch.tensor(_VIEW_DIR, dtype=dtype, device=device).view(1, 1, 1, 3) # Brighter key light so the lit hemisphere reads clearly; ambient fills shadow. radiance = _KEY_INTENSITY * ggx_shade(diffuse, specular, r, N, wi, wo) radiance = radiance + _SPHERE_AMBIENT * diffuse ball = (radiance / (1.0 + radiance)).clamp(0.0, 1.0)[0] # (res,res,3) # Neutral vertical-gradient background; composite the ball over it. top = torch.tensor((0.16, 0.16, 0.18), dtype=dtype, device=device) bot = torch.tensor((0.05, 0.05, 0.06), dtype=dtype, device=device) grad = torch.linspace(0.0, 1.0, res, dtype=dtype, device=device).view(res, 1, 1) bg = top.view(1, 1, 3) * (1.0 - grad) + bot.view(1, 1, 3) * grad bg = bg.expand(res, res, 3) out = torch.where(mask.unsqueeze(-1), ball, bg) return out.permute(2, 0, 1).contiguous() # (3,res,res)