| | import torch |
| | import torch.nn as nn |
| | from torch import einsum |
| | import torch.nn.functional as F |
| | from functools import partial |
| | from timm.models.layers import DropPath |
| | from einops import rearrange, repeat |
| |
|
| | |
| | class Embedder: |
| | def __init__(self, **kwargs): |
| | self.kwargs = kwargs |
| | self.create_embedding_fn() |
| | |
| | def create_embedding_fn(self): |
| | embed_fns = [] |
| | d = self.kwargs['input_dims'] |
| | out_dim = 0 |
| | if self.kwargs['include_input']: |
| | embed_fns.append(self.identity_fn) |
| | out_dim += d |
| | |
| | max_freq = self.kwargs['max_freq_log2'] |
| | N_freqs = self.kwargs['num_freqs'] |
| | |
| | if self.kwargs['log_sampling']: |
| | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) |
| | else: |
| | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) |
| | |
| | for freq in freq_bands: |
| | for p_fn in self.kwargs['periodic_fns']: |
| | embed_fns.append(partial(self.periodic_fn, p_fn=p_fn, freq=freq)) |
| | out_dim += d |
| | |
| | self.embed_fns = embed_fns |
| | self.out_dim = out_dim |
| | |
| | def identity_fn(self, x): |
| | return x |
| | |
| | def periodic_fn(self, x, p_fn, freq): |
| | return p_fn(x * freq) |
| | |
| | def embed(self, inputs): |
| | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) |
| |
|
| | def get_embedder(multires, i=0): |
| | if i == -1: |
| | return nn.Identity(), 1 |
| | |
| | embed_kwargs = { |
| | 'include_input': True, |
| | 'input_dims': 1, |
| | 'max_freq_log2': multires-1, |
| | 'num_freqs': multires, |
| | 'log_sampling': True, |
| | 'periodic_fns': [torch.sin, torch.cos], |
| | } |
| | |
| | embedder_obj = Embedder(**embed_kwargs) |
| | embed = embedder_obj.embed |
| | return embed, embedder_obj.out_dim |
| |
|
| |
|
| | class PE_NeRF(nn.Module): |
| | def __init__(self, out_channels=512, multires=10): |
| | super().__init__() |
| |
|
| | self.multires = multires |
| | self.embed_fn, embed_dim_per_dim = get_embedder(multires) |
| | self.embed_dim = embed_dim_per_dim * 3 |
| |
|
| | self.coor_embed = nn.Sequential( |
| | nn.Linear(self.embed_dim, 256), |
| | nn.GELU(), |
| | nn.Linear(256, out_channels) |
| | ) |
| |
|
| | def forward(self, vertices: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | vertices: [B, 3] or [N, 3], coordinates in [-0.5, 0.5] |
| | Returns: |
| | encoded: [B, out_channels * 3] |
| | """ |
| | x_embed = self.embed_fn(vertices[..., 0:1]) |
| | y_embed = self.embed_fn(vertices[..., 1:2]) |
| | z_embed = self.embed_fn(vertices[..., 2:3]) |
| |
|
| | pos_enc = torch.cat([x_embed, y_embed, z_embed], dim=-1) |
| |
|
| | return self.coor_embed(pos_enc) |
| |
|
| | def exists(val): |
| | return val is not None |
| |
|
| | def default(val, d): |
| | return val if exists(val) else d |
| |
|
| | |
| | class GEGLU(nn.Module): |
| | def forward(self, x): |
| | x, gate = x.chunk(2, dim=-1) |
| | return x * F.gelu(gate) |
| |
|
| |
|
| | class FeedForward(nn.Module): |
| | def __init__(self, dim, mult=4): |
| | super().__init__() |
| | self.net = nn.Sequential( |
| | nn.Linear(dim, dim * mult * 2), |
| | GEGLU(), |
| | nn.Linear(dim * mult, dim) |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.net(x) |
| |
|
| |
|
| | class PreNorm(nn.Module): |
| | def __init__(self, dim, fn, context_dim = None): |
| | super().__init__() |
| | self.fn = fn |
| | self.norm = nn.LayerNorm(dim) |
| | self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None |
| |
|
| | def forward(self, x, **kwargs): |
| | x = self.norm(x) |
| |
|
| | if exists(self.norm_context): |
| | context = kwargs['context'] |
| | normed_context = self.norm_context(context) |
| | kwargs.update(context = normed_context) |
| |
|
| | return self.fn(x, **kwargs) |
| |
|
| |
|
| | class Attention(nn.Module): |
| | def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, drop_path_rate = 0.0): |
| | super().__init__() |
| | inner_dim = dim_head * heads |
| | context_dim = default(context_dim, query_dim) |
| | self.scale = dim_head ** -0.5 |
| | self.heads = heads |
| |
|
| | self.to_q = nn.Linear(query_dim, inner_dim, bias = False) |
| | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) |
| | self.to_out = nn.Linear(inner_dim, query_dim) |
| |
|
| | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() |
| |
|
| | def forward(self, x, context = None, mask = None): |
| | h = self.heads |
| |
|
| | q = self.to_q(x) |
| | context = default(context, x) |
| | k, v = self.to_kv(context).chunk(2, dim = -1) |
| |
|
| | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v)) |
| |
|
| | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale |
| |
|
| | if exists(mask): |
| | mask = rearrange(mask, 'b ... -> b (...)') |
| | max_neg_value = -torch.finfo(sim.dtype).max |
| | mask = repeat(mask, 'b j -> (b h) () j', h = h) |
| | sim.masked_fill_(~mask, max_neg_value) |
| |
|
| | |
| | attn = sim.softmax(dim = -1) |
| |
|
| | out = einsum('b i j, b j d -> b i d', attn, v) |
| | out = rearrange(out, '(b h) n d -> b n (h d)', h = h) |
| | return self.drop_path(self.to_out(out)) |
| |
|
| | class QueryPointDecoder(nn.Module): |
| | def __init__(self, query_dim=1536, context_dim=512, output_dim=1, depth=8, |
| | using_nerf=True, quantize_bits=10, dim=512, heads=8, multires=10): |
| | super().__init__() |
| | self.using_nerf = using_nerf |
| | self.depth = depth |
| | |
| | if using_nerf: |
| | self.pe = PE_NeRF(out_channels=query_dim, multires=multires) |
| | else: |
| | self.embedding_x = nn.Embedding(2**quantize_bits, query_dim // 3) |
| | self.embedding_y = nn.Embedding(2**quantize_bits, query_dim // 3) |
| | self.embedding_z = nn.Embedding(2**quantize_bits, query_dim // 3) |
| | self.coord_proj = nn.Sequential( |
| | nn.Linear(query_dim, query_dim * 4), |
| | nn.GELU(), |
| | nn.Linear(query_dim * 4, query_dim) |
| | ) |
| | |
| | |
| | self.context_proj = nn.Linear(context_dim, dim) |
| |
|
| | self.pe_ctx = PE_NeRF(out_channels=dim, multires=multires) |
| |
|
| | self.context_self_attn_layers = nn.ModuleList([ |
| | nn.ModuleList([ |
| | PreNorm(dim, Attention(dim, dim_head=64, heads=heads)), |
| | PreNorm(dim, FeedForward(dim)) |
| | ]) for _ in range(depth) |
| | ]) |
| | |
| | self.cross_attn = PreNorm(dim, |
| | Attention(dim, dim, |
| | dim_head=dim, heads=1)) |
| | self.cross_ff = PreNorm(dim, FeedForward(dim)) |
| | |
| | self.to_outputs = nn.Linear(dim, output_dim) |
| |
|
| | def forward(self, query_points, context_feats, context_mask=None, voxels_coords=None,): |
| | B, N, _ = query_points.shape |
| | |
| | if self.using_nerf: |
| | |
| | |
| | x = self.pe(query_points.view(-1, 3)).view(B, N, -1) |
| | else: |
| | embeddings = torch.cat([ |
| | self.embedding_x(query_points[..., 0]), |
| | self.embedding_y(query_points[..., 1]), |
| | self.embedding_z(query_points[..., 2]), |
| | ], dim=-1) |
| | x = self.coord_proj(embeddings) |
| | |
| | context = self.context_proj(context_feats) |
| | |
| | if voxels_coords is not None: |
| | M = voxels_coords.shape[1] |
| | normalized_coords = 2.0 * (voxels_coords.float() / 1024.) - 1.0 |
| | context += self.pe_ctx(normalized_coords.view(-1, 3)).view(B, M, -1) |
| |
|
| | attn_mask = context_mask[:, None, None, :] if context_mask is not None else None |
| | |
| | for self_attn, ff in self.context_self_attn_layers: |
| | context = self_attn(context, mask=attn_mask) + context |
| | context = ff(context) + context |
| | |
| | latents = self.cross_attn(x, context=context, mask=attn_mask) |
| | latents = self.cross_ff(x) + latents |
| |
|
| | return self.to_outputs(latents).squeeze(-1) |
| |
|
| | if __name__ == '__main__': |
| | torch.manual_seed(42) |
| | model = QueryPointDecoder().cuda() |
| | model.eval() |
| |
|
| | B, N, M = 2, 64, 20 |
| | query_pts = torch.rand(B, N, 3).cuda() - 0.5 |
| | context_feats = torch.randn(B, M, 512).cuda() |
| |
|
| | with torch.no_grad(): |
| | logits = model(query_pts, context_feats) |
| | print("Logits shape:", logits.shape) |
| |
|