Spaces:
Sleeping
Sleeping
| """ | |
| ViT-based contextual encoder. | |
| Takes a makeup image → spatial feature grid Z ∈ R^{h×w×d}. | |
| Uses timm's ViT-B/16 with intermediate feature extraction. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| class ViTEncoder(nn.Module): | |
| """ | |
| Wraps a pretrained ViT-B/16 and reshapes patch tokens into a 2-D | |
| spatial feature grid that the implicit decoder can query via | |
| bilinear interpolation. | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str = "vit_base_patch16_224", | |
| out_dim: int = 768, | |
| img_size: int = 256, | |
| pretrained: bool = True, | |
| ): | |
| super().__init__() | |
| self.img_size = img_size | |
| self.patch_size = 16 | |
| try: | |
| import timm | |
| except ImportError as exc: | |
| raise ImportError( | |
| "timm is required to construct ViTEncoder. Install dependencies from requirements.txt." | |
| ) from exc | |
| # load ViT; override image size so positional embeddings are | |
| # interpolated to our resolution | |
| try: | |
| self.vit = timm.create_model( | |
| model_name, | |
| pretrained=pretrained, | |
| img_size=img_size, | |
| num_classes=0, | |
| dynamic_img_size=True, | |
| ) | |
| except TypeError: | |
| self.vit = timm.create_model( | |
| model_name, | |
| pretrained=pretrained, | |
| img_size=img_size, | |
| num_classes=0, # remove classification head | |
| ) | |
| vit_dim = self.vit.embed_dim # 768 for ViT-B | |
| # optional projection | |
| self.proj = nn.Identity() | |
| if out_dim != vit_dim: | |
| self.proj = nn.Linear(vit_dim, out_dim) | |
| self.out_dim = out_dim | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: (B, 3, H, W) makeup image in [-1, 1] | |
| Returns: | |
| Z: (B, out_dim, grid_h, grid_w) spatial feature grid | |
| """ | |
| # timm ViT forward_features returns (B, 1+N, D) with CLS token | |
| tokens = self.vit.forward_features(x) # (B, 1+N, D) | |
| patch_tokens = tokens[:, 1:, :] # drop CLS → (B, N, D) | |
| patch_tokens = self.proj(patch_tokens) # (B, N, out_dim) | |
| B, N, C = patch_tokens.shape | |
| grid_h = x.shape[-2] // self.patch_size | |
| grid_w = x.shape[-1] // self.patch_size | |
| Z = patch_tokens.permute(0, 2, 1).reshape(B, C, grid_h, grid_w) | |
| return Z | |