| from torchvision import models | |
| import torch.nn as nn | |
| class ViTEncoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.vit = models.vit_b_16(weights="IMAGENET1K_V1") | |
| self.vit.heads = nn.Identity() # remove classifier head | |
| def forward(self, x): | |
| return self.vit(x) | |