import torch import torch.nn as nn from torch import einsum import torch.nn.functional as F from functools import partial from timm.models.layers import DropPath from einops import rearrange, repeat # ---- PE: NeRF-style Position Encoding ---- class Embedder: def __init__(self, **kwargs): self.kwargs = kwargs self.create_embedding_fn() def create_embedding_fn(self): embed_fns = [] d = self.kwargs['input_dims'] out_dim = 0 if self.kwargs['include_input']: embed_fns.append(self.identity_fn) out_dim += d max_freq = self.kwargs['max_freq_log2'] N_freqs = self.kwargs['num_freqs'] if self.kwargs['log_sampling']: freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) else: freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) for freq in freq_bands: for p_fn in self.kwargs['periodic_fns']: embed_fns.append(partial(self.periodic_fn, p_fn=p_fn, freq=freq)) out_dim += d self.embed_fns = embed_fns self.out_dim = out_dim def identity_fn(self, x): return x def periodic_fn(self, x, p_fn, freq): return p_fn(x * freq) def embed(self, inputs): return torch.cat([fn(inputs) for fn in self.embed_fns], -1) def get_embedder(multires, i=0): if i == -1: return nn.Identity(), 1 embed_kwargs = { 'include_input': True, 'input_dims': 1, 'max_freq_log2': multires-1, 'num_freqs': multires, 'log_sampling': True, 'periodic_fns': [torch.sin, torch.cos], } embedder_obj = Embedder(**embed_kwargs) embed = embedder_obj.embed return embed, embedder_obj.out_dim class PE_NeRF(nn.Module): def __init__(self, out_channels=512, multires=10): super().__init__() self.multires = multires self.embed_fn, embed_dim_per_dim = get_embedder(multires) # per-dim embed self.embed_dim = embed_dim_per_dim * 3 # since 3D: x, y, z self.coor_embed = nn.Sequential( nn.Linear(self.embed_dim, 256), nn.GELU(), nn.Linear(256, out_channels) ) def forward(self, vertices: torch.Tensor) -> torch.Tensor: """ Args: vertices: [B, 3] or [N, 3], coordinates in [-0.5, 0.5] Returns: encoded: [B, out_channels * 3] """ x_embed = self.embed_fn(vertices[..., 0:1]) # [N, D] y_embed = self.embed_fn(vertices[..., 1:2]) z_embed = self.embed_fn(vertices[..., 2:3]) pos_enc = torch.cat([x_embed, y_embed, z_embed], dim=-1) # [N, D * 3] return self.coor_embed(pos_enc) def exists(val): return val is not None def default(val, d): return val if exists(val) else d # ---- Attention & FF blocks ---- class GEGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim=-1) return x * F.gelu(gate) class FeedForward(nn.Module): def __init__(self, dim, mult=4): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim * mult * 2), GEGLU(), nn.Linear(dim * mult, dim) ) def forward(self, x): return self.net(x) class PreNorm(nn.Module): def __init__(self, dim, fn, context_dim = None): super().__init__() self.fn = fn self.norm = nn.LayerNorm(dim) self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None def forward(self, x, **kwargs): x = self.norm(x) if exists(self.norm_context): context = kwargs['context'] normed_context = self.norm_context(context) kwargs.update(context = normed_context) return self.fn(x, **kwargs) class Attention(nn.Module): def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, drop_path_rate = 0.0): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.scale = dim_head ** -0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias = False) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) self.to_out = nn.Linear(inner_dim, query_dim) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() def forward(self, x, context = None, mask = None): h = self.heads q = self.to_q(x) context = default(context, x) k, v = self.to_kv(context).chunk(2, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v)) sim = einsum('b i d, b j d -> b i j', q, k) * self.scale if exists(mask): mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h = h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of attn = sim.softmax(dim = -1) out = einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) n d -> b n (h d)', h = h) return self.drop_path(self.to_out(out)) class QueryPointDecoder(nn.Module): def __init__(self, query_dim=1536, context_dim=512, output_dim=1, depth=8, using_nerf=True, quantize_bits=10, dim=512, heads=8, multires=10): super().__init__() self.using_nerf = using_nerf self.depth = depth if using_nerf: self.pe = PE_NeRF(out_channels=query_dim, multires=multires) else: self.embedding_x = nn.Embedding(2**quantize_bits, query_dim // 3) self.embedding_y = nn.Embedding(2**quantize_bits, query_dim // 3) self.embedding_z = nn.Embedding(2**quantize_bits, query_dim // 3) self.coord_proj = nn.Sequential( nn.Linear(query_dim, query_dim * 4), nn.GELU(), nn.Linear(query_dim * 4, query_dim) ) # self.context_proj = nn.Linear(context_dim, query_dim) self.context_proj = nn.Linear(context_dim, dim) self.pe_ctx = PE_NeRF(out_channels=dim, multires=multires) self.context_self_attn_layers = nn.ModuleList([ nn.ModuleList([ PreNorm(dim, Attention(dim, dim_head=64, heads=heads)), PreNorm(dim, FeedForward(dim)) ]) for _ in range(depth) ]) self.cross_attn = PreNorm(dim, Attention(dim, dim, dim_head=dim, heads=1)) self.cross_ff = PreNorm(dim, FeedForward(dim)) self.to_outputs = nn.Linear(dim, output_dim) def forward(self, query_points, context_feats, context_mask=None, voxels_coords=None,): B, N, _ = query_points.shape if self.using_nerf: # print('query_points.min()', query_points.min()) # print('query_points.max()', query_points.max()) x = self.pe(query_points.view(-1, 3)).view(B, N, -1) else: embeddings = torch.cat([ self.embedding_x(query_points[..., 0]), self.embedding_y(query_points[..., 1]), self.embedding_z(query_points[..., 2]), ], dim=-1) x = self.coord_proj(embeddings) context = self.context_proj(context_feats) if voxels_coords is not None: M = voxels_coords.shape[1] normalized_coords = 2.0 * (voxels_coords.float() / 1024.) - 1.0 context += self.pe_ctx(normalized_coords.view(-1, 3)).view(B, M, -1) attn_mask = context_mask[:, None, None, :] if context_mask is not None else None for self_attn, ff in self.context_self_attn_layers: context = self_attn(context, mask=attn_mask) + context context = ff(context) + context latents = self.cross_attn(x, context=context, mask=attn_mask) latents = self.cross_ff(x) + latents return self.to_outputs(latents).squeeze(-1) if __name__ == '__main__': torch.manual_seed(42) model = QueryPointDecoder().cuda() model.eval() B, N, M = 2, 64, 20 query_pts = torch.rand(B, N, 3).cuda() - 0.5 # [-0.5, 0.5] context_feats = torch.randn(B, M, 512).cuda() with torch.no_grad(): logits = model(query_pts, context_feats) print("Logits shape:", logits.shape) # [B, N, 1]