| """ViT encoder (§2.4). |
| |
| Patchify each view's RGB and per-view Plucker coordinates (patch ``p=8``), |
| linear-project both to ``C`` and **sum** them, add a small per-timestep temporal |
| embedding, then run cross-image self-attention across all ``V x |t|`` views so |
| the cameras fuse. Output: image tokens ``B in R^{N_I x C}``. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from mapgs.model.blocks import EncoderBlock, PatchEmbed |
|
|
|
|
| class ViTEncoder(nn.Module): |
| def __init__( |
| self, |
| dim: int = 1024, |
| patch: int = 8, |
| depth: int = 6, |
| n_heads: int = 16, |
| mlp_ratio: float = 4.0, |
| qk_norm: bool = True, |
| layerscale_init: float = 1e-5, |
| max_timesteps: int = 16, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.patch = patch |
| self.img_embed = PatchEmbed(3, dim, patch) |
| self.ray_embed = PatchEmbed(6, dim, patch) |
| self.temporal_embed = nn.Embedding(max_timesteps, dim) |
| nn.init.normal_(self.temporal_embed.weight, std=0.02) |
| self.blocks = nn.ModuleList( |
| [EncoderBlock(dim, n_heads, mlp_ratio, qk_norm, layerscale_init) for _ in range(depth)] |
| ) |
| self.norm = nn.LayerNorm(dim) |
| self.grad_checkpoint = False |
|
|
| def forward( |
| self, |
| images: torch.Tensor, |
| plucker: torch.Tensor, |
| timestep_ids: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| B, V, _, H, W = images.shape |
| img = images.flatten(0, 1) |
| ray = plucker.flatten(0, 1) |
| tok_img, (gh, gw) = self.img_embed(img) |
| tok_ray, _ = self.ray_embed(ray) |
| tok = tok_img + tok_ray |
| Np = tok.shape[1] |
| tok = tok.view(B, V, Np, self.dim) |
| if timestep_ids is not None: |
| tok = tok + self.temporal_embed(timestep_ids)[:, :, None, :] |
| tok = tok.reshape(B, V * Np, self.dim) |
| for blk in self.blocks: |
| if self.grad_checkpoint and self.training: |
| tok = torch.utils.checkpoint.checkpoint(blk, tok, use_reentrant=False) |
| else: |
| tok = blk(tok) |
| return self.norm(tok) |
|
|