File size: 313 Bytes
7131543 |
1 2 3 4 5 6 7 8 9 10 11 12 |
from torchvision import models
import torch.nn as nn
class ViTEncoder(nn.Module):
def __init__(self):
super().__init__()
self.vit = models.vit_b_16(weights="IMAGENET1K_V1")
self.vit.heads = nn.Identity() # remove classifier head
def forward(self, x):
return self.vit(x)
|