ChestGPT / vit_encoder.py
safiaa02's picture
Create vit_encoder.py
7131543 verified
raw
history blame contribute delete
313 Bytes
from torchvision import models
import torch.nn as nn
class ViTEncoder(nn.Module):
def __init__(self):
super().__init__()
self.vit = models.vit_b_16(weights="IMAGENET1K_V1")
self.vit.heads = nn.Identity() # remove classifier head
def forward(self, x):
return self.vit(x)