Spaces:
Paused
Paused
| // Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. | |
| // | |
| // NVIDIA CORPORATION and its licensors retain all intellectual property | |
| // and proprietary rights in and to this software, related documentation | |
| // and any modifications thereto. Any use, reproduction, disclosure or | |
| // distribution of this software and related documentation without an express | |
| // license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| //------------------------------------------------------------------------ | |
| // Lambert functions | |
| __device__ float fwdLambert(const vec3f nrm, const vec3f wi) | |
| { | |
| return max(dot(nrm, wi) / M_PI, 0.0f); | |
| } | |
| __device__ void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out) | |
| { | |
| if (dot(nrm, wi) > 0.0f) | |
| bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI); | |
| } | |
| //------------------------------------------------------------------------ | |
| // Fresnel Schlick | |
| __device__ vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta) | |
| { | |
| float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); | |
| float scale = powf(1.0f - _cosTheta, 5.0f); | |
| return f0 * (1.0f - scale) + f90 * scale; | |
| } | |
| __device__ void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out) | |
| { | |
| float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); | |
| float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f); | |
| d_f0 += d_out * (1.0 - scale); | |
| d_f90 += d_out * scale; | |
| if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) | |
| { | |
| d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f)); | |
| } | |
| } | |
| //------------------------------------------------------------------------ | |
| // Ndf GGX | |
| __device__ float fwdNdfGGX(const float alphaSqr, const float cosTheta) | |
| { | |
| float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); | |
| float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f; | |
| return alphaSqr / (d * d * M_PI); | |
| } | |
| __device__ void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out) | |
| { | |
| // Torch only back propagates if clamp doesn't trigger | |
| float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); | |
| float cosThetaSqr = _cosTheta * _cosTheta; | |
| d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f)); | |
| if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) | |
| { | |
| d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f)); | |
| } | |
| } | |
| //------------------------------------------------------------------------ | |
| // Lambda GGX | |
| __device__ float fwdLambdaGGX(const float alphaSqr, const float cosTheta) | |
| { | |
| float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); | |
| float cosThetaSqr = _cosTheta * _cosTheta; | |
| float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr; | |
| float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f); | |
| return res; | |
| } | |
| __device__ void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out) | |
| { | |
| float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); | |
| float cosThetaSqr = _cosTheta * _cosTheta; | |
| float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr; | |
| float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f); | |
| d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f); | |
| if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) | |
| d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f)); | |
| } | |
| //------------------------------------------------------------------------ | |
| // Masking GGX | |
| __device__ float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO) | |
| { | |
| float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI); | |
| float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO); | |
| return 1.0f / (1.0f + lambdaI + lambdaO); | |
| } | |
| __device__ void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out) | |
| { | |
| // FWD eval | |
| float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI); | |
| float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO); | |
| // BWD eval | |
| float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f); | |
| bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO); | |
| bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO); | |
| } | |
| //------------------------------------------------------------------------ | |
| // GGX specular | |
| __device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness) | |
| { | |
| float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f); | |
| float alphaSqr = _alpha * _alpha; | |
| vec3f h = safeNormalize(wo + wi); | |
| float woDotN = dot(wo, nrm); | |
| float wiDotN = dot(wi, nrm); | |
| float woDotH = dot(wo, h); | |
| float nDotH = dot(nrm, h); | |
| float D = fwdNdfGGX(alphaSqr, nDotH); | |
| float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN); | |
| vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH); | |
| vec3f w = F * D * G * 0.25 / woDotN; | |
| bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON); | |
| return frontfacing ? w : 0.0f; | |
| } | |
| __device__ void bwdPbrSpecular( | |
| const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness, | |
| vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out) | |
| { | |
| /////////////////////////////////////////////////////////////////////// | |
| // FWD eval | |
| float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f); | |
| float alphaSqr = _alpha * _alpha; | |
| vec3f h = safeNormalize(wo + wi); | |
| float woDotN = dot(wo, nrm); | |
| float wiDotN = dot(wi, nrm); | |
| float woDotH = dot(wo, h); | |
| float nDotH = dot(nrm, h); | |
| float D = fwdNdfGGX(alphaSqr, nDotH); | |
| float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN); | |
| vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH); | |
| vec3f w = F * D * G * 0.25 / woDotN; | |
| bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON); | |
| if (frontfacing) | |
| { | |
| /////////////////////////////////////////////////////////////////////// | |
| // BWD eval | |
| vec3f d_F = d_out * D * G * 0.25f / woDotN; | |
| float d_D = sum(d_out * F * G * 0.25f / woDotN); | |
| float d_G = sum(d_out * F * D * 0.25f / woDotN); | |
| float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN)); | |
| vec3f d_f90(0); | |
| float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0); | |
| bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F); | |
| bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G); | |
| bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D); | |
| vec3f d_h(0); | |
| bwdDot(nrm, h, d_nrm, d_h, d_nDotH); | |
| bwdDot(wo, h, d_wo, d_h, d_woDotH); | |
| bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN); | |
| bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN); | |
| vec3f d_h_unnorm(0); | |
| bwdSafeNormalize(wo + wi, d_h_unnorm, d_h); | |
| d_wo += d_h_unnorm; | |
| d_wi += d_h_unnorm; | |
| if (alpha > min_roughness * min_roughness) | |
| d_alpha += d_alphaSqr * 2 * alpha; | |
| } | |
| } | |
| //------------------------------------------------------------------------ | |
| // Full PBR BSDF | |
| __device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness) | |
| { | |
| vec3f wo = safeNormalize(view_pos - pos); | |
| vec3f wi = safeNormalize(light_pos - pos); | |
| float alpha = arm.y * arm.y; | |
| vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x); | |
| vec3f diff_col = kd * (1.0f - arm.z); | |
| float lambert = fwdLambert(nrm, wi); | |
| vec3f diffuse = diff_col * lambert; | |
| vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness); | |
| return diffuse + specular; | |
| } | |
| __device__ void bwdPbrBSDF( | |
| const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, | |
| vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out) | |
| { | |
| //////////////////////////////////////////////////////////////////////// | |
| // FWD | |
| vec3f _wi = light_pos - pos; | |
| vec3f _wo = view_pos - pos; | |
| vec3f wi = safeNormalize(_wi); | |
| vec3f wo = safeNormalize(_wo); | |
| float alpha = arm.y * arm.y; | |
| vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x); | |
| vec3f diff_col = kd * (1.0f - arm.z); | |
| float lambert = fwdLambert(nrm, wi); | |
| //////////////////////////////////////////////////////////////////////// | |
| // BWD | |
| float d_alpha(0); | |
| vec3f d_spec_col(0), d_wi(0), d_wo(0); | |
| bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out); | |
| float d_lambert = sum(diff_col * d_out); | |
| bwdLambert(nrm, wi, d_nrm, d_wi, d_lambert); | |
| // Backprop: diff_col = kd * (1.0f - arm.z) | |
| vec3f d_diff_col = d_out * lambert; | |
| d_kd += d_diff_col * (1.0f - arm.z); | |
| d_arm.z -= sum(d_diff_col * kd); | |
| // Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x) | |
| d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z; | |
| d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f)); | |
| d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f)); | |
| // Backprop: alpha = arm.y * arm.y | |
| d_arm.y += d_alpha * 2 * arm.y; | |
| // Backprop: vec3f wi = safeNormalize(light_pos - pos); | |
| vec3f d__wi(0); | |
| bwdSafeNormalize(_wi, d__wi, d_wi); | |
| d_light_pos += d__wi; | |
| d_pos -= d__wi; | |
| // Backprop: vec3f wo = safeNormalize(view_pos - pos); | |
| vec3f d__wo(0); | |
| bwdSafeNormalize(_wo, d__wo, d_wo); | |
| d_view_pos += d__wo; | |
| d_pos -= d__wo; | |
| } | |
| //------------------------------------------------------------------------ | |
| // Kernels | |
| __global__ void LambertFwdKernel(LambertKernelParams p) | |
| { | |
| // Calculate pixel position. | |
| unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; | |
| unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; | |
| unsigned int pz = blockIdx.z; | |
| if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) | |
| return; | |
| vec3f nrm = p.nrm.fetch3(px, py, pz); | |
| vec3f wi = p.wi.fetch3(px, py, pz); | |
| float res = fwdLambert(nrm, wi); | |
| p.out.store(px, py, pz, res); | |
| } | |
| __global__ void LambertBwdKernel(LambertKernelParams p) | |
| { | |
| // Calculate pixel position. | |
| unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; | |
| unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; | |
| unsigned int pz = blockIdx.z; | |
| if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) | |
| return; | |
| vec3f nrm = p.nrm.fetch3(px, py, pz); | |
| vec3f wi = p.wi.fetch3(px, py, pz); | |
| float d_out = p.out.fetch1(px, py, pz); | |
| vec3f d_nrm(0), d_wi(0); | |
| bwdLambert(nrm, wi, d_nrm, d_wi, d_out); | |
| p.nrm.store_grad(px, py, pz, d_nrm); | |
| p.wi.store_grad(px, py, pz, d_wi); | |
| } | |
| __global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p) | |
| { | |
| // Calculate pixel position. | |
| unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; | |
| unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; | |
| unsigned int pz = blockIdx.z; | |
| if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) | |
| return; | |
| vec3f f0 = p.f0.fetch3(px, py, pz); | |
| vec3f f90 = p.f90.fetch3(px, py, pz); | |
| float cosTheta = p.cosTheta.fetch1(px, py, pz); | |
| vec3f res = fwdFresnelSchlick(f0, f90, cosTheta); | |
| p.out.store(px, py, pz, res); | |
| } | |
| __global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p) | |
| { | |
| // Calculate pixel position. | |
| unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; | |
| unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; | |
| unsigned int pz = blockIdx.z; | |
| if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) | |
| return; | |
| vec3f f0 = p.f0.fetch3(px, py, pz); | |
| vec3f f90 = p.f90.fetch3(px, py, pz); | |
| float cosTheta = p.cosTheta.fetch1(px, py, pz); | |
| vec3f d_out = p.out.fetch3(px, py, pz); | |
| vec3f d_f0(0), d_f90(0); | |
| float d_cosTheta(0); | |
| bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out); | |
| p.f0.store_grad(px, py, pz, d_f0); | |
| p.f90.store_grad(px, py, pz, d_f90); | |
| p.cosTheta.store_grad(px, py, pz, d_cosTheta); | |
| } | |
| __global__ void ndfGGXFwdKernel(NdfGGXParams p) | |
| { | |
| // Calculate pixel position. | |
| unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; | |
| unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; | |
| unsigned int pz = blockIdx.z; | |
| if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) | |
| return; | |
| float alphaSqr = p.alphaSqr.fetch1(px, py, pz); | |
| float cosTheta = p.cosTheta.fetch1(px, py, pz); | |
| float res = fwdNdfGGX(alphaSqr, cosTheta); | |
| p.out.store(px, py, pz, res); | |
| } | |
| __global__ void ndfGGXBwdKernel(NdfGGXParams p) | |
| { | |
| // Calculate pixel position. | |
| unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; | |
| unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; | |
| unsigned int pz = blockIdx.z; | |
| if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) | |
| return; | |
| float alphaSqr = p.alphaSqr.fetch1(px, py, pz); | |
| float cosTheta = p.cosTheta.fetch1(px, py, pz); | |
| float d_out = p.out.fetch1(px, py, pz); | |
| float d_alphaSqr(0), d_cosTheta(0); | |
| bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out); | |
| p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); | |
| p.cosTheta.store_grad(px, py, pz, d_cosTheta); | |
| } | |
| __global__ void lambdaGGXFwdKernel(NdfGGXParams p) | |
| { | |
| // Calculate pixel position. | |
| unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; | |
| unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; | |
| unsigned int pz = blockIdx.z; | |
| if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) | |
| return; | |
| float alphaSqr = p.alphaSqr.fetch1(px, py, pz); | |
| float cosTheta = p.cosTheta.fetch1(px, py, pz); | |
| float res = fwdLambdaGGX(alphaSqr, cosTheta); | |
| p.out.store(px, py, pz, res); | |
| } | |
| __global__ void lambdaGGXBwdKernel(NdfGGXParams p) | |
| { | |
| // Calculate pixel position. | |
| unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; | |
| unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; | |
| unsigned int pz = blockIdx.z; | |
| if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) | |
| return; | |
| float alphaSqr = p.alphaSqr.fetch1(px, py, pz); | |
| float cosTheta = p.cosTheta.fetch1(px, py, pz); | |
| float d_out = p.out.fetch1(px, py, pz); | |
| float d_alphaSqr(0), d_cosTheta(0); | |
| bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out); | |
| p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); | |
| p.cosTheta.store_grad(px, py, pz, d_cosTheta); | |
| } | |
| __global__ void maskingSmithFwdKernel(MaskingSmithParams p) | |
| { | |
| // Calculate pixel position. | |
| unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; | |
| unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; | |
| unsigned int pz = blockIdx.z; | |
| if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) | |
| return; | |
| float alphaSqr = p.alphaSqr.fetch1(px, py, pz); | |
| float cosThetaI = p.cosThetaI.fetch1(px, py, pz); | |
| float cosThetaO = p.cosThetaO.fetch1(px, py, pz); | |
| float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO); | |
| p.out.store(px, py, pz, res); | |
| } | |
| __global__ void maskingSmithBwdKernel(MaskingSmithParams p) | |
| { | |
| // Calculate pixel position. | |
| unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; | |
| unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; | |
| unsigned int pz = blockIdx.z; | |
| if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) | |
| return; | |
| float alphaSqr = p.alphaSqr.fetch1(px, py, pz); | |
| float cosThetaI = p.cosThetaI.fetch1(px, py, pz); | |
| float cosThetaO = p.cosThetaO.fetch1(px, py, pz); | |
| float d_out = p.out.fetch1(px, py, pz); | |
| float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0); | |
| bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out); | |
| p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); | |
| p.cosThetaI.store_grad(px, py, pz, d_cosThetaI); | |
| p.cosThetaO.store_grad(px, py, pz, d_cosThetaO); | |
| } | |
| __global__ void pbrSpecularFwdKernel(PbrSpecular p) | |
| { | |
| // Calculate pixel position. | |
| unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; | |
| unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; | |
| unsigned int pz = blockIdx.z; | |
| if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) | |
| return; | |
| vec3f col = p.col.fetch3(px, py, pz); | |
| vec3f nrm = p.nrm.fetch3(px, py, pz); | |
| vec3f wo = p.wo.fetch3(px, py, pz); | |
| vec3f wi = p.wi.fetch3(px, py, pz); | |
| float alpha = p.alpha.fetch1(px, py, pz); | |
| vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness); | |
| p.out.store(px, py, pz, res); | |
| } | |
| __global__ void pbrSpecularBwdKernel(PbrSpecular p) | |
| { | |
| // Calculate pixel position. | |
| unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; | |
| unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; | |
| unsigned int pz = blockIdx.z; | |
| if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) | |
| return; | |
| vec3f col = p.col.fetch3(px, py, pz); | |
| vec3f nrm = p.nrm.fetch3(px, py, pz); | |
| vec3f wo = p.wo.fetch3(px, py, pz); | |
| vec3f wi = p.wi.fetch3(px, py, pz); | |
| float alpha = p.alpha.fetch1(px, py, pz); | |
| vec3f d_out = p.out.fetch3(px, py, pz); | |
| float d_alpha(0); | |
| vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0); | |
| bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out); | |
| p.col.store_grad(px, py, pz, d_col); | |
| p.nrm.store_grad(px, py, pz, d_nrm); | |
| p.wo.store_grad(px, py, pz, d_wo); | |
| p.wi.store_grad(px, py, pz, d_wi); | |
| p.alpha.store_grad(px, py, pz, d_alpha); | |
| } | |
| __global__ void pbrBSDFFwdKernel(PbrBSDF p) | |
| { | |
| // Calculate pixel position. | |
| unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; | |
| unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; | |
| unsigned int pz = blockIdx.z; | |
| if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) | |
| return; | |
| vec3f kd = p.kd.fetch3(px, py, pz); | |
| vec3f arm = p.arm.fetch3(px, py, pz); | |
| vec3f pos = p.pos.fetch3(px, py, pz); | |
| vec3f nrm = p.nrm.fetch3(px, py, pz); | |
| vec3f view_pos = p.view_pos.fetch3(px, py, pz); | |
| vec3f light_pos = p.light_pos.fetch3(px, py, pz); | |
| vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness); | |
| p.out.store(px, py, pz, res); | |
| } | |
| __global__ void pbrBSDFBwdKernel(PbrBSDF p) | |
| { | |
| // Calculate pixel position. | |
| unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; | |
| unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; | |
| unsigned int pz = blockIdx.z; | |
| if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) | |
| return; | |
| vec3f kd = p.kd.fetch3(px, py, pz); | |
| vec3f arm = p.arm.fetch3(px, py, pz); | |
| vec3f pos = p.pos.fetch3(px, py, pz); | |
| vec3f nrm = p.nrm.fetch3(px, py, pz); | |
| vec3f view_pos = p.view_pos.fetch3(px, py, pz); | |
| vec3f light_pos = p.light_pos.fetch3(px, py, pz); | |
| vec3f d_out = p.out.fetch3(px, py, pz); | |
| vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0); | |
| bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out); | |
| p.kd.store_grad(px, py, pz, d_kd); | |
| p.arm.store_grad(px, py, pz, d_arm); | |
| p.pos.store_grad(px, py, pz, d_pos); | |
| p.nrm.store_grad(px, py, pz, d_nrm); | |
| p.view_pos.store_grad(px, py, pz, d_view_pos); | |
| p.light_pos.store_grad(px, py, pz, d_light_pos); | |
| } | |