safiaa02 commited on
Commit
7131543
·
verified ·
1 Parent(s): 8a70974

Create vit_encoder.py

Browse files
Files changed (1) hide show
  1. vit_encoder.py +11 -0
vit_encoder.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import models
2
+ import torch.nn as nn
3
+
4
+ class ViTEncoder(nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.vit = models.vit_b_16(weights="IMAGENET1K_V1")
8
+ self.vit.heads = nn.Identity() # remove classifier head
9
+
10
+ def forward(self, x):
11
+ return self.vit(x)