File size: 6,938 Bytes
04c78c7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | import torch
import numpy as np
from .scene import Light, Scene, Camera, dot_product, normalize, generate_normalized_random_direction, gamma_encode
class Renderer:
def __init__(self, return_params=False):
self.use_augmentation = False
self.return_params = return_params
def xi(self, x):
return (x > 0.0) * torch.ones_like(x)
def compute_microfacet_distribution(self, roughness, NH):
alpha = roughness**2
alpha_squared = alpha**2
NH_squared = NH**2
denominator_part = torch.clamp(NH_squared * (alpha_squared + (1 - NH_squared) / NH_squared), min=0.001)
return (alpha_squared * self.xi(NH)) / (np.pi * denominator_part**2)
def compute_fresnel(self, F0, VH):
# https://cdn2.unrealengine.com/Resources/files/2013SiggraphPresentationsNotes-26915738.pdf
return F0 + (1.0 - F0) * (1.0 - VH)**5
def compute_g1(self, roughness, XH, XN):
alpha = roughness**2
alpha_squared = alpha**2
XN_squared = XN**2
return 2 * self.xi(XH / XN) / (1 + torch.sqrt(1 + alpha_squared * (1.0 - XN_squared) / XN_squared))
def compute_geometry(self, roughness, VH, LH, VN, LN):
return self.compute_g1(roughness, VH, VN) * self.compute_g1(roughness, LH, LN)
def compute_specular_term(self, wi, wo, albedo, normals, roughness, metalness):
F0 = 0.04 * (1. - metalness) + metalness * albedo
# Compute the half direction
H = normalize((wi + wo) / 2.0)
# Precompute some dot product
NH = torch.clamp(dot_product(normals, H), min=0.001)
VH = torch.clamp(dot_product(wo, H), min=0.001)
LH = torch.clamp(dot_product(wi, H), min=0.001)
VN = torch.clamp(dot_product(wo, normals), min=0.001)
LN = torch.clamp(dot_product(wi, normals), min=0.001)
F = self.compute_fresnel(F0, VH)
G = self.compute_geometry(roughness, VH, LH, VN, LN)
D = self.compute_microfacet_distribution(roughness, NH)
return F * G * D / (4.0 * VN * LN)
def compute_diffuse_term(self, albedo, metalness):
return albedo * (1. - metalness) / np.pi
def evaluate_brdf(self, wi, wo, normals, albedo, roughness, metalness):
diffuse_term = self.compute_diffuse_term(albedo, metalness)
specular_term = self.compute_specular_term(wi, wo, albedo, normals, roughness, metalness)
return diffuse_term, specular_term
def render(self, scene, svbrdf):
#normals, albedo, roughness, displacement = svbrdf
normals, albedo, roughness = svbrdf
device = albedo.device
# Generate surface coordinates for the material patch
# The center point of the patch is located at (0, 0, 0) which is the center of the global coordinate system.
# The patch itself spans from (-1, -1, 0) to (1, 1, 0).
xcoords_row = torch.linspace(-1, 1, albedo.shape[-1], device=device)
xcoords = xcoords_row.unsqueeze(0).expand(albedo.shape[-2], albedo.shape[-1]).unsqueeze(0)
ycoords = -1 * torch.transpose(xcoords, dim0=1, dim1=2)
coords = torch.cat((xcoords, ycoords, torch.zeros_like(xcoords)), dim=0)
# We treat the center of the material patch as focal point of the camera
camera_pos = scene.camera.pos.unsqueeze(-1).unsqueeze(-1).to(device)
relative_camera_pos = camera_pos - coords
wo = normalize(relative_camera_pos)
# Avoid zero roughness (i. e., potential division by zero)
roughness = torch.clamp(roughness, min=0.001)
light_pos = scene.light.pos.unsqueeze(-1).unsqueeze(-1).to(device)
relative_light_pos = light_pos - coords
wi = normalize(relative_light_pos)
fdiffuse, fspecular = self.evaluate_brdf(wi, wo, normals, albedo, roughness, metalness=0)
f = fdiffuse + fspecular
color = scene.light.color if torch.is_tensor(scene.light.color) else torch.tensor(scene.light.color)
light_color = color.unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device)
falloff = 1.0 / torch.sqrt(dot_product(relative_light_pos, relative_light_pos))**2 # Radial light intensity falloff
LN = torch.clamp(dot_product(wi, normals), min=0.0) # Only consider the upper hemisphere
radiance = torch.mul(torch.mul(f, light_color * falloff), LN)
return radiance
def _get_input_params(self, n_samples, light, pose):
min_eps = 0.001
max_eps = 0.02
light_distance = 2.197
view_distance = 2.75
# Generate scenes (camera and light configurations)
# In the first configuration, the light and view direction are guaranteed to be perpendicular to the material sample.
# For the remaining cases, both are randomly sampled from a hemisphere.
view_dist = torch.ones(n_samples-1) * view_distance
if pose is None:
view_poses = torch.cat([torch.Tensor(2).uniform_(-0.25, 0.25), torch.ones(1) * view_distance], dim=-1).unsqueeze(0)
if n_samples > 1:
hemi_views = generate_normalized_random_direction(n_samples - 1, min_eps=min_eps, max_eps=max_eps) * view_distance
view_poses = torch.cat([view_poses, hemi_views])
else:
assert torch.is_tensor(pose)
view_poses = pose.cpu()
if light is None:
light_poses = torch.cat([torch.Tensor(2).uniform_(-0.75, 0.75), torch.ones(1) * light_distance], dim=-1).unsqueeze(0)
if n_samples > 1:
hemi_lights = generate_normalized_random_direction(n_samples - 1, min_eps=min_eps, max_eps=max_eps) * light_distance
light_poses = torch.cat([light_poses, hemi_lights])
else:
assert torch.is_tensor(light)
light_poses = light.cpu()
light_colors = torch.Tensor([10.0]).unsqueeze(-1).expand(n_samples, 3)
return view_poses, light_poses, light_colors
def __call__(self, svbrdf, n_samples=1, lights=None, poses=None):
view_poses, light_poses, light_colors = self._get_input_params(n_samples, lights, poses)
renderings = []
for wo, wi, c in zip(view_poses, light_poses, light_colors):
scene = Scene(Camera(wo), Light(wi, c))
rendering = self.render(scene, svbrdf)
# Simulate noise
std_deviation_noise = torch.exp(torch.Tensor(1).normal_(mean = np.log(0.005), std=0.3)).numpy()[0]
noise = torch.zeros_like(rendering).normal_(mean=0.0, std=std_deviation_noise)
# clipping
post_noise = torch.clamp(rendering + noise, min=0.0, max=1.0)
# gamma encoding
post_gamma = gamma_encode(post_noise)
renderings.append(post_gamma)
renderings = torch.cat(renderings, dim=0)
if self.return_params:
return renderings, (view_poses, light_poses, light_colors)
return renderings |