| | import numpy as np |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.autograd import Function |
| | from torch.autograd.function import once_differentiable |
| | from torch.cuda.amp import custom_bwd, custom_fwd |
| |
|
| | try: |
| | import _shencoder as _backend |
| | except ImportError: |
| | from .backend import _backend |
| |
|
| | class _sh_encoder(Function): |
| | @staticmethod |
| | @custom_fwd(cast_inputs=torch.float32) |
| | def forward(ctx, inputs, degree, calc_grad_inputs=False): |
| | |
| | |
| |
|
| | inputs = inputs.contiguous() |
| | B, input_dim = inputs.shape |
| | output_dim = degree ** 2 |
| | |
| | outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) |
| |
|
| | if calc_grad_inputs: |
| | dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device) |
| | else: |
| | dy_dx = None |
| |
|
| | _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx) |
| |
|
| | ctx.save_for_backward(inputs, dy_dx) |
| | ctx.dims = [B, input_dim, degree] |
| |
|
| | return outputs |
| | |
| | @staticmethod |
| | |
| | @custom_bwd |
| | def backward(ctx, grad): |
| | |
| |
|
| | inputs, dy_dx = ctx.saved_tensors |
| |
|
| | if dy_dx is not None: |
| | grad = grad.contiguous() |
| | B, input_dim, degree = ctx.dims |
| | grad_inputs = torch.zeros_like(inputs) |
| | _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs) |
| | return grad_inputs, None, None |
| | else: |
| | return None, None, None |
| |
|
| |
|
| |
|
| | sh_encode = _sh_encoder.apply |
| |
|
| |
|
| | class SHEncoder(nn.Module): |
| | def __init__(self, input_dim=3, degree=4): |
| | super().__init__() |
| |
|
| | self.input_dim = input_dim |
| | self.degree = degree |
| | self.output_dim = degree ** 2 |
| |
|
| | assert self.input_dim == 3, "SH encoder only support input dim == 3" |
| | assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]" |
| | |
| | def __repr__(self): |
| | return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}" |
| | |
| | def forward(self, inputs, size=1): |
| | |
| | |
| |
|
| | inputs = inputs / size |
| |
|
| | prefix_shape = list(inputs.shape[:-1]) |
| | inputs = inputs.reshape(-1, self.input_dim) |
| |
|
| | outputs = sh_encode(inputs, self.degree, inputs.requires_grad) |
| | outputs = outputs.reshape(prefix_shape + [self.output_dim]) |
| |
|
| | return outputs |