| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import NamedTuple |
| import torch.nn as nn |
| import torch |
| from . import _C |
|
|
| def cpu_deep_copy_tuple(input_tuple): |
| copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple] |
| return tuple(copied_tensors) |
|
|
| def rasterize_gaussians( |
| means3D, |
| means2D, |
| sh, |
| colors_precomp, |
| language_feature_precomp, |
| opacities, |
| scales, |
| rotations, |
| cov3Ds_precomp, |
| raster_settings, |
| ): |
| return _RasterizeGaussians.apply( |
| means3D, |
| means2D, |
| sh, |
| colors_precomp, |
| language_feature_precomp, |
| opacities, |
| scales, |
| rotations, |
| cov3Ds_precomp, |
| raster_settings, |
| ) |
|
|
| class _RasterizeGaussians(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, |
| means3D, |
| means2D, |
| sh, |
| colors_precomp, |
| language_feature_precomp, |
| opacities, |
| scales, |
| rotations, |
| cov3Ds_precomp, |
| raster_settings, |
| ): |
|
|
| |
| args = ( |
| raster_settings.bg, |
| means3D, |
| colors_precomp, |
| language_feature_precomp, |
| opacities, |
| scales, |
| rotations, |
| raster_settings.scale_modifier, |
| cov3Ds_precomp, |
| raster_settings.viewmatrix, |
| raster_settings.projmatrix, |
| raster_settings.tanfovx, |
| raster_settings.tanfovy, |
| raster_settings.image_height, |
| raster_settings.image_width, |
| sh, |
| raster_settings.sh_degree, |
| raster_settings.campos, |
| raster_settings.prefiltered, |
| raster_settings.debug, |
| raster_settings.include_feature |
| ) |
|
|
| |
| if raster_settings.debug: |
| cpu_args = cpu_deep_copy_tuple(args) |
| try: |
| num_rendered, color, language_feature, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) |
| except Exception as ex: |
| torch.save(cpu_args, "snapshot_fw.dump") |
| print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.") |
| raise ex |
| else: |
| num_rendered, color, language_feature, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) |
|
|
| |
| ctx.raster_settings = raster_settings |
| ctx.num_rendered = num_rendered |
| ctx.save_for_backward(colors_precomp, language_feature_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer) |
| return color, language_feature, radii |
|
|
| @staticmethod |
| def backward(ctx, grad_out_color, grad_out_language_feature, _): |
|
|
| |
| num_rendered = ctx.num_rendered |
| raster_settings = ctx.raster_settings |
| colors_precomp, language_feature_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors |
|
|
| |
| args = (raster_settings.bg, |
| means3D, |
| radii, |
| colors_precomp, |
| language_feature_precomp, |
| scales, |
| rotations, |
| raster_settings.scale_modifier, |
| cov3Ds_precomp, |
| raster_settings.viewmatrix, |
| raster_settings.projmatrix, |
| raster_settings.tanfovx, |
| raster_settings.tanfovy, |
| grad_out_color, |
| grad_out_language_feature, |
| sh, |
| raster_settings.sh_degree, |
| raster_settings.campos, |
| geomBuffer, |
| num_rendered, |
| binningBuffer, |
| imgBuffer, |
| raster_settings.debug, |
| raster_settings.include_feature) |
|
|
| |
| if raster_settings.debug: |
| cpu_args = cpu_deep_copy_tuple(args) |
| try: |
| grad_means2D, grad_colors_precomp, grad_language_feature_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) |
| except Exception as ex: |
| torch.save(cpu_args, "snapshot_bw.dump") |
| print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n") |
| raise ex |
| else: |
| grad_means2D, grad_colors_precomp, grad_language_feature_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) |
|
|
| grads = ( |
| grad_means3D, |
| grad_means2D, |
| grad_sh, |
| grad_colors_precomp, |
| grad_language_feature_precomp, |
| grad_opacities, |
| grad_scales, |
| grad_rotations, |
| grad_cov3Ds_precomp, |
| None, |
| ) |
|
|
| return grads |
|
|
| class GaussianRasterizationSettings(NamedTuple): |
| image_height: int |
| image_width: int |
| tanfovx : float |
| tanfovy : float |
| bg : torch.Tensor |
| scale_modifier : float |
| viewmatrix : torch.Tensor |
| projmatrix : torch.Tensor |
| sh_degree : int |
| campos : torch.Tensor |
| prefiltered : bool |
| debug : bool |
| include_feature: bool |
|
|
| class GaussianRasterizer(nn.Module): |
| def __init__(self, raster_settings): |
| super().__init__() |
| self.raster_settings = raster_settings |
|
|
| def markVisible(self, positions): |
| |
| with torch.no_grad(): |
| raster_settings = self.raster_settings |
| visible = _C.mark_visible( |
| positions, |
| raster_settings.viewmatrix, |
| raster_settings.projmatrix) |
| |
| return visible |
|
|
| def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, language_feature_precomp = None, scales = None, rotations = None, cov3D_precomp = None): |
| |
| raster_settings = self.raster_settings |
|
|
| if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None): |
| raise Exception('Please provide excatly one of either SHs or precomputed colors!') |
| |
| if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None): |
| raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!') |
| |
| if shs is None: |
| shs = torch.Tensor([]) |
| if colors_precomp is None: |
| colors_precomp = torch.Tensor([]) |
| if language_feature_precomp is None: |
| language_feature_precomp = torch.Tensor([]) |
| |
| if scales is None: |
| scales = torch.Tensor([]) |
| if rotations is None: |
| rotations = torch.Tensor([]) |
| if cov3D_precomp is None: |
| cov3D_precomp = torch.Tensor([]) |
|
|
| |
| return rasterize_gaussians( |
| means3D, |
| means2D, |
| shs, |
| colors_precomp, |
| language_feature_precomp, |
| opacities, |
| scales, |
| rotations, |
| cov3D_precomp, |
| raster_settings, |
| ) |
|
|
|
|