Subh775's picture
Upload folder using huggingface_hub
6085c77 verified
raw
history blame contribute delete
623 Bytes
from torchvision import models
import torch.nn as nn
def build_model(fine_tune=True, num_classes=4):
model = models.swin_t(weights='DEFAULT')
print(model)
if fine_tune:
print('[INFO]: Fine-tuning all layers...')
for params in model.parameters():
params.requires_grad = True
if not fine_tune:
print('[INFO]: Freezing hidden layers...')
for params in model.parameters():
params.requires_grad = False
model.head = nn.Linear(
in_features=768,
out_features=num_classes,
bias=True
)
return model