""" ViT-based contextual encoder. Takes a makeup image → spatial feature grid Z ∈ R^{h×w×d}. Uses timm's ViT-B/16 with intermediate feature extraction. """ import torch import torch.nn as nn class ViTEncoder(nn.Module): """ Wraps a pretrained ViT-B/16 and reshapes patch tokens into a 2-D spatial feature grid that the implicit decoder can query via bilinear interpolation. """ def __init__( self, model_name: str = "vit_base_patch16_224", out_dim: int = 768, img_size: int = 256, pretrained: bool = True, ): super().__init__() self.img_size = img_size self.patch_size = 16 try: import timm except ImportError as exc: raise ImportError( "timm is required to construct ViTEncoder. Install dependencies from requirements.txt." ) from exc # load ViT; override image size so positional embeddings are # interpolated to our resolution try: self.vit = timm.create_model( model_name, pretrained=pretrained, img_size=img_size, num_classes=0, dynamic_img_size=True, ) except TypeError: self.vit = timm.create_model( model_name, pretrained=pretrained, img_size=img_size, num_classes=0, # remove classification head ) vit_dim = self.vit.embed_dim # 768 for ViT-B # optional projection self.proj = nn.Identity() if out_dim != vit_dim: self.proj = nn.Linear(vit_dim, out_dim) self.out_dim = out_dim def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (B, 3, H, W) makeup image in [-1, 1] Returns: Z: (B, out_dim, grid_h, grid_w) spatial feature grid """ # timm ViT forward_features returns (B, 1+N, D) with CLS token tokens = self.vit.forward_features(x) # (B, 1+N, D) patch_tokens = tokens[:, 1:, :] # drop CLS → (B, N, D) patch_tokens = self.proj(patch_tokens) # (B, N, out_dim) B, N, C = patch_tokens.shape grid_h = x.shape[-2] // self.patch_size grid_w = x.shape[-1] // self.patch_size Z = patch_tokens.permute(0, 2, 1).reshape(B, C, grid_h, grid_w) return Z