Spaces:
Runtime error
Runtime error
| # Open Source Model Licensed under the Apache License Version 2.0 | |
| # and Other Licenses of the Third-Party Components therein: | |
| # The below Model in this distribution may have been modified by THL A29 Limited | |
| # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited. | |
| # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |
| # The below software and/or models in this distribution may have been | |
| # modified by THL A29 Limited ("Tencent Modifications"). | |
| # All Tencent Modifications are Copyright (C) THL A29 Limited. | |
| # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT | |
| # except for the third-party components listed below. | |
| # Hunyuan 3D does not impose any additional limitations beyond what is outlined | |
| # in the repsective licenses of these third-party components. | |
| # Users must comply with all terms and conditions of original licenses of these third-party | |
| # components and must ensure that the usage of the third party components adheres to | |
| # all relevant laws and regulations. | |
| # For avoidance of doubts, Hunyuan 3D means the large language models and | |
| # their software and algorithms, including trained model weights, parameters (including | |
| # optimizer states), machine-learning model code, inference-enabling code, training-enabling code, | |
| # fine-tuning enabling code and other elements of the foregoing made publicly available | |
| # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. | |
| import os | |
| from typing import Optional, Union, List | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from torch import Tensor | |
| from .attention_processors import CrossAttentionProcessor | |
| from ...utils.misc import logger | |
| scaled_dot_product_attention = nn.functional.scaled_dot_product_attention | |
| if os.environ.get("USE_SAGEATTN", "0") == "1": | |
| try: | |
| from sageattention import sageattn | |
| except ImportError: | |
| raise ImportError( | |
| 'Please install the package "sageattention" to use this USE_SAGEATTN.' | |
| ) | |
| scaled_dot_product_attention = sageattn | |
| class FourierEmbedder(nn.Module): | |
| """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts | |
| each feature dimension of `x[..., i]` into: | |
| [ | |
| sin(x[..., i]), | |
| sin(f_1*x[..., i]), | |
| sin(f_2*x[..., i]), | |
| ... | |
| sin(f_N * x[..., i]), | |
| cos(x[..., i]), | |
| cos(f_1*x[..., i]), | |
| cos(f_2*x[..., i]), | |
| ... | |
| cos(f_N * x[..., i]), | |
| x[..., i] # only present if include_input is True. | |
| ], here f_i is the frequency. | |
| Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. | |
| If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; | |
| Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. | |
| Args: | |
| num_freqs (int): the number of frequencies, default is 6; | |
| logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], | |
| otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; | |
| input_dim (int): the input dimension, default is 3; | |
| include_input (bool): include the input tensor or not, default is True. | |
| Attributes: | |
| frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], | |
| otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); | |
| out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), | |
| otherwise, it is input_dim * num_freqs * 2. | |
| """ | |
| def __init__( | |
| self, | |
| num_freqs: int = 6, | |
| logspace: bool = True, | |
| input_dim: int = 3, | |
| include_input: bool = True, | |
| include_pi: bool = True, | |
| ) -> None: | |
| """The initialization""" | |
| super().__init__() | |
| if logspace: | |
| frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32) | |
| else: | |
| frequencies = torch.linspace( | |
| 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32 | |
| ) | |
| if include_pi: | |
| frequencies *= torch.pi | |
| self.register_buffer("frequencies", frequencies, persistent=False) | |
| self.include_input = include_input | |
| self.num_freqs = num_freqs | |
| self.out_dim = self.get_dims(input_dim) | |
| def get_dims(self, input_dim): | |
| temp = 1 if self.include_input or self.num_freqs == 0 else 0 | |
| out_dim = input_dim * (self.num_freqs * 2 + temp) | |
| return out_dim | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Forward process. | |
| Args: | |
| x: tensor of shape [..., dim] | |
| Returns: | |
| embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] | |
| where temp is 1 if include_input is True and 0 otherwise. | |
| """ | |
| if self.num_freqs > 0: | |
| embed = (x[..., None].contiguous() * self.frequencies).view( | |
| *x.shape[:-1], -1 | |
| ) | |
| if self.include_input: | |
| return torch.cat((x, embed.sin(), embed.cos()), dim=-1) | |
| else: | |
| return torch.cat((embed.sin(), embed.cos()), dim=-1) | |
| else: | |
| return x | |
| class DropPath(nn.Module): | |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" | |
| def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): | |
| super(DropPath, self).__init__() | |
| self.drop_prob = drop_prob | |
| self.scale_by_keep = scale_by_keep | |
| def forward(self, x): | |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |
| This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, | |
| the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... | |
| See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for | |
| changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use | |
| 'survival rate' as the argument. | |
| """ | |
| if self.drop_prob == 0.0 or not self.training: | |
| return x | |
| keep_prob = 1 - self.drop_prob | |
| shape = (x.shape[0],) + (1,) * ( | |
| x.ndim - 1 | |
| ) # work with diff dim tensors, not just 2D ConvNets | |
| random_tensor = x.new_empty(shape).bernoulli_(keep_prob) | |
| if keep_prob > 0.0 and self.scale_by_keep: | |
| random_tensor.div_(keep_prob) | |
| return x * random_tensor | |
| def extra_repr(self): | |
| return f"drop_prob={round(self.drop_prob, 3):0.3f}" | |
| class MLP(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| width: int, | |
| expand_ratio: int = 4, | |
| output_width: int = None, | |
| drop_path_rate: float = 0.0, | |
| ): | |
| super().__init__() | |
| self.width = width | |
| self.c_fc = nn.Linear(width, width * expand_ratio) | |
| self.c_proj = nn.Linear( | |
| width * expand_ratio, output_width if output_width is not None else width | |
| ) | |
| self.gelu = nn.GELU() | |
| self.drop_path = ( | |
| DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() | |
| ) | |
| def forward(self, x): | |
| return self.drop_path(self.c_proj(self.gelu(self.c_fc(x)))) | |
| class QKVMultiheadCrossAttention(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| heads: int, | |
| width=None, | |
| qk_norm=False, | |
| norm_layer=nn.LayerNorm, | |
| ): | |
| super().__init__() | |
| self.heads = heads | |
| self.q_norm = ( | |
| norm_layer(width // heads, elementwise_affine=True, eps=1e-6) | |
| if qk_norm | |
| else nn.Identity() | |
| ) | |
| self.k_norm = ( | |
| norm_layer(width // heads, elementwise_affine=True, eps=1e-6) | |
| if qk_norm | |
| else nn.Identity() | |
| ) | |
| self.attn_processor = CrossAttentionProcessor() | |
| def forward(self, q, kv): | |
| _, n_ctx, _ = q.shape | |
| bs, n_data, width = kv.shape | |
| attn_ch = width // self.heads // 2 | |
| q = q.view(bs, n_ctx, self.heads, -1) | |
| kv = kv.view(bs, n_data, self.heads, -1) | |
| k, v = torch.split(kv, attn_ch, dim=-1) | |
| q = self.q_norm(q) | |
| k = self.k_norm(k) | |
| q, k, v = map( | |
| lambda t: rearrange(t, "b n h d -> b h n d", h=self.heads), (q, k, v) | |
| ) | |
| out = self.attn_processor(self, q, k, v) | |
| out = out.transpose(1, 2).reshape(bs, n_ctx, -1) | |
| return out | |
| class MultiheadCrossAttention(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| width: int, | |
| heads: int, | |
| qkv_bias: bool = True, | |
| data_width: Optional[int] = None, | |
| norm_layer=nn.LayerNorm, | |
| qk_norm: bool = False, | |
| kv_cache: bool = False, | |
| ): | |
| super().__init__() | |
| self.width = width | |
| self.heads = heads | |
| self.data_width = width if data_width is None else data_width | |
| self.c_q = nn.Linear(width, width, bias=qkv_bias) | |
| self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias) | |
| self.c_proj = nn.Linear(width, width) | |
| self.attention = QKVMultiheadCrossAttention( | |
| heads=heads, | |
| width=width, | |
| norm_layer=norm_layer, | |
| qk_norm=qk_norm, | |
| ) | |
| self.kv_cache = kv_cache | |
| self.data = None | |
| def forward(self, x, data): | |
| x = self.c_q(x) | |
| if self.kv_cache: | |
| if self.data is None: | |
| self.data = self.c_kv(data) | |
| logger.info( | |
| "Save kv cache,this should be called only once for one mesh" | |
| ) | |
| data = self.data | |
| else: | |
| data = self.c_kv(data) | |
| x = self.attention(x, data) | |
| x = self.c_proj(x) | |
| return x | |
| class ResidualCrossAttentionBlock(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| width: int, | |
| heads: int, | |
| mlp_expand_ratio: int = 4, | |
| data_width: Optional[int] = None, | |
| qkv_bias: bool = True, | |
| norm_layer=nn.LayerNorm, | |
| qk_norm: bool = False, | |
| ): | |
| super().__init__() | |
| if data_width is None: | |
| data_width = width | |
| self.attn = MultiheadCrossAttention( | |
| width=width, | |
| heads=heads, | |
| data_width=data_width, | |
| qkv_bias=qkv_bias, | |
| norm_layer=norm_layer, | |
| qk_norm=qk_norm, | |
| ) | |
| self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) | |
| self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6) | |
| self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6) | |
| self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio) | |
| def forward(self, x: torch.Tensor, data: torch.Tensor): | |
| x = x + self.attn(self.ln_1(x), self.ln_2(data)) | |
| x = x + self.mlp(self.ln_3(x)) | |
| return x | |
| class QKVMultiheadAttention(nn.Module): | |
| def __init__( | |
| self, *, heads: int, width=None, qk_norm=False, norm_layer=nn.LayerNorm | |
| ): | |
| super().__init__() | |
| self.heads = heads | |
| self.q_norm = ( | |
| norm_layer(width // heads, elementwise_affine=True, eps=1e-6) | |
| if qk_norm | |
| else nn.Identity() | |
| ) | |
| self.k_norm = ( | |
| norm_layer(width // heads, elementwise_affine=True, eps=1e-6) | |
| if qk_norm | |
| else nn.Identity() | |
| ) | |
| def forward(self, qkv): | |
| bs, n_ctx, width = qkv.shape | |
| attn_ch = width // self.heads // 3 | |
| qkv = qkv.view(bs, n_ctx, self.heads, -1) | |
| q, k, v = torch.split(qkv, attn_ch, dim=-1) | |
| q = self.q_norm(q) | |
| k = self.k_norm(k) | |
| q, k, v = map( | |
| lambda t: rearrange(t, "b n h d -> b h n d", h=self.heads), (q, k, v) | |
| ) | |
| out = ( | |
| scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) | |
| ) | |
| return out | |
| class MultiheadAttention(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| width: int, | |
| heads: int, | |
| qkv_bias: bool, | |
| norm_layer=nn.LayerNorm, | |
| qk_norm: bool = False, | |
| drop_path_rate: float = 0.0, | |
| ): | |
| super().__init__() | |
| self.width = width | |
| self.heads = heads | |
| self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias) | |
| self.c_proj = nn.Linear(width, width) | |
| self.attention = QKVMultiheadAttention( | |
| heads=heads, | |
| width=width, | |
| norm_layer=norm_layer, | |
| qk_norm=qk_norm, | |
| ) | |
| self.drop_path = ( | |
| DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() | |
| ) | |
| def forward(self, x): | |
| x = self.c_qkv(x) | |
| x = self.attention(x) | |
| x = self.drop_path(self.c_proj(x)) | |
| return x | |
| class ResidualAttentionBlock(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| width: int, | |
| heads: int, | |
| qkv_bias: bool = True, | |
| norm_layer=nn.LayerNorm, | |
| qk_norm: bool = False, | |
| drop_path_rate: float = 0.0, | |
| ): | |
| super().__init__() | |
| self.attn = MultiheadAttention( | |
| width=width, | |
| heads=heads, | |
| qkv_bias=qkv_bias, | |
| norm_layer=norm_layer, | |
| qk_norm=qk_norm, | |
| drop_path_rate=drop_path_rate, | |
| ) | |
| self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) | |
| self.mlp = MLP(width=width, drop_path_rate=drop_path_rate) | |
| self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6) | |
| def forward(self, x: torch.Tensor): | |
| x = x + self.attn(self.ln_1(x)) | |
| x = x + self.mlp(self.ln_2(x)) | |
| return x | |
| class Transformer(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| width: int, | |
| layers: int, | |
| heads: int, | |
| qkv_bias: bool = True, | |
| norm_layer=nn.LayerNorm, | |
| qk_norm: bool = False, | |
| drop_path_rate: float = 0.0, | |
| ): | |
| super().__init__() | |
| self.width = width | |
| self.layers = layers | |
| self.resblocks = nn.ModuleList([ | |
| ResidualAttentionBlock( | |
| width=width, | |
| heads=heads, | |
| qkv_bias=qkv_bias, | |
| norm_layer=norm_layer, | |
| qk_norm=qk_norm, | |
| drop_path_rate=drop_path_rate, | |
| ) | |
| for _ in range(layers) | |
| ]) | |
| def forward(self, x: torch.Tensor): | |
| for block in self.resblocks: | |
| x = block(x) | |
| return x | |
| class CrossAttentionDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| out_channels: int, | |
| fourier_embedder: FourierEmbedder, | |
| width: int, | |
| heads: int, | |
| mlp_expand_ratio: int = 4, | |
| downsample_ratio: int = 1, | |
| enable_ln_post: bool = True, | |
| qkv_bias: bool = True, | |
| qk_norm: bool = False, | |
| label_type: str = "binary", | |
| ): | |
| super().__init__() | |
| self.enable_ln_post = enable_ln_post | |
| self.fourier_embedder = fourier_embedder | |
| self.downsample_ratio = downsample_ratio | |
| self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width) | |
| if self.downsample_ratio != 1: | |
| self.latents_proj = nn.Linear(width * downsample_ratio, width) | |
| if self.enable_ln_post == False: | |
| qk_norm = False | |
| self.cross_attn_decoder = ResidualCrossAttentionBlock( | |
| width=width, | |
| mlp_expand_ratio=mlp_expand_ratio, | |
| heads=heads, | |
| qkv_bias=qkv_bias, | |
| qk_norm=qk_norm, | |
| ) | |
| if self.enable_ln_post: | |
| self.ln_post = nn.LayerNorm(width) | |
| self.output_proj = nn.Linear(width, out_channels) | |
| self.label_type = label_type | |
| self.count = 0 | |
| def set_cross_attention_processor(self, processor): | |
| self.cross_attn_decoder.attn.attention.attn_processor = processor | |
| # def set_default_cross_attention_processor(self): | |
| # self.cross_attn_decoder.attn.attention.attn_processor = CrossAttentionProcessor | |
| def forward(self, queries=None, query_embeddings=None, latents=None): | |
| if query_embeddings is None: | |
| query_embeddings = self.query_proj( | |
| self.fourier_embedder(queries).to(latents.dtype) | |
| ) | |
| self.count += query_embeddings.shape[1] | |
| if self.downsample_ratio != 1: | |
| latents = self.latents_proj(latents) | |
| x = self.cross_attn_decoder(query_embeddings, latents) | |
| if self.enable_ln_post: | |
| x = self.ln_post(x) | |
| occ = self.output_proj(x) | |
| return occ | |
| def fps( | |
| src: torch.Tensor, | |
| batch: Optional[Tensor] = None, | |
| ratio: Optional[Union[Tensor, float]] = None, | |
| random_start: bool = True, | |
| batch_size: Optional[int] = None, | |
| ptr: Optional[Union[Tensor, List[int]]] = None, | |
| ): | |
| src = src.float() | |
| from torch_cluster import fps as fps_fn | |
| output = fps_fn(src, batch, ratio, random_start, batch_size, ptr) | |
| return output | |
| class PointCrossAttentionEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| num_latents: int, | |
| downsample_ratio: float, | |
| pc_size: int, | |
| pc_sharpedge_size: int, | |
| fourier_embedder: FourierEmbedder, | |
| point_feats: int, | |
| width: int, | |
| heads: int, | |
| layers: int, | |
| normal_pe: bool = False, | |
| qkv_bias: bool = True, | |
| use_ln_post: bool = False, | |
| use_checkpoint: bool = False, | |
| qk_norm: bool = False, | |
| ): | |
| super().__init__() | |
| self.use_checkpoint = use_checkpoint | |
| self.num_latents = num_latents | |
| self.downsample_ratio = downsample_ratio | |
| self.point_feats = point_feats | |
| self.normal_pe = normal_pe | |
| if pc_sharpedge_size == 0: | |
| print( | |
| f"PointCrossAttentionEncoder INFO: pc_sharpedge_size is not given," | |
| f" using pc_size as pc_sharpedge_size" | |
| ) | |
| else: | |
| print( | |
| "PointCrossAttentionEncoder INFO: pc_sharpedge_size is given, using" | |
| f" pc_size={pc_size}, pc_sharpedge_size={pc_sharpedge_size}" | |
| ) | |
| self.pc_size = pc_size | |
| self.pc_sharpedge_size = pc_sharpedge_size | |
| self.fourier_embedder = fourier_embedder | |
| self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width) | |
| self.cross_attn = ResidualCrossAttentionBlock( | |
| width=width, heads=heads, qkv_bias=qkv_bias, qk_norm=qk_norm | |
| ) | |
| self.self_attn = None | |
| if layers > 0: | |
| self.self_attn = Transformer( | |
| width=width, | |
| layers=layers, | |
| heads=heads, | |
| qkv_bias=qkv_bias, | |
| qk_norm=qk_norm, | |
| ) | |
| if use_ln_post: | |
| self.ln_post = nn.LayerNorm(width) | |
| else: | |
| self.ln_post = None | |
| def sample_points_and_latents( | |
| self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None | |
| ): | |
| B, N, D = pc.shape | |
| num_pts = self.num_latents * self.downsample_ratio | |
| # Compute number of latents | |
| num_latents = int(num_pts / self.downsample_ratio) | |
| # Compute the number of random and sharpedge latents | |
| num_random_query = ( | |
| self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents | |
| ) | |
| num_sharpedge_query = num_latents - num_random_query | |
| # Split random and sharpedge surface points | |
| random_pc, sharpedge_pc = torch.split( | |
| pc, [self.pc_size, self.pc_sharpedge_size], dim=1 | |
| ) | |
| assert ( | |
| random_pc.shape[1] <= self.pc_size | |
| ), "Random surface points size must be less than or equal to pc_size" | |
| assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, ( | |
| "Sharpedge surface points size must be less than or equal to" | |
| " pc_sharpedge_size" | |
| ) | |
| # Randomly select random surface points and random query points | |
| input_random_pc_size = int(num_random_query * self.downsample_ratio) | |
| random_query_ratio = num_random_query / input_random_pc_size | |
| idx_random_pc = torch.randperm(random_pc.shape[1], device=random_pc.device)[ | |
| :input_random_pc_size | |
| ] | |
| input_random_pc = random_pc[:, idx_random_pc, :] | |
| flatten_input_random_pc = input_random_pc.view(B * input_random_pc_size, D) | |
| N_down = int(flatten_input_random_pc.shape[0] / B) | |
| batch_down = torch.arange(B).to(pc.device) | |
| batch_down = torch.repeat_interleave(batch_down, N_down) | |
| idx_query_random = fps( | |
| flatten_input_random_pc, batch_down, ratio=random_query_ratio | |
| ) | |
| query_random_pc = flatten_input_random_pc[idx_query_random].view(B, -1, D) | |
| # Randomly select sharpedge surface points and sharpedge query points | |
| input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio) | |
| if input_sharpedge_pc_size == 0: | |
| input_sharpedge_pc = torch.zeros(B, 0, D, dtype=input_random_pc.dtype).to( | |
| pc.device | |
| ) | |
| query_sharpedge_pc = torch.zeros(B, 0, D, dtype=query_random_pc.dtype).to( | |
| pc.device | |
| ) | |
| else: | |
| sharpedge_query_ratio = num_sharpedge_query / input_sharpedge_pc_size | |
| idx_sharpedge_pc = torch.randperm( | |
| sharpedge_pc.shape[1], device=sharpedge_pc.device | |
| )[:input_sharpedge_pc_size] | |
| input_sharpedge_pc = sharpedge_pc[:, idx_sharpedge_pc, :] | |
| flatten_input_sharpedge_surface_points = input_sharpedge_pc.view( | |
| B * input_sharpedge_pc_size, D | |
| ) | |
| N_down = int(flatten_input_sharpedge_surface_points.shape[0] / B) | |
| batch_down = torch.arange(B).to(pc.device) | |
| batch_down = torch.repeat_interleave(batch_down, N_down) | |
| idx_query_sharpedge = fps( | |
| flatten_input_sharpedge_surface_points, | |
| batch_down, | |
| ratio=sharpedge_query_ratio, | |
| ) | |
| query_sharpedge_pc = flatten_input_sharpedge_surface_points[ | |
| idx_query_sharpedge | |
| ].view(B, -1, D) | |
| # Concatenate random and sharpedge surface points and query points | |
| query_pc = torch.cat([query_random_pc, query_sharpedge_pc], dim=1) | |
| input_pc = torch.cat([input_random_pc, input_sharpedge_pc], dim=1) | |
| # PE | |
| query = self.fourier_embedder(query_pc) | |
| data = self.fourier_embedder(input_pc) | |
| # Concat normal if given | |
| if self.point_feats != 0: | |
| random_surface_feats, sharpedge_surface_feats = torch.split( | |
| feats, [self.pc_size, self.pc_sharpedge_size], dim=1 | |
| ) | |
| input_random_surface_feats = random_surface_feats[:, idx_random_pc, :] | |
| flatten_input_random_surface_feats = input_random_surface_feats.view( | |
| B * input_random_pc_size, -1 | |
| ) | |
| query_random_feats = flatten_input_random_surface_feats[ | |
| idx_query_random | |
| ].view(B, -1, flatten_input_random_surface_feats.shape[-1]) | |
| if input_sharpedge_pc_size == 0: | |
| input_sharpedge_surface_feats = torch.zeros( | |
| B, 0, self.point_feats, dtype=input_random_surface_feats.dtype | |
| ).to(pc.device) | |
| query_sharpedge_feats = torch.zeros( | |
| B, 0, self.point_feats, dtype=query_random_feats.dtype | |
| ).to(pc.device) | |
| else: | |
| input_sharpedge_surface_feats = sharpedge_surface_feats[ | |
| :, idx_sharpedge_pc, : | |
| ] | |
| flatten_input_sharpedge_surface_feats = ( | |
| input_sharpedge_surface_feats.view(B * input_sharpedge_pc_size, -1) | |
| ) | |
| query_sharpedge_feats = flatten_input_sharpedge_surface_feats[ | |
| idx_query_sharpedge | |
| ].view(B, -1, flatten_input_sharpedge_surface_feats.shape[-1]) | |
| query_feats = torch.cat([query_random_feats, query_sharpedge_feats], dim=1) | |
| input_feats = torch.cat( | |
| [input_random_surface_feats, input_sharpedge_surface_feats], dim=1 | |
| ) | |
| if self.normal_pe: | |
| query_normal_pe = self.fourier_embedder(query_feats[..., :3]) | |
| input_normal_pe = self.fourier_embedder(input_feats[..., :3]) | |
| query_feats = torch.cat([query_normal_pe, query_feats[..., 3:]], dim=-1) | |
| input_feats = torch.cat([input_normal_pe, input_feats[..., 3:]], dim=-1) | |
| query = torch.cat([query, query_feats], dim=-1) | |
| data = torch.cat([data, input_feats], dim=-1) | |
| if input_sharpedge_pc_size == 0: | |
| query_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device) | |
| input_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device) | |
| # print(f'query_pc: {query_pc.shape}') | |
| # print(f'input_pc: {input_pc.shape}') | |
| # print(f'query_random_pc: {query_random_pc.shape}') | |
| # print(f'input_random_pc: {input_random_pc.shape}') | |
| # print(f'query_sharpedge_pc: {query_sharpedge_pc.shape}') | |
| # print(f'input_sharpedge_pc: {input_sharpedge_pc.shape}') | |
| return ( | |
| query.view(B, -1, query.shape[-1]), | |
| data.view(B, -1, data.shape[-1]), | |
| [ | |
| query_pc, | |
| input_pc, | |
| query_random_pc, | |
| input_random_pc, | |
| query_sharpedge_pc, | |
| input_sharpedge_pc, | |
| ], | |
| ) | |
| def forward(self, pc, feats): | |
| """ | |
| Args: | |
| pc (torch.FloatTensor): [B, N, 3] | |
| feats (torch.FloatTensor or None): [B, N, C] | |
| Returns: | |
| """ | |
| query, data, pc_infos = self.sample_points_and_latents(pc, feats) | |
| query = self.input_proj(query) | |
| query = query | |
| data = self.input_proj(data) | |
| data = data | |
| latents = self.cross_attn(query, data) | |
| if self.self_attn is not None: | |
| latents = self.self_attn(latents) | |
| if self.ln_post is not None: | |
| latents = self.ln_post(latents) | |
| return latents, pc_infos | |