ndjadjafbagk / query_point.py
udbbdh's picture
Upload folder using huggingface_hub
7340df2 verified
import torch
import torch.nn as nn
from torch import einsum
import torch.nn.functional as F
from functools import partial
from timm.models.layers import DropPath
from einops import rearrange, repeat
# ---- PE: NeRF-style Position Encoding ----
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(self.identity_fn)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(partial(self.periodic_fn, p_fn=p_fn, freq=freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def identity_fn(self, x):
return x
def periodic_fn(self, x, p_fn, freq):
return p_fn(x * freq)
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, i=0):
if i == -1:
return nn.Identity(), 1
embed_kwargs = {
'include_input': True,
'input_dims': 1,
'max_freq_log2': multires-1,
'num_freqs': multires,
'log_sampling': True,
'periodic_fns': [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
embed = embedder_obj.embed
return embed, embedder_obj.out_dim
class PE_NeRF(nn.Module):
def __init__(self, out_channels=512, multires=10):
super().__init__()
self.multires = multires
self.embed_fn, embed_dim_per_dim = get_embedder(multires) # per-dim embed
self.embed_dim = embed_dim_per_dim * 3 # since 3D: x, y, z
self.coor_embed = nn.Sequential(
nn.Linear(self.embed_dim, 256),
nn.GELU(),
nn.Linear(256, out_channels)
)
def forward(self, vertices: torch.Tensor) -> torch.Tensor:
"""
Args:
vertices: [B, 3] or [N, 3], coordinates in [-0.5, 0.5]
Returns:
encoded: [B, out_channels * 3]
"""
x_embed = self.embed_fn(vertices[..., 0:1]) # [N, D]
y_embed = self.embed_fn(vertices[..., 1:2])
z_embed = self.embed_fn(vertices[..., 2:3])
pos_enc = torch.cat([x_embed, y_embed, z_embed], dim=-1) # [N, D * 3]
return self.coor_embed(pos_enc)
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# ---- Attention & FF blocks ----
class GEGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, mult=4):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Linear(dim * mult, dim)
)
def forward(self, x):
return self.net(x)
class PreNorm(nn.Module):
def __init__(self, dim, fn, context_dim = None):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None
def forward(self, x, **kwargs):
x = self.norm(x)
if exists(self.norm_context):
context = kwargs['context']
normed_context = self.norm_context(context)
kwargs.update(context = normed_context)
return self.fn(x, **kwargs)
class Attention(nn.Module):
def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, drop_path_rate = 0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, query_dim)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x, context = None, mask = None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h = h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim = -1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
return self.drop_path(self.to_out(out))
class QueryPointDecoder(nn.Module):
def __init__(self, query_dim=1536, context_dim=512, output_dim=1, depth=8,
using_nerf=True, quantize_bits=10, dim=512, heads=8, multires=10):
super().__init__()
self.using_nerf = using_nerf
self.depth = depth
if using_nerf:
self.pe = PE_NeRF(out_channels=query_dim, multires=multires)
else:
self.embedding_x = nn.Embedding(2**quantize_bits, query_dim // 3)
self.embedding_y = nn.Embedding(2**quantize_bits, query_dim // 3)
self.embedding_z = nn.Embedding(2**quantize_bits, query_dim // 3)
self.coord_proj = nn.Sequential(
nn.Linear(query_dim, query_dim * 4),
nn.GELU(),
nn.Linear(query_dim * 4, query_dim)
)
# self.context_proj = nn.Linear(context_dim, query_dim)
self.context_proj = nn.Linear(context_dim, dim)
self.pe_ctx = PE_NeRF(out_channels=dim, multires=multires)
self.context_self_attn_layers = nn.ModuleList([
nn.ModuleList([
PreNorm(dim, Attention(dim, dim_head=64, heads=heads)),
PreNorm(dim, FeedForward(dim))
]) for _ in range(depth)
])
self.cross_attn = PreNorm(dim,
Attention(dim, dim,
dim_head=dim, heads=1))
self.cross_ff = PreNorm(dim, FeedForward(dim))
self.to_outputs = nn.Linear(dim, output_dim)
def forward(self, query_points, context_feats, context_mask=None, voxels_coords=None,):
B, N, _ = query_points.shape
if self.using_nerf:
# print('query_points.min()', query_points.min())
# print('query_points.max()', query_points.max())
x = self.pe(query_points.view(-1, 3)).view(B, N, -1)
else:
embeddings = torch.cat([
self.embedding_x(query_points[..., 0]),
self.embedding_y(query_points[..., 1]),
self.embedding_z(query_points[..., 2]),
], dim=-1)
x = self.coord_proj(embeddings)
context = self.context_proj(context_feats)
if voxels_coords is not None:
M = voxels_coords.shape[1]
normalized_coords = 2.0 * (voxels_coords.float() / 1024.) - 1.0
context += self.pe_ctx(normalized_coords.view(-1, 3)).view(B, M, -1)
attn_mask = context_mask[:, None, None, :] if context_mask is not None else None
for self_attn, ff in self.context_self_attn_layers:
context = self_attn(context, mask=attn_mask) + context
context = ff(context) + context
latents = self.cross_attn(x, context=context, mask=attn_mask)
latents = self.cross_ff(x) + latents
return self.to_outputs(latents).squeeze(-1)
if __name__ == '__main__':
torch.manual_seed(42)
model = QueryPointDecoder().cuda()
model.eval()
B, N, M = 2, 64, 20
query_pts = torch.rand(B, N, 3).cuda() - 0.5 # [-0.5, 0.5]
context_feats = torch.randn(B, M, 512).cuda()
with torch.no_grad():
logits = model(query_pts, context_feats)
print("Logits shape:", logits.shape) # [B, N, 1]