mapvggt / mapgs /model /encoder.py
ChenmingWu's picture
Upload folder using huggingface_hub
b2efbe4 verified
Raw
History Blame Contribute Delete
2.34 kB
"""ViT encoder (§2.4).
Patchify each view's RGB and per-view Plucker coordinates (patch ``p=8``),
linear-project both to ``C`` and **sum** them, add a small per-timestep temporal
embedding, then run cross-image self-attention across all ``V x |t|`` views so
the cameras fuse. Output: image tokens ``B in R^{N_I x C}``.
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
from mapgs.model.blocks import EncoderBlock, PatchEmbed
class ViTEncoder(nn.Module):
def __init__(
self,
dim: int = 1024,
patch: int = 8,
depth: int = 6,
n_heads: int = 16,
mlp_ratio: float = 4.0,
qk_norm: bool = True,
layerscale_init: float = 1e-5,
max_timesteps: int = 16,
):
super().__init__()
self.dim = dim
self.patch = patch
self.img_embed = PatchEmbed(3, dim, patch)
self.ray_embed = PatchEmbed(6, dim, patch)
self.temporal_embed = nn.Embedding(max_timesteps, dim)
nn.init.normal_(self.temporal_embed.weight, std=0.02)
self.blocks = nn.ModuleList(
[EncoderBlock(dim, n_heads, mlp_ratio, qk_norm, layerscale_init) for _ in range(depth)]
)
self.norm = nn.LayerNorm(dim)
self.grad_checkpoint = False
def forward(
self,
images: torch.Tensor, # [B, V, 3, H, W]
plucker: torch.Tensor, # [B, V, 6, H, W]
timestep_ids: Optional[torch.Tensor] = None, # [B, V] long
) -> torch.Tensor:
B, V, _, H, W = images.shape
img = images.flatten(0, 1) # [B*V, 3, H, W]
ray = plucker.flatten(0, 1) # [B*V, 6, H, W]
tok_img, (gh, gw) = self.img_embed(img)
tok_ray, _ = self.ray_embed(ray)
tok = tok_img + tok_ray # [B*V, Np, C]
Np = tok.shape[1]
tok = tok.view(B, V, Np, self.dim)
if timestep_ids is not None:
tok = tok + self.temporal_embed(timestep_ids)[:, :, None, :]
tok = tok.reshape(B, V * Np, self.dim) # N_I = V*Np
for blk in self.blocks:
if self.grad_checkpoint and self.training:
tok = torch.utils.checkpoint.checkpoint(blk, tok, use_reentrant=False)
else:
tok = blk(tok)
return self.norm(tok)