Spaces:
Runtime error
Runtime error
File size: 819 Bytes
b418fb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
import torchvision
from torch import nn
def create_vit(pretrained_weights: torchvision.models.Weights,
model: torchvision.models,
in_features: int,
out_features: int,
device: torch.device):
"""Creates a Vision Transformer (ViT) instance from torchvision
and returns it.
"""
# Create a pretrained ViT model
model = torchvision.models.vit_b_16(weights=pretrained_weights).to(device)
transforms = pretrained_weights.transforms()
# Freeze the feature extractor
for param in model.parameters():
param.requires_grad = False
# Change the head of the ViT
model.heads = nn.Sequential(
nn.Linear(in_features=in_features, out_features=out_features)
).to(device)
return model, transforms
|