Spaces:
Runtime error
Runtime error
| import torch | |
| def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
| return torch.sum(x*y, -1, keepdim=True) | |
| def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor: | |
| return 2*dot(x, n)*n - x | |
| def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: | |
| return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN | |
| def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: | |
| return x / length(x, eps) | |
| def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor: | |
| return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w) | |