pbr-material-predictor / src /render_preview.py
Andrid1's picture
3D material-ball preview + relight sliders + real example textures
dec3137 verified
Raw
History Blame Contribute Delete
5.55 kB
# src/render_preview.py
"""Deterministic GGX preview render of predicted PBR maps.
Unlike the training rendering loss (random multi-light), this uses a single
fixed key light + small ambient so the demo preview is stable and reproducible.
Two renderers:
- render_preview: flat front-facing patch (quick sanity view).
- render_sphere_preview: the predicted material mapped onto a lit sphere
("material ball"), which reads as a 3D object and can be relit by moving
the light direction.
"""
import math
import torch
import torch.nn.functional as F
from src.rendering_loss import ggx_shade, metallic_to_specular
# Fixed key light from the upper-right, view straight-on.
_LIGHT_DIR = (0.4, 0.4, 1.0)
_VIEW_DIR = (0.0, 0.0, 1.0)
_AMBIENT = 0.04
# Sphere "material ball" lighting: brighter key + a little more fill.
_KEY_INTENSITY = 2.4
_SPHERE_AMBIENT = 0.08
def render_preview(normal, roughness, metallic, basecolor):
"""Render predicted maps under fixed lighting. Inputs (C,H,W) in [0,1]."""
# (C,H,W) -> (1,H,W,C)
n = normal.permute(1, 2, 0).unsqueeze(0)
r = roughness.permute(1, 2, 0).unsqueeze(0)
m = metallic.permute(1, 2, 0).unsqueeze(0)
bc = basecolor.permute(1, 2, 0).unsqueeze(0)
# Normal: [0,1] storage -> signed [-1,1], normalized.
n = n * 2.0 - 1.0
n = n / (torch.norm(n, dim=-1, keepdim=True) + 1e-8)
diffuse, specular = metallic_to_specular(bc, m)
_, H, W, _ = bc.shape
wi = torch.tensor(_LIGHT_DIR, dtype=bc.dtype, device=bc.device).view(1, 1, 1, 3)
wo = torch.tensor(_VIEW_DIR, dtype=bc.dtype, device=bc.device).view(1, 1, 1, 3)
radiance = ggx_shade(diffuse, specular, r, n, wi, wo) # (1,H,W,3)
radiance = radiance + _AMBIENT * diffuse
# Reinhard tone-map to [0,1), then clamp for safety.
ldr = (radiance / (1.0 + radiance)).clamp(0.0, 1.0)
return ldr[0].permute(2, 0, 1).contiguous() # (3,H,W)
def _light_dir(az_deg, el_deg, dtype, device):
"""Unit light direction from azimuth/elevation (degrees). +z faces camera."""
az = math.radians(az_deg)
el = math.radians(el_deg)
x = math.cos(el) * math.sin(az)
y = math.sin(el)
z = math.cos(el) * math.cos(az)
return torch.tensor((x, y, z), dtype=dtype, device=device).view(1, 1, 1, 3)
def render_sphere_preview(
normal,
roughness,
metallic,
basecolor,
light_az_deg=-35.0,
light_el_deg=35.0,
detail_strength=1.0,
res=384,
):
"""Render the predicted material onto a lit sphere ("material ball").
Inputs are (C,H,W) maps in [0,1]. The flat texture is front-projected onto a
unit sphere: the geometric sphere normal gives the round 3D shading, and the
predicted tangent-space normal perturbs it for surface detail. Returns a
(3,res,res) RGB image in [0,1] with the ball composited on a neutral
background. Deterministic for fixed light/strength.
"""
dtype, device = basecolor.dtype, basecolor.device
def resize(t): # (C,H,W) -> (1,res,res,C)
t = F.interpolate(t.unsqueeze(0), size=(res, res), mode="bilinear",
align_corners=False)
return t.permute(0, 2, 3, 1)
bc = resize(basecolor)
r = resize(roughness)
m = resize(metallic)
n_det = resize(normal) * 2.0 - 1.0
n_det = n_det / (torch.norm(n_det, dim=-1, keepdim=True) + 1e-8)
# Sphere geometry: screen (u,v) in [-1,1], v up. Inside unit disk => surface.
ys = torch.linspace(1.0, -1.0, res, dtype=dtype, device=device) # +v up
xs = torch.linspace(-1.0, 1.0, res, dtype=dtype, device=device)
vv, uu = torch.meshgrid(ys, xs, indexing="ij")
r2 = uu ** 2 + vv ** 2
mask = (r2 <= 1.0)
z = torch.sqrt((1.0 - r2).clamp(min=0.0))
n_geo = torch.stack([uu, vv, z], dim=-1).unsqueeze(0) # (1,res,res,3)
# Tangent frame on the sphere; fall back near the poles where up ~ n_geo.
up = torch.tensor((0.0, 1.0, 0.0), dtype=dtype, device=device).view(1, 1, 1, 3)
T = torch.cross(up.expand_as(n_geo), n_geo, dim=-1)
T_len = torch.norm(T, dim=-1, keepdim=True)
alt = torch.tensor((1.0, 0.0, 0.0), dtype=dtype, device=device).view(1, 1, 1, 3)
T = torch.where(T_len < 1e-4, alt.expand_as(T), T)
T = T / (torch.norm(T, dim=-1, keepdim=True) + 1e-8)
B = torch.cross(n_geo, T, dim=-1)
# Perturb the geometric normal by the predicted detail normal.
s = float(detail_strength)
N = (s * n_det[..., 0:1] * T
+ s * n_det[..., 1:2] * B
+ n_det[..., 2:3] * n_geo)
N = N / (torch.norm(N, dim=-1, keepdim=True) + 1e-8)
diffuse, specular = metallic_to_specular(bc, m)
wi = _light_dir(light_az_deg, light_el_deg, dtype, device)
wo = torch.tensor(_VIEW_DIR, dtype=dtype, device=device).view(1, 1, 1, 3)
# Brighter key light so the lit hemisphere reads clearly; ambient fills shadow.
radiance = _KEY_INTENSITY * ggx_shade(diffuse, specular, r, N, wi, wo)
radiance = radiance + _SPHERE_AMBIENT * diffuse
ball = (radiance / (1.0 + radiance)).clamp(0.0, 1.0)[0] # (res,res,3)
# Neutral vertical-gradient background; composite the ball over it.
top = torch.tensor((0.16, 0.16, 0.18), dtype=dtype, device=device)
bot = torch.tensor((0.05, 0.05, 0.06), dtype=dtype, device=device)
grad = torch.linspace(0.0, 1.0, res, dtype=dtype, device=device).view(res, 1, 1)
bg = top.view(1, 1, 3) * (1.0 - grad) + bot.view(1, 1, 3) * grad
bg = bg.expand(res, res, 3)
out = torch.where(mask.unsqueeze(-1), ball, bg)
return out.permute(2, 0, 1).contiguous() # (3,res,res)