|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Optional, Union |
|
|
from einops import repeat |
|
|
import math |
|
|
from torch_cluster import fps |
|
|
import random |
|
|
import time |
|
|
import numpy as np |
|
|
|
|
|
from ..modules import checkpoint |
|
|
from ..modules.embedder import FourierEmbedder |
|
|
from ..modules.transformer_blocks import ( |
|
|
ResidualCrossAttentionBlock, |
|
|
Transformer |
|
|
) |
|
|
|
|
|
from .tsal_base import ShapeAsLatentModule |
|
|
|
|
|
|
|
|
class CrossAttentionEncoder(nn.Module): |
|
|
|
|
|
def __init__(self, *, |
|
|
device: Optional[torch.device], |
|
|
dtype: Optional[torch.dtype], |
|
|
num_latents: int, |
|
|
fourier_embedder: FourierEmbedder, |
|
|
point_feats: int, |
|
|
width: int, |
|
|
heads: int, |
|
|
layers: int, |
|
|
init_scale: float = 0.25, |
|
|
qkv_bias: bool = True, |
|
|
flash: bool = False, |
|
|
use_ln_post: bool = False, |
|
|
use_checkpoint: bool = False, |
|
|
query_method: bool = False, |
|
|
use_full_input: bool = True, |
|
|
token_num: int = 256, |
|
|
no_query: bool=False): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.query_method = query_method |
|
|
self.token_num = token_num |
|
|
self.use_full_input = use_full_input |
|
|
|
|
|
self.use_checkpoint = use_checkpoint |
|
|
self.num_latents = num_latents |
|
|
|
|
|
if no_query: |
|
|
self.query = None |
|
|
else: |
|
|
self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02) |
|
|
|
|
|
self.fourier_embedder = fourier_embedder |
|
|
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype) |
|
|
self.cross_attn = ResidualCrossAttentionBlock( |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
width=width, |
|
|
heads=heads, |
|
|
init_scale=init_scale, |
|
|
qkv_bias=qkv_bias, |
|
|
flash=flash, |
|
|
) |
|
|
|
|
|
self.self_attn = Transformer( |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
n_ctx=num_latents, |
|
|
width=width, |
|
|
layers=layers, |
|
|
heads=heads, |
|
|
init_scale=init_scale, |
|
|
qkv_bias=qkv_bias, |
|
|
flash=flash, |
|
|
use_checkpoint=False |
|
|
) |
|
|
|
|
|
if use_ln_post: |
|
|
self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device) |
|
|
else: |
|
|
self.ln_post = None |
|
|
|
|
|
def _forward(self, pc, feats): |
|
|
""" |
|
|
|
|
|
Args: |
|
|
pc (torch.FloatTensor): [B, N, 3] |
|
|
feats (torch.FloatTensor or None): [B, N, C] |
|
|
|
|
|
Returns: |
|
|
|
|
|
""" |
|
|
if self.query_method: |
|
|
token_num = self.num_latents |
|
|
bs = pc.shape[0] |
|
|
data = self.fourier_embedder(pc) |
|
|
if feats is not None: |
|
|
data = torch.cat([data, feats], dim=-1) |
|
|
data = self.input_proj(data) |
|
|
|
|
|
query = repeat(self.query, "m c -> b m c", b=bs) |
|
|
|
|
|
latents = self.cross_attn(query, data) |
|
|
latents = self.self_attn(latents) |
|
|
|
|
|
if self.ln_post is not None: |
|
|
latents = self.ln_post(latents) |
|
|
|
|
|
pre_pc = None |
|
|
else: |
|
|
|
|
|
if isinstance(self.token_num, int): |
|
|
token_num = self.token_num |
|
|
else: |
|
|
token_num = random.choice(self.token_num) |
|
|
|
|
|
if self.training: |
|
|
rng = np.random.default_rng() |
|
|
else: |
|
|
rng = np.random.default_rng(seed=0) |
|
|
ind = rng.choice(pc.shape[1], token_num * 4, replace=token_num * 4 > pc.shape[1]) |
|
|
|
|
|
pre_pc = pc[:,ind,:] |
|
|
pre_feats = feats[:,ind,:] |
|
|
|
|
|
|
|
|
B, N, D = pre_pc.shape |
|
|
C = pre_feats.shape[-1] |
|
|
|
|
|
pos = pre_pc.view(B*N, D) |
|
|
pos_feats = pre_feats.view(B*N, C) |
|
|
batch = torch.arange(B).to(pc.device) |
|
|
batch = torch.repeat_interleave(batch, N) |
|
|
|
|
|
idx = fps(pos, batch, ratio=1. / 4, random_start=self.training) |
|
|
|
|
|
sampled_pc = pos[idx] |
|
|
sampled_pc = sampled_pc.view(B, -1, 3) |
|
|
|
|
|
sampled_feats = pos_feats[idx] |
|
|
sampled_feats = sampled_feats.view(B, -1, C) |
|
|
|
|
|
|
|
|
if self.use_full_input: |
|
|
data = self.fourier_embedder(pc) |
|
|
else: |
|
|
data = self.fourier_embedder(pre_pc) |
|
|
|
|
|
if feats is not None: |
|
|
if not self.use_full_input: |
|
|
feats = pre_feats |
|
|
data = torch.cat([data, feats], dim=-1) |
|
|
data = self.input_proj(data) |
|
|
|
|
|
sampled_data = self.fourier_embedder(sampled_pc) |
|
|
if feats is not None: |
|
|
sampled_data = torch.cat([sampled_data, sampled_feats], dim=-1) |
|
|
sampled_data = self.input_proj(sampled_data) |
|
|
|
|
|
latents = self.cross_attn(sampled_data, data) |
|
|
latents = self.self_attn(latents) |
|
|
|
|
|
if self.ln_post is not None: |
|
|
latents = self.ln_post(latents) |
|
|
|
|
|
pre_pc = torch.cat([pre_pc, pre_feats], dim=-1) |
|
|
|
|
|
return latents, pc, token_num, pre_pc |
|
|
|
|
|
def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None): |
|
|
""" |
|
|
|
|
|
Args: |
|
|
pc (torch.FloatTensor): [B, N, 3] |
|
|
feats (torch.FloatTensor or None): [B, N, C] |
|
|
|
|
|
Returns: |
|
|
dict |
|
|
""" |
|
|
|
|
|
return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint) |
|
|
|
|
|
|
|
|
class CrossAttentionDecoder(nn.Module): |
|
|
|
|
|
def __init__(self, *, |
|
|
device: Optional[torch.device], |
|
|
dtype: Optional[torch.dtype], |
|
|
num_latents: int, |
|
|
out_channels: int, |
|
|
fourier_embedder: FourierEmbedder, |
|
|
width: int, |
|
|
heads: int, |
|
|
init_scale: float = 0.25, |
|
|
qkv_bias: bool = True, |
|
|
flash: bool = False, |
|
|
use_checkpoint: bool = False, |
|
|
mlp_width_scale: int = 4, |
|
|
supervision_type: str = 'occupancy'): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.use_checkpoint = use_checkpoint |
|
|
self.fourier_embedder = fourier_embedder |
|
|
self.supervision_type = supervision_type |
|
|
|
|
|
self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype) |
|
|
|
|
|
self.cross_attn_decoder = ResidualCrossAttentionBlock( |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
n_data=num_latents, |
|
|
width=width, |
|
|
heads=heads, |
|
|
init_scale=init_scale, |
|
|
qkv_bias=qkv_bias, |
|
|
flash=flash, |
|
|
mlp_width_scale=mlp_width_scale, |
|
|
) |
|
|
|
|
|
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) |
|
|
self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype) |
|
|
if self.supervision_type == 'occupancy-sdf': |
|
|
self.output_proj_sdf = nn.Linear(width, out_channels, device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
|
|
|
def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): |
|
|
if next(self.query_proj.parameters()).dtype == torch.float16: |
|
|
queries = queries.half() |
|
|
latents = latents.half() |
|
|
|
|
|
|
|
|
queries = self.query_proj(self.fourier_embedder(queries)) |
|
|
x = self.cross_attn_decoder(queries, latents) |
|
|
x = self.ln_post(x) |
|
|
x_1 = self.output_proj(x) |
|
|
if self.supervision_type == 'occupancy-sdf': |
|
|
x_2 = self.output_proj_sdf(x) |
|
|
return x_1, x_2 |
|
|
else: |
|
|
return x_1 |
|
|
|
|
|
def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): |
|
|
return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint) |
|
|
|
|
|
|
|
|
class ShapeAsLatentPerceiver(ShapeAsLatentModule): |
|
|
def __init__(self, *, |
|
|
device: Optional[torch.device], |
|
|
dtype: Optional[torch.dtype], |
|
|
num_latents: int, |
|
|
point_feats: int = 0, |
|
|
embed_dim: int = 0, |
|
|
num_freqs: int = 8, |
|
|
include_pi: bool = True, |
|
|
width: int, |
|
|
heads: int, |
|
|
num_encoder_layers: int, |
|
|
num_decoder_layers: int, |
|
|
decoder_width: Optional[int] = None, |
|
|
init_scale: float = 0.25, |
|
|
qkv_bias: bool = True, |
|
|
flash: bool = False, |
|
|
use_ln_post: bool = False, |
|
|
use_checkpoint: bool = False, |
|
|
supervision_type: str = 'occupancy', |
|
|
query_method: bool = False, |
|
|
token_num: int = 256, |
|
|
grad_type: str = "numerical", |
|
|
grad_interval: float = 0.005, |
|
|
use_full_input: bool = True, |
|
|
freeze_encoder: bool = False, |
|
|
decoder_mlp_width_scale: int = 4, |
|
|
residual_kl: bool = False, |
|
|
): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.use_checkpoint = use_checkpoint |
|
|
|
|
|
self.num_latents = num_latents |
|
|
assert grad_type in ["numerical", "analytical"] |
|
|
self.grad_type = grad_type |
|
|
self.grad_interval = grad_interval |
|
|
self.supervision_type = supervision_type |
|
|
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) |
|
|
|
|
|
init_scale = init_scale * math.sqrt(1.0 / width) |
|
|
self.encoder = CrossAttentionEncoder( |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
fourier_embedder=self.fourier_embedder, |
|
|
num_latents=num_latents, |
|
|
point_feats=point_feats, |
|
|
width=width, |
|
|
heads=heads, |
|
|
layers=num_encoder_layers, |
|
|
init_scale=init_scale, |
|
|
qkv_bias=qkv_bias, |
|
|
flash=flash, |
|
|
use_ln_post=use_ln_post, |
|
|
use_checkpoint=use_checkpoint, |
|
|
query_method=query_method, |
|
|
use_full_input=use_full_input, |
|
|
token_num=token_num |
|
|
) |
|
|
|
|
|
self.embed_dim = embed_dim |
|
|
self.residual_kl = residual_kl |
|
|
if decoder_width is None: |
|
|
decoder_width = width |
|
|
if embed_dim > 0: |
|
|
|
|
|
self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype) |
|
|
self.post_kl = nn.Linear(embed_dim, decoder_width, device=device, dtype=dtype) |
|
|
self.latent_shape = (num_latents, embed_dim) |
|
|
if self.residual_kl: |
|
|
assert self.post_kl.out_features % self.post_kl.in_features == 0 |
|
|
assert self.pre_kl.in_features % self.pre_kl.out_features == 0 |
|
|
else: |
|
|
self.latent_shape = (num_latents, width) |
|
|
|
|
|
self.transformer = Transformer( |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
n_ctx=num_latents, |
|
|
width=decoder_width, |
|
|
layers=num_decoder_layers, |
|
|
heads=heads, |
|
|
init_scale=init_scale, |
|
|
qkv_bias=qkv_bias, |
|
|
flash=flash, |
|
|
use_checkpoint=use_checkpoint |
|
|
) |
|
|
|
|
|
|
|
|
self.geo_decoder = CrossAttentionDecoder( |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
fourier_embedder=self.fourier_embedder, |
|
|
out_channels=1, |
|
|
num_latents=num_latents, |
|
|
width=decoder_width, |
|
|
heads=heads, |
|
|
init_scale=init_scale, |
|
|
qkv_bias=qkv_bias, |
|
|
flash=flash, |
|
|
use_checkpoint=use_checkpoint, |
|
|
supervision_type=supervision_type, |
|
|
mlp_width_scale=decoder_mlp_width_scale |
|
|
) |
|
|
|
|
|
if freeze_encoder: |
|
|
for p in self.encoder.parameters(): |
|
|
p.requires_grad = False |
|
|
for p in self.pre_kl.parameters(): |
|
|
p.requires_grad = False |
|
|
print("freeze encoder and pre kl") |
|
|
|
|
|
def forward(self, |
|
|
pc: torch.FloatTensor, |
|
|
feats: torch.FloatTensor, |
|
|
volume_queries: torch.FloatTensor, |
|
|
sample_posterior: bool = True): |
|
|
""" |
|
|
|
|
|
Args: |
|
|
pc (torch.FloatTensor): [B, N, 3] |
|
|
feats (torch.FloatTensor or None): [B, N, C] |
|
|
volume_queries (torch.FloatTensor): [B, P, 3] |
|
|
sample_posterior (bool): |
|
|
|
|
|
Returns: |
|
|
logits (torch.FloatTensor): [B, P] |
|
|
center_pos (torch.FloatTensor): [B, M, 3] |
|
|
posterior (DiagonalGaussianDistribution or None). |
|
|
|
|
|
""" |
|
|
|
|
|
latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) |
|
|
|
|
|
latents = self.decode(latents) |
|
|
logits = self.query_geometry(volume_queries, latents) |
|
|
|
|
|
return logits, center_pos, posterior |
|
|
|
|
|
|
|
|
class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver): |
|
|
|
|
|
def __init__(self, *, |
|
|
device: Optional[torch.device], |
|
|
dtype: Optional[str], |
|
|
num_latents: int, |
|
|
point_feats: int = 0, |
|
|
embed_dim: int = 0, |
|
|
num_freqs: int = 8, |
|
|
include_pi: bool = True, |
|
|
width: int, |
|
|
heads: int, |
|
|
num_encoder_layers: int, |
|
|
num_decoder_layers: int, |
|
|
decoder_width: Optional[int] = None, |
|
|
init_scale: float = 0.25, |
|
|
qkv_bias: bool = True, |
|
|
flash: bool = False, |
|
|
use_ln_post: bool = False, |
|
|
use_checkpoint: bool = False, |
|
|
supervision_type: str = 'occupancy', |
|
|
grad_type: str = "numerical", |
|
|
grad_interval: float = 0.005, |
|
|
query_method: bool = False, |
|
|
use_full_input: bool = True, |
|
|
token_num: int = 256, |
|
|
freeze_encoder: bool = False, |
|
|
decoder_mlp_width_scale: int = 4, |
|
|
residual_kl: bool = False, |
|
|
): |
|
|
|
|
|
MAP_DTYPE = { |
|
|
'float32': torch.float32, |
|
|
'float16': torch.float16, |
|
|
'bfloat16': torch.bfloat16, |
|
|
} |
|
|
if dtype is not None: |
|
|
dtype = MAP_DTYPE[dtype] |
|
|
super().__init__( |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
num_latents=1 + num_latents, |
|
|
point_feats=point_feats, |
|
|
embed_dim=embed_dim, |
|
|
num_freqs=num_freqs, |
|
|
include_pi=include_pi, |
|
|
width=width, |
|
|
decoder_width=decoder_width, |
|
|
heads=heads, |
|
|
num_encoder_layers=num_encoder_layers, |
|
|
num_decoder_layers=num_decoder_layers, |
|
|
init_scale=init_scale, |
|
|
qkv_bias=qkv_bias, |
|
|
flash=flash, |
|
|
use_ln_post=use_ln_post, |
|
|
use_checkpoint=use_checkpoint, |
|
|
supervision_type=supervision_type, |
|
|
grad_type=grad_type, |
|
|
grad_interval=grad_interval, |
|
|
query_method=query_method, |
|
|
token_num=token_num, |
|
|
use_full_input=use_full_input, |
|
|
freeze_encoder=freeze_encoder, |
|
|
decoder_mlp_width_scale=decoder_mlp_width_scale, |
|
|
residual_kl=residual_kl, |
|
|
) |
|
|
|
|
|
self.width = width |
|
|
|
|
|
def encode(self, |
|
|
pc: torch.FloatTensor, |
|
|
feats: Optional[torch.FloatTensor] = None, |
|
|
sample_posterior: bool = True, |
|
|
only_shape: bool=False): |
|
|
""" |
|
|
|
|
|
Args: |
|
|
pc (torch.FloatTensor): [B, N, 3] |
|
|
feats (torch.FloatTensor or None): [B, N, c] |
|
|
sample_posterior (bool): |
|
|
|
|
|
Returns: |
|
|
shape_embed (torch.FloatTensor) |
|
|
kl_embed (torch.FloatTensor): |
|
|
posterior (DiagonalGaussianDistribution or None): |
|
|
""" |
|
|
|
|
|
shape_embed, latents, token_num, pre_pc = self.encode_latents(pc, feats) |
|
|
if only_shape: |
|
|
return shape_embed |
|
|
kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior) |
|
|
|
|
|
return shape_embed, kl_embed, posterior, token_num, pre_pc |
|
|
|
|
|
def encode_latents(self, |
|
|
pc: torch.FloatTensor, |
|
|
feats: Optional[torch.FloatTensor] = None): |
|
|
|
|
|
x, _, token_num, pre_pc = self.encoder(pc, feats) |
|
|
|
|
|
shape_embed = x[:, 0] |
|
|
|
|
|
|
|
|
latents = x |
|
|
|
|
|
return shape_embed, latents, token_num, pre_pc |
|
|
|
|
|
def forward(self, |
|
|
pc: torch.FloatTensor, |
|
|
feats: torch.FloatTensor, |
|
|
volume_queries: torch.FloatTensor, |
|
|
sample_posterior: bool = True): |
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ShapeAsLatentPerceiverEncoder(ShapeAsLatentModule): |
|
|
def __init__(self, *, |
|
|
device: Optional[torch.device], |
|
|
dtype: Optional[Union[torch.dtype, str]], |
|
|
num_latents: int, |
|
|
point_feats: int = 0, |
|
|
embed_dim: int = 0, |
|
|
num_freqs: int = 8, |
|
|
include_pi: bool = True, |
|
|
width: int, |
|
|
heads: int, |
|
|
num_encoder_layers: int, |
|
|
init_scale: float = 0.25, |
|
|
qkv_bias: bool = True, |
|
|
flash: bool = False, |
|
|
use_ln_post: bool = False, |
|
|
use_checkpoint: bool = False, |
|
|
supervision_type: str = 'occupancy', |
|
|
query_method: bool = False, |
|
|
token_num: int = 256, |
|
|
grad_type: str = "numerical", |
|
|
grad_interval: float = 0.005, |
|
|
use_full_input: bool = True, |
|
|
freeze_encoder: bool = False, |
|
|
residual_kl: bool = False, |
|
|
): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
|
|
|
MAP_DTYPE = { |
|
|
'float32': torch.float32, |
|
|
'float16': torch.float16, |
|
|
'bfloat16': torch.bfloat16, |
|
|
} |
|
|
|
|
|
if dtype is not None and isinstance(dtype, str): |
|
|
dtype = MAP_DTYPE[dtype] |
|
|
|
|
|
self.use_checkpoint = use_checkpoint |
|
|
|
|
|
self.num_latents = num_latents |
|
|
assert grad_type in ["numerical", "analytical"] |
|
|
self.grad_type = grad_type |
|
|
self.grad_interval = grad_interval |
|
|
self.supervision_type = supervision_type |
|
|
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) |
|
|
|
|
|
init_scale = init_scale * math.sqrt(1.0 / width) |
|
|
self.encoder = CrossAttentionEncoder( |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
fourier_embedder=self.fourier_embedder, |
|
|
num_latents=num_latents, |
|
|
point_feats=point_feats, |
|
|
width=width, |
|
|
heads=heads, |
|
|
layers=num_encoder_layers, |
|
|
init_scale=init_scale, |
|
|
qkv_bias=qkv_bias, |
|
|
flash=flash, |
|
|
use_ln_post=use_ln_post, |
|
|
use_checkpoint=use_checkpoint, |
|
|
query_method=query_method, |
|
|
use_full_input=use_full_input, |
|
|
token_num=token_num, |
|
|
no_query=True, |
|
|
) |
|
|
|
|
|
self.embed_dim = embed_dim |
|
|
self.residual_kl = residual_kl |
|
|
if freeze_encoder: |
|
|
for p in self.encoder.parameters(): |
|
|
p.requires_grad = False |
|
|
print("freeze encoder") |
|
|
self.width = width |
|
|
|
|
|
def encode_latents(self, |
|
|
pc: torch.FloatTensor, |
|
|
feats: Optional[torch.FloatTensor] = None): |
|
|
|
|
|
x, _, token_num, pre_pc = self.encoder(pc, feats) |
|
|
|
|
|
shape_embed = x[:, 0] |
|
|
latents = x |
|
|
|
|
|
return shape_embed, latents, token_num, pre_pc |
|
|
|
|
|
def forward(self): |
|
|
raise NotImplementedError() |