Spaces:
Running
Running
| # src/render_preview.py | |
| """Deterministic GGX preview render of predicted PBR maps. | |
| Unlike the training rendering loss (random multi-light), this uses a single | |
| fixed key light + small ambient so the demo preview is stable and reproducible. | |
| Two renderers: | |
| - render_preview: flat front-facing patch (quick sanity view). | |
| - render_sphere_preview: the predicted material mapped onto a lit sphere | |
| ("material ball"), which reads as a 3D object and can be relit by moving | |
| the light direction. | |
| """ | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from src.rendering_loss import ggx_shade, metallic_to_specular | |
| # Fixed key light from the upper-right, view straight-on. | |
| _LIGHT_DIR = (0.4, 0.4, 1.0) | |
| _VIEW_DIR = (0.0, 0.0, 1.0) | |
| _AMBIENT = 0.04 | |
| # Sphere "material ball" lighting: brighter key + a little more fill. | |
| _KEY_INTENSITY = 2.4 | |
| _SPHERE_AMBIENT = 0.08 | |
| def render_preview(normal, roughness, metallic, basecolor): | |
| """Render predicted maps under fixed lighting. Inputs (C,H,W) in [0,1].""" | |
| # (C,H,W) -> (1,H,W,C) | |
| n = normal.permute(1, 2, 0).unsqueeze(0) | |
| r = roughness.permute(1, 2, 0).unsqueeze(0) | |
| m = metallic.permute(1, 2, 0).unsqueeze(0) | |
| bc = basecolor.permute(1, 2, 0).unsqueeze(0) | |
| # Normal: [0,1] storage -> signed [-1,1], normalized. | |
| n = n * 2.0 - 1.0 | |
| n = n / (torch.norm(n, dim=-1, keepdim=True) + 1e-8) | |
| diffuse, specular = metallic_to_specular(bc, m) | |
| _, H, W, _ = bc.shape | |
| wi = torch.tensor(_LIGHT_DIR, dtype=bc.dtype, device=bc.device).view(1, 1, 1, 3) | |
| wo = torch.tensor(_VIEW_DIR, dtype=bc.dtype, device=bc.device).view(1, 1, 1, 3) | |
| radiance = ggx_shade(diffuse, specular, r, n, wi, wo) # (1,H,W,3) | |
| radiance = radiance + _AMBIENT * diffuse | |
| # Reinhard tone-map to [0,1), then clamp for safety. | |
| ldr = (radiance / (1.0 + radiance)).clamp(0.0, 1.0) | |
| return ldr[0].permute(2, 0, 1).contiguous() # (3,H,W) | |
| def _light_dir(az_deg, el_deg, dtype, device): | |
| """Unit light direction from azimuth/elevation (degrees). +z faces camera.""" | |
| az = math.radians(az_deg) | |
| el = math.radians(el_deg) | |
| x = math.cos(el) * math.sin(az) | |
| y = math.sin(el) | |
| z = math.cos(el) * math.cos(az) | |
| return torch.tensor((x, y, z), dtype=dtype, device=device).view(1, 1, 1, 3) | |
| def render_sphere_preview( | |
| normal, | |
| roughness, | |
| metallic, | |
| basecolor, | |
| light_az_deg=-35.0, | |
| light_el_deg=35.0, | |
| detail_strength=1.0, | |
| res=384, | |
| ): | |
| """Render the predicted material onto a lit sphere ("material ball"). | |
| Inputs are (C,H,W) maps in [0,1]. The flat texture is front-projected onto a | |
| unit sphere: the geometric sphere normal gives the round 3D shading, and the | |
| predicted tangent-space normal perturbs it for surface detail. Returns a | |
| (3,res,res) RGB image in [0,1] with the ball composited on a neutral | |
| background. Deterministic for fixed light/strength. | |
| """ | |
| dtype, device = basecolor.dtype, basecolor.device | |
| def resize(t): # (C,H,W) -> (1,res,res,C) | |
| t = F.interpolate(t.unsqueeze(0), size=(res, res), mode="bilinear", | |
| align_corners=False) | |
| return t.permute(0, 2, 3, 1) | |
| bc = resize(basecolor) | |
| r = resize(roughness) | |
| m = resize(metallic) | |
| n_det = resize(normal) * 2.0 - 1.0 | |
| n_det = n_det / (torch.norm(n_det, dim=-1, keepdim=True) + 1e-8) | |
| # Sphere geometry: screen (u,v) in [-1,1], v up. Inside unit disk => surface. | |
| ys = torch.linspace(1.0, -1.0, res, dtype=dtype, device=device) # +v up | |
| xs = torch.linspace(-1.0, 1.0, res, dtype=dtype, device=device) | |
| vv, uu = torch.meshgrid(ys, xs, indexing="ij") | |
| r2 = uu ** 2 + vv ** 2 | |
| mask = (r2 <= 1.0) | |
| z = torch.sqrt((1.0 - r2).clamp(min=0.0)) | |
| n_geo = torch.stack([uu, vv, z], dim=-1).unsqueeze(0) # (1,res,res,3) | |
| # Tangent frame on the sphere; fall back near the poles where up ~ n_geo. | |
| up = torch.tensor((0.0, 1.0, 0.0), dtype=dtype, device=device).view(1, 1, 1, 3) | |
| T = torch.cross(up.expand_as(n_geo), n_geo, dim=-1) | |
| T_len = torch.norm(T, dim=-1, keepdim=True) | |
| alt = torch.tensor((1.0, 0.0, 0.0), dtype=dtype, device=device).view(1, 1, 1, 3) | |
| T = torch.where(T_len < 1e-4, alt.expand_as(T), T) | |
| T = T / (torch.norm(T, dim=-1, keepdim=True) + 1e-8) | |
| B = torch.cross(n_geo, T, dim=-1) | |
| # Perturb the geometric normal by the predicted detail normal. | |
| s = float(detail_strength) | |
| N = (s * n_det[..., 0:1] * T | |
| + s * n_det[..., 1:2] * B | |
| + n_det[..., 2:3] * n_geo) | |
| N = N / (torch.norm(N, dim=-1, keepdim=True) + 1e-8) | |
| diffuse, specular = metallic_to_specular(bc, m) | |
| wi = _light_dir(light_az_deg, light_el_deg, dtype, device) | |
| wo = torch.tensor(_VIEW_DIR, dtype=dtype, device=device).view(1, 1, 1, 3) | |
| # Brighter key light so the lit hemisphere reads clearly; ambient fills shadow. | |
| radiance = _KEY_INTENSITY * ggx_shade(diffuse, specular, r, N, wi, wo) | |
| radiance = radiance + _SPHERE_AMBIENT * diffuse | |
| ball = (radiance / (1.0 + radiance)).clamp(0.0, 1.0)[0] # (res,res,3) | |
| # Neutral vertical-gradient background; composite the ball over it. | |
| top = torch.tensor((0.16, 0.16, 0.18), dtype=dtype, device=device) | |
| bot = torch.tensor((0.05, 0.05, 0.06), dtype=dtype, device=device) | |
| grad = torch.linspace(0.0, 1.0, res, dtype=dtype, device=device).view(res, 1, 1) | |
| bg = top.view(1, 1, 3) * (1.0 - grad) + bot.view(1, 1, 3) * grad | |
| bg = bg.expand(res, res, 3) | |
| out = torch.where(mask.unsqueeze(-1), ball, bg) | |
| return out.permute(2, 0, 1).contiguous() # (3,res,res) | |