| | import torch |
| | from diffusers.configuration_utils import ConfigMixin, register_to_config |
| | from diffusers import ModelMixin |
| | from torch import Tensor |
| |
|
| | from .temporaltrans.temptrans import SimpleTransModel |
| |
|
| | class PointModel(ModelMixin, ConfigMixin): |
| | @register_to_config |
| | def __init__( |
| | self, |
| | model_type: str = 'pvcnn', |
| | in_channels: int = 3, |
| | out_channels: int = 3, |
| | embed_dim: int = 64, |
| | dropout: float = 0.1, |
| | width_multiplier: int = 1, |
| | voxel_resolution_multiplier: int = 1, |
| | ): |
| | super().__init__() |
| | self.model_type = model_type |
| | if self.model_type == 'simple': |
| | self.autocast_context = torch.autocast('cuda', dtype=torch.float32) |
| | self.model = SimpleTransModel( |
| | embed_dim=embed_dim, |
| | num_classes=out_channels, |
| | extra_feature_channels=(in_channels - 3), |
| | ) |
| | self.model.output_projection.bias.data.normal_(0, 1e-6) |
| | self.model.output_projection.weight.data.normal_(0, 1e-6) |
| | else: |
| | raise NotImplementedError() |
| |
|
| | def forward(self, inputs: Tensor, t: Tensor, context=None) -> Tensor: |
| | """ Receives input of shape (B, N, in_channels) and returns output |
| | of shape (B, N, out_channels) """ |
| | with self.autocast_context: |
| | return self.model(inputs, t, context) |
| |
|