# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved. # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. ############################################################################# # Example Differentiable Ray Tracer # # Shows how to use the built-in wp.Mesh data structure and wp.mesh_query_ray() # function to implement a basic differentiable ray tracer # ############################################################################## import math import os import numpy as np from pxr import Usd, UsdGeom import warp as wp from warp.optim import SGD wp.init() class RenderMode: """Rendering modes grayscale: Lambertian shading from multiple directional lights texture: 2D texture map normal_map: mesh normal computed from interpolated vertex normals """ grayscale = 0 texture = 1 normal_map = 2 @wp.struct class RenderMesh: """Mesh to be ray traced. Assumes a triangle mesh as input. Per-vertex normals are computed with compute_vertex_normals() """ id: wp.uint64 vertices: wp.array(dtype=wp.vec3) indices: wp.array(dtype=int) tex_coords: wp.array(dtype=wp.vec2) tex_indices: wp.array(dtype=int) vertex_normals: wp.array(dtype=wp.vec3) pos: wp.array(dtype=wp.vec3) rot: wp.array(dtype=wp.quat) @wp.struct class Camera: """Basic camera for ray tracing""" horizontal: float vertical: float aspect: float e: float tan: float pos: wp.vec3 rot: wp.quat @wp.struct class DirectionalLights: """Stores arrays of directional light directions and intensities.""" dirs: wp.array(dtype=wp.vec3) intensities: wp.array(dtype=float) num_lights: int @wp.kernel def vertex_normal_sum_kernel( verts: wp.array(dtype=wp.vec3), indices: wp.array(dtype=int), normal_sums: wp.array(dtype=wp.vec3) ): tid = wp.tid() i = indices[tid * 3] j = indices[tid * 3 + 1] k = indices[tid * 3 + 2] a = verts[i] b = verts[j] c = verts[k] ab = b - a ac = c - a area_normal = wp.cross(ab, ac) wp.atomic_add(normal_sums, i, area_normal) wp.atomic_add(normal_sums, j, area_normal) wp.atomic_add(normal_sums, k, area_normal) @wp.kernel def normalize_kernel( normal_sums: wp.array(dtype=wp.vec3), vertex_normals: wp.array(dtype=wp.vec3), ): tid = wp.tid() vertex_normals[tid] = wp.normalize(normal_sums[tid]) @wp.func def texture_interpolation(tex_interp: wp.vec2, texture: wp.array2d(dtype=wp.vec3)): tex_width = texture.shape[1] tex_height = texture.shape[0] tex = wp.vec2(tex_interp[0] * float(tex_width - 1), (1.0 - tex_interp[1]) * float(tex_height - 1)) x0 = int(tex[0]) x1 = x0 + 1 alpha_x = tex[0] - float(x0) y0 = int(tex[1]) y1 = y0 + 1 alpha_y = tex[1] - float(y0) c00 = texture[y0, x0] c10 = texture[y0, x1] c01 = texture[y1, x0] c11 = texture[y1, x1] lower = (1.0 - alpha_x) * c00 + alpha_x * c10 upper = (1.0 - alpha_x) * c01 + alpha_x * c11 color = (1.0 - alpha_y) * lower + alpha_y * upper return color @wp.kernel def draw_kernel( mesh: RenderMesh, camera: Camera, texture: wp.array2d(dtype=wp.vec3), rays_width: int, rays_height: int, rays: wp.array(dtype=wp.vec3), lights: DirectionalLights, mode: int, ): tid = wp.tid() x = tid % rays_width y = rays_height - tid // rays_width sx = 2.0 * float(x) / float(rays_width) - 1.0 sy = 2.0 * float(y) / float(rays_height) - 1.0 # compute view ray in world space ro_world = camera.pos rd_world = wp.normalize(wp.quat_rotate(camera.rot, wp.vec3(sx * camera.tan * camera.aspect, sy * camera.tan, -1.0))) # compute view ray in mesh space inv = wp.transform_inverse(wp.transform(mesh.pos[0], mesh.rot[0])) ro = wp.transform_point(inv, ro_world) rd = wp.transform_vector(inv, rd_world) t = float(0.0) ur = float(0.0) vr = float(0.0) sign = float(0.0) n = wp.vec3() f = int(0) color = wp.vec3(0.0, 0.0, 0.0) if wp.mesh_query_ray(mesh.id, ro, rd, 1.0e6, t, ur, vr, sign, n, f): i = mesh.indices[f * 3] j = mesh.indices[f * 3 + 1] k = mesh.indices[f * 3 + 2] a = mesh.vertices[i] b = mesh.vertices[j] c = mesh.vertices[k] p = wp.mesh_eval_position(mesh.id, f, ur, vr) # barycentric coordinates tri_area = wp.length(wp.cross(b - a, c - a)) w = wp.length(wp.cross(b - a, p - a)) / tri_area v = wp.length(wp.cross(p - a, c - a)) / tri_area u = 1.0 - w - v a_n = mesh.vertex_normals[i] b_n = mesh.vertex_normals[j] c_n = mesh.vertex_normals[k] # vertex normal interpolation normal = u * a_n + v * b_n + w * c_n if mode == 0 or mode == 1: if mode == 0: # grayscale color = wp.vec3(1.0) elif mode == 1: # texture interpolation tex_a = mesh.tex_coords[mesh.tex_indices[f * 3]] tex_b = mesh.tex_coords[mesh.tex_indices[f * 3 + 1]] tex_c = mesh.tex_coords[mesh.tex_indices[f * 3 + 2]] tex = u * tex_a + v * tex_b + w * tex_c color = texture_interpolation(tex, texture) # lambertian directional lighting lambert = float(0.0) for i in range(lights.num_lights): dir = wp.transform_vector(inv, lights.dirs[i]) val = lights.intensities[i] * wp.dot(normal, dir) if val < 0.0: val = 0.0 lambert = lambert + val color = lambert * color elif mode == 2: # normal map color = normal * 0.5 + wp.vec3(0.5, 0.5, 0.5) if color[0] > 1.0: color = wp.vec3(1.0, color[1], color[2]) if color[1] > 1.0: color = wp.vec3(color[0], 1.0, color[2]) if color[2] > 1.0: color = wp.vec3(color[0], color[1], 1.0) rays[tid] = color @wp.kernel def downsample_kernel( rays: wp.array(dtype=wp.vec3), pixels: wp.array(dtype=wp.vec3), rays_width: int, num_samples: int ): tid = wp.tid() pixels_width = rays_width / num_samples px = tid % pixels_width py = tid // pixels_width start_idx = py * num_samples * rays_width + px * num_samples color = wp.vec3(0.0, 0.0, 0.0) for i in range(0, num_samples): for j in range(0, num_samples): ray = rays[start_idx + i * rays_width + j] color = wp.vec3(color[0] + ray[0], color[1] + ray[1], color[2] + ray[2]) num_samples_sq = float(num_samples * num_samples) color = wp.vec3(color[0] / num_samples_sq, color[1] / num_samples_sq, color[2] / num_samples_sq) pixels[tid] = color @wp.kernel def loss_kernel(pixels: wp.array(dtype=wp.vec3), target_pixels: wp.array(dtype=wp.vec3), loss: wp.array(dtype=float)): tid = wp.tid() pixel = pixels[tid] target_pixel = target_pixels[tid] diff = target_pixel - pixel # pseudo Huber loss delta = 1.0 x = delta * delta * (wp.sqrt(1.0 + (diff[0] / delta) * (diff[0] / delta)) - 1.0) y = delta * delta * (wp.sqrt(1.0 + (diff[1] / delta) * (diff[1] / delta)) - 1.0) z = delta * delta * (wp.sqrt(1.0 + (diff[2] / delta) * (diff[2] / delta)) - 1.0) sum = x + y + z wp.atomic_add(loss, 0, sum) @wp.kernel def normalize(x: wp.array(dtype=wp.quat)): tid = wp.tid() x[tid] = wp.normalize(x[tid]) class Example: """A basic differentiable ray tracer Non-differentiable variables: camera.horizontal: camera horizontal aperture size camera.vertical: camera vertical aperture size camera.aspect: camera aspect ratio camera.e: focal length camera.pos: camera displacement camera.rot: camera rotation (quaternion) pix_width: final image width in pixels pix_height: final image height in pixels num_samples: anti-aliasing. calculated as pow(2, num_samples) directional_lights: characterized by intensity (scalar) and direction (vec3) render_mesh.indices: mesh vertex indices render_mesh.tex_indices: texture indices Differentiable variables: render_mesh.pos: parent transform displacement render_mesh.quat: parent transform rotation (quaternion) render_mesh.vertices: mesh vertex positions render_mesh.vertex_normals: mesh vertex normals render_mesh.tex_coords: 2D texture coordinates """ def __init__(self, stage=None, rot_array=[0.0, 0.0, 0.0, 1.0], verbose=False): self.device = wp.get_device() self.verbose = verbose cam_pos = wp.vec3(0.0, 0.75, 7.0) cam_rot = wp.quat(0.0, 0.0, 0.0, 1.0) horizontal_aperture = 36.0 vertical_aperture = 20.25 aspect = horizontal_aperture / vertical_aperture focal_length = 50.0 self.height = 1024 self.width = int(aspect * self.height) self.num_pixels = self.width * self.height asset_stage = Usd.Stage.Open(os.path.join(os.path.dirname(__file__), "assets/bunny.usd")) mesh_geom = UsdGeom.Mesh(asset_stage.GetPrimAtPath("/bunny/bunny")) points = np.array(mesh_geom.GetPointsAttr().Get()) indices = np.array(mesh_geom.GetFaceVertexIndicesAttr().Get()) num_points = points.shape[0] num_faces = int(indices.shape[0] / 3) # manufacture texture coordinates + indices for this asset distance = np.linalg.norm(points, axis=1) radius = np.max(distance) distance = distance / radius tex_coords = np.stack((distance, distance), axis=1) tex_indices = indices # manufacture texture for this asset x = np.arange(256.0) xx, yy = np.meshgrid(x, x) zz = np.zeros_like(xx) texture_host = np.stack((xx, yy, zz), axis=2) / 255.0 # set anti-aliasing self.num_samples = 1 # set render mode self.render_mode = RenderMode.texture # set training iterations self.train_rate = 3.0e-8 self.train_rate = 5.00e-8 self.momentum = 0.5 self.dampening = 0.1 self.weight_decay = 0.0 self.train_iters = 150 self.period = 10 self.iter = 0 # storage for training animation self.images = np.zeros((self.height, self.width, 3, int(self.train_iters / self.period))) self.image_counter = 0 # construct RenderMesh self.render_mesh = RenderMesh() self.mesh = wp.Mesh( points=wp.array(points, dtype=wp.vec3, requires_grad=True), indices=wp.array(indices, dtype=int) ) self.render_mesh.id = self.mesh.id self.render_mesh.vertices = self.mesh.points self.render_mesh.indices = self.mesh.indices self.render_mesh.tex_coords = wp.array(tex_coords, dtype=wp.vec2, requires_grad=True) self.render_mesh.tex_indices = wp.array(tex_indices, dtype=int) self.normal_sums = wp.zeros(num_points, dtype=wp.vec3, requires_grad=True) self.render_mesh.vertex_normals = wp.zeros(num_points, dtype=wp.vec3, requires_grad=True) self.render_mesh.pos = wp.zeros(1, dtype=wp.vec3, requires_grad=True) self.render_mesh.rot = wp.array(np.array(rot_array), dtype=wp.quat, requires_grad=True) # compute vertex normals wp.launch( kernel=vertex_normal_sum_kernel, dim=num_faces, inputs=[self.render_mesh.vertices, self.render_mesh.indices, self.normal_sums], ) wp.launch( kernel=normalize_kernel, dim=num_points, inputs=[self.normal_sums, self.render_mesh.vertex_normals], ) # construct camera self.camera = Camera() self.camera.horizontal = horizontal_aperture self.camera.vertical = vertical_aperture self.camera.aspect = aspect self.camera.e = focal_length self.camera.tan = vertical_aperture / (2.0 * focal_length) self.camera.pos = cam_pos self.camera.rot = cam_rot # construct texture self.texture = wp.array2d(texture_host, dtype=wp.vec3, requires_grad=True) # construct lights self.lights = DirectionalLights() self.lights.dirs = wp.array(np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), dtype=wp.vec3, requires_grad=True) self.lights.intensities = wp.array(np.array([2.0, 0.2]), dtype=float, requires_grad=True) self.lights.num_lights = 2 # construct rays self.rays_width = self.width * pow(2, self.num_samples) self.rays_height = self.height * pow(2, self.num_samples) self.num_rays = self.rays_width * self.rays_height self.rays = wp.zeros(self.num_rays, dtype=wp.vec3, requires_grad=True) # construct pixels self.pixels = wp.zeros(self.num_pixels, dtype=wp.vec3, requires_grad=True) self.target_pixels = wp.zeros(self.num_pixels, dtype=wp.vec3) # loss array self.loss = wp.zeros(1, dtype=float, requires_grad=True) # capture graph wp.capture_begin(self.device) try: self.tape = wp.Tape() with self.tape: self.compute_loss() self.tape.backward(self.loss) finally: self.graph = wp.capture_end(self.device) self.optimizer = SGD( [self.render_mesh.rot], self.train_rate, momentum=self.momentum, dampening=self.dampening, weight_decay=self.weight_decay, ) def ray_trace(self, is_live=False): # raycast wp.launch( kernel=draw_kernel, dim=self.num_rays, inputs=[ self.render_mesh, self.camera, self.texture, self.rays_width, self.rays_height, self.rays, self.lights, self.render_mode, ], device=self.device, ) # downsample wp.launch( kernel=downsample_kernel, dim=self.num_pixels, inputs=[self.rays, self.pixels, self.rays_width, pow(2, self.num_samples)], device=self.device, ) def compute_loss(self): self.ray_trace() wp.launch( loss_kernel, dim=self.num_pixels, inputs=[self.pixels, self.target_pixels, self.loss], device=self.device ) def get_image(self): return self.pixels.numpy().reshape((self.height, self.width, 3)) def get_animation(self): fig, ax = plt.subplots() plt.axis("off") plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) frames = [] for i in range(self.images.shape[3]): frame = ax.imshow(self.images[:, :, :, i], animated=True) frames.append([frame]) ani = animation.ArtistAnimation(fig, frames, interval=50, blit=True, repeat_delay=1000) return ani def update(self): wp.capture_launch(self.graph) rot_grad = self.tape.gradients[self.render_mesh.rot] self.optimizer.step([rot_grad]) wp.launch(normalize, dim=1, inputs=[self.render_mesh.rot]) if self.verbose and self.iter % self.period == 0: print(f"Iter: {self.iter} Loss: {self.loss}") self.tape.zero() self.loss.zero_() self.iter = self.iter + 1 def render(self): self.images[:, :, :, self.image_counter] = self.get_image() self.image_counter += 1 def train_graph(self): # train for i in range(self.train_iters): self.update() if i % self.period == 0: self.render() if __name__ == "__main__": import matplotlib.animation as animation import matplotlib.image as img import matplotlib.pyplot as plt output_dir = os.path.join(os.path.dirname(__file__), "outputs") reference_example = Example() # render target rotation reference_example.ray_trace() target_image = reference_example.get_image() img.imsave(output_dir + "/target_image.png", target_image) # offset mesh rotation rotated_example = Example( rot_array=[0.0, (math.sqrt(3) - 1) / (2.0 * math.sqrt(2.0)), 0.0, (math.sqrt(3) + 1) / (2.0 * math.sqrt(2.0))], verbose=True, ) wp.copy(rotated_example.target_pixels, reference_example.pixels) # recover target rotation rotated_example.train_graph() final_image = rotated_example.get_image() img.imsave(output_dir + "/final_image.png", final_image) video = rotated_example.get_animation() video.save(output_dir + "/animation.gif", dpi=300, writer=animation.PillowWriter(fps=5))