| | from abc import abstractmethod |
| | from typing import Dict, Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from .perceiver import SimplePerceiver |
| | from .transformer import Transformer |
| |
|
| |
|
| | class PointCloudSDFModel(nn.Module): |
| | @property |
| | @abstractmethod |
| | def device(self) -> torch.device: |
| | """ |
| | Get the device that should be used for input tensors. |
| | """ |
| |
|
| | @property |
| | @abstractmethod |
| | def default_batch_size(self) -> int: |
| | """ |
| | Get a reasonable default number of query points for the model. |
| | In some cases, this might be the only supported size. |
| | """ |
| |
|
| | @abstractmethod |
| | def encode_point_clouds(self, point_clouds: torch.Tensor) -> Dict[str, torch.Tensor]: |
| | """ |
| | Encode a batch of point clouds to cache part of the SDF calculation |
| | done by forward(). |
| | |
| | :param point_clouds: a batch of [batch x 3 x N] points. |
| | :return: a state representing the encoded point cloud batch. |
| | """ |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | point_clouds: Optional[torch.Tensor] = None, |
| | encoded: Optional[Dict[str, torch.Tensor]] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Predict the SDF at the coordinates x, given a batch of point clouds. |
| | |
| | Either point_clouds or encoded should be passed. Only exactly one of |
| | these arguments should be None. |
| | |
| | :param x: a [batch x 3 x N'] tensor of query points. |
| | :param point_clouds: a [batch x 3 x N] batch of point clouds. |
| | :param encoded: the result of calling encode_point_clouds(). |
| | :return: a [batch x N'] tensor of SDF predictions. |
| | """ |
| | assert point_clouds is not None or encoded is not None |
| | assert point_clouds is None or encoded is None |
| | if point_clouds is not None: |
| | encoded = self.encode_point_clouds(point_clouds) |
| | return self.predict_sdf(x, encoded) |
| |
|
| | @abstractmethod |
| | def predict_sdf( |
| | self, x: torch.Tensor, encoded: Optional[Dict[str, torch.Tensor]] |
| | ) -> torch.Tensor: |
| | """ |
| | Predict the SDF at the query points given the encoded point clouds. |
| | |
| | Each query point should be treated independently, only conditioning on |
| | the point clouds themselves. |
| | """ |
| |
|
| |
|
| | class CrossAttentionPointCloudSDFModel(PointCloudSDFModel): |
| | """ |
| | Encode point clouds using a transformer, and query points using cross |
| | attention to the encoded latents. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | *, |
| | device: torch.device, |
| | dtype: torch.dtype, |
| | n_ctx: int = 4096, |
| | width: int = 512, |
| | encoder_layers: int = 12, |
| | encoder_heads: int = 8, |
| | decoder_layers: int = 4, |
| | decoder_heads: int = 8, |
| | init_scale: float = 0.25, |
| | ): |
| | super().__init__() |
| | self._device = device |
| | self.n_ctx = n_ctx |
| |
|
| | self.encoder_input_proj = nn.Linear(3, width, device=device, dtype=dtype) |
| | self.encoder = Transformer( |
| | device=device, |
| | dtype=dtype, |
| | n_ctx=n_ctx, |
| | width=width, |
| | layers=encoder_layers, |
| | heads=encoder_heads, |
| | init_scale=init_scale, |
| | ) |
| | self.decoder_input_proj = nn.Linear(3, width, device=device, dtype=dtype) |
| | self.decoder = SimplePerceiver( |
| | device=device, |
| | dtype=dtype, |
| | n_data=n_ctx, |
| | width=width, |
| | layers=decoder_layers, |
| | heads=decoder_heads, |
| | init_scale=init_scale, |
| | ) |
| | self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) |
| | self.output_proj = nn.Linear(width, 1, device=device, dtype=dtype) |
| |
|
| | @property |
| | def device(self) -> torch.device: |
| | return self._device |
| |
|
| | @property |
| | def default_batch_size(self) -> int: |
| | return self.n_query |
| |
|
| | def encode_point_clouds(self, point_clouds: torch.Tensor) -> Dict[str, torch.Tensor]: |
| | h = self.encoder_input_proj(point_clouds.permute(0, 2, 1)) |
| | h = self.encoder(h) |
| | return dict(latents=h) |
| |
|
| | def predict_sdf( |
| | self, x: torch.Tensor, encoded: Optional[Dict[str, torch.Tensor]] |
| | ) -> torch.Tensor: |
| | data = encoded["latents"] |
| | x = self.decoder_input_proj(x.permute(0, 2, 1)) |
| | x = self.decoder(x, data) |
| | x = self.ln_post(x) |
| | x = self.output_proj(x) |
| | return x[..., 0] |
| |
|