Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| def vector_dot(A: torch.Tensor, B: torch.Tensor, min=0.0) -> torch.Tensor: | |
| return torch.clamp((A * B).sum(1, keepdim=True), min=min, max=1.0) | |
| def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: | |
| return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4)).to(f.dtype) | |
| def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: | |
| return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055).to(f.dtype) | |
| def tone_gamma(x: torch.Tensor) -> torch.Tensor: | |
| x = 1 - torch.exp(-x) | |
| return torch.pow(x, 1.0/2.2) | |
| # safe division for value range 0-1 | |
| class safe_01_div(torch.autograd.Function): | |
| def forward(ctx, a, b): | |
| ctx.save_for_backward(a, b) | |
| return torch.div(a, torch.clamp(b, min=1e-4, max=1.0)) | |
| def backward(ctx, grad_output): | |
| a, b = ctx.saved_tensors | |
| grad_input = grad_output.clone() | |
| return torch.div(1, torch.clamp(b, min=1e-4, max=1.0)) * grad_input, -1 * torch.div(a, torch.clamp(b, min=1e-2, max=1.0)**2) * grad_input | |
| def get_positions(h, w, real_size, use_pixel_centers=True) -> torch.Tensor: | |
| pixel_center = 0.5 if use_pixel_centers else 0 | |
| i, j = torch.meshgrid( | |
| torch.arange(h) + pixel_center, | |
| torch.arange(w) + pixel_center, | |
| indexing='ij' | |
| ) | |
| if not isinstance(real_size, list): | |
| real_size = [real_size] * 2 | |
| pos = torch.stack([(i / h - 0.5) * real_size[0], (j / w - 0.5) * real_size[1], torch.zeros_like(i)], dim=-1) | |
| return pos | |
| # N, H: (Bx3xHxW), roughness: (Bx1xHxW) | |
| # The "D", facet distribution function in Cook-Torrence model | |
| def DistributionGGX(cosNH, roughness): | |
| a = roughness * roughness | |
| a2 = a * a | |
| cosNH2 = cosNH * cosNH | |
| num = a2 | |
| denom = cosNH2 * (a2 - 1.0) + 1.0 | |
| denom = torch.pi * denom * denom | |
| return num / denom | |
| # NdotV, roughness: (Bx1xHxW) | |
| def GeometrySchlickGGX(NdotV: torch.Tensor, roughness: torch.Tensor) -> torch.Tensor: | |
| r = (roughness + 1.0) | |
| k = (r*r) / 8.0 | |
| num = NdotV | |
| denom = NdotV * (1.0 - k) + k | |
| return num / denom | |
| # cosTheta, F0 (Bx1xHxW) | |
| # The "F" | |
| def fresnelSchlick(cosTheta: torch.Tensor, F0: torch.Tensor) -> torch.Tensor: | |
| return F0 + (1.0 - F0) * torch.pow(1.0 - cosTheta, 5.0) |