Spaces:
Sleeping
Sleeping
File size: 2,489 Bytes
0046a78 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | """
ViT-based contextual encoder.
Takes a makeup image → spatial feature grid Z ∈ R^{h×w×d}.
Uses timm's ViT-B/16 with intermediate feature extraction.
"""
import torch
import torch.nn as nn
class ViTEncoder(nn.Module):
"""
Wraps a pretrained ViT-B/16 and reshapes patch tokens into a 2-D
spatial feature grid that the implicit decoder can query via
bilinear interpolation.
"""
def __init__(
self,
model_name: str = "vit_base_patch16_224",
out_dim: int = 768,
img_size: int = 256,
pretrained: bool = True,
):
super().__init__()
self.img_size = img_size
self.patch_size = 16
try:
import timm
except ImportError as exc:
raise ImportError(
"timm is required to construct ViTEncoder. Install dependencies from requirements.txt."
) from exc
# load ViT; override image size so positional embeddings are
# interpolated to our resolution
try:
self.vit = timm.create_model(
model_name,
pretrained=pretrained,
img_size=img_size,
num_classes=0,
dynamic_img_size=True,
)
except TypeError:
self.vit = timm.create_model(
model_name,
pretrained=pretrained,
img_size=img_size,
num_classes=0, # remove classification head
)
vit_dim = self.vit.embed_dim # 768 for ViT-B
# optional projection
self.proj = nn.Identity()
if out_dim != vit_dim:
self.proj = nn.Linear(vit_dim, out_dim)
self.out_dim = out_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, 3, H, W) makeup image in [-1, 1]
Returns:
Z: (B, out_dim, grid_h, grid_w) spatial feature grid
"""
# timm ViT forward_features returns (B, 1+N, D) with CLS token
tokens = self.vit.forward_features(x) # (B, 1+N, D)
patch_tokens = tokens[:, 1:, :] # drop CLS → (B, N, D)
patch_tokens = self.proj(patch_tokens) # (B, N, out_dim)
B, N, C = patch_tokens.shape
grid_h = x.shape[-2] // self.patch_size
grid_w = x.shape[-1] // self.patch_size
Z = patch_tokens.permute(0, 2, 1).reshape(B, C, grid_h, grid_w)
return Z
|