Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from typing import Optional | |
| import torch | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| from spar3d.models.transformers.transformer_1d import Transformer1D | |
| from spar3d.models.utils import BaseModule | |
| class TransformerPointTokenizer(BaseModule): | |
| class Config(BaseModule.Config): | |
| num_attention_heads: int = 16 | |
| attention_head_dim: int = 64 | |
| in_channels: Optional[int] = 6 | |
| out_channels: Optional[int] = 1024 | |
| num_layers: int = 16 | |
| norm_num_groups: int = 32 | |
| attention_bias: bool = False | |
| activation_fn: str = "geglu" | |
| norm_elementwise_affine: bool = True | |
| cfg: Config | |
| def configure(self) -> None: | |
| transformer_cfg = dict(self.cfg.copy()) | |
| # remove the non-transformer configs | |
| transformer_cfg["in_channels"] = ( | |
| self.cfg.num_attention_heads * self.cfg.attention_head_dim | |
| ) | |
| self.model = Transformer1D(transformer_cfg) | |
| self.linear_in = torch.nn.Linear( | |
| self.cfg.in_channels, transformer_cfg["in_channels"] | |
| ) | |
| self.linear_out = torch.nn.Linear( | |
| transformer_cfg["in_channels"], self.cfg.out_channels | |
| ) | |
| def forward( | |
| self, points: Float[Tensor, "B N Ci"], **kwargs | |
| ) -> Float[Tensor, "B N Cp"]: | |
| assert points.ndim == 3 | |
| inputs = self.linear_in(points).permute(0, 2, 1) # B N Ci -> B Ci N | |
| out = self.model(inputs).permute(0, 2, 1) # B Ci N -> B N Ci | |
| out = self.linear_out(out) # B N Ci -> B N Co | |
| return out | |
| def detokenize(self, *args, **kwargs): | |
| raise NotImplementedError | |