SCREAMIE's picture
vit finetune on food101 gradio demo app
1d12158
raw
history blame contribute delete
396 Bytes
from torch import nn
import torchvision
def pretrained_vit():
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
transforms = pretrained_vit_weights.transforms()
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights)
pretrained_vit.heads = nn.Linear(
in_features=768, out_features=101,
)
return pretrained_vit, transforms