"""ViT encoder (ยง2.4). Patchify each view's RGB and per-view Plucker coordinates (patch ``p=8``), linear-project both to ``C`` and **sum** them, add a small per-timestep temporal embedding, then run cross-image self-attention across all ``V x |t|`` views so the cameras fuse. Output: image tokens ``B in R^{N_I x C}``. """ from __future__ import annotations from typing import Optional import torch import torch.nn as nn from mapgs.model.blocks import EncoderBlock, PatchEmbed class ViTEncoder(nn.Module): def __init__( self, dim: int = 1024, patch: int = 8, depth: int = 6, n_heads: int = 16, mlp_ratio: float = 4.0, qk_norm: bool = True, layerscale_init: float = 1e-5, max_timesteps: int = 16, ): super().__init__() self.dim = dim self.patch = patch self.img_embed = PatchEmbed(3, dim, patch) self.ray_embed = PatchEmbed(6, dim, patch) self.temporal_embed = nn.Embedding(max_timesteps, dim) nn.init.normal_(self.temporal_embed.weight, std=0.02) self.blocks = nn.ModuleList( [EncoderBlock(dim, n_heads, mlp_ratio, qk_norm, layerscale_init) for _ in range(depth)] ) self.norm = nn.LayerNorm(dim) self.grad_checkpoint = False def forward( self, images: torch.Tensor, # [B, V, 3, H, W] plucker: torch.Tensor, # [B, V, 6, H, W] timestep_ids: Optional[torch.Tensor] = None, # [B, V] long ) -> torch.Tensor: B, V, _, H, W = images.shape img = images.flatten(0, 1) # [B*V, 3, H, W] ray = plucker.flatten(0, 1) # [B*V, 6, H, W] tok_img, (gh, gw) = self.img_embed(img) tok_ray, _ = self.ray_embed(ray) tok = tok_img + tok_ray # [B*V, Np, C] Np = tok.shape[1] tok = tok.view(B, V, Np, self.dim) if timestep_ids is not None: tok = tok + self.temporal_embed(timestep_ids)[:, :, None, :] tok = tok.reshape(B, V * Np, self.dim) # N_I = V*Np for blk in self.blocks: if self.grad_checkpoint and self.training: tok = torch.utils.checkpoint.checkpoint(blk, tok, use_reentrant=False) else: tok = blk(tok) return self.norm(tok)