File size: 623 Bytes
6085c77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
from torchvision import models
import torch.nn as nn
def build_model(fine_tune=True, num_classes=4):
model = models.swin_t(weights='DEFAULT')
print(model)
if fine_tune:
print('[INFO]: Fine-tuning all layers...')
for params in model.parameters():
params.requires_grad = True
if not fine_tune:
print('[INFO]: Freezing hidden layers...')
for params in model.parameters():
params.requires_grad = False
model.head = nn.Linear(
in_features=768,
out_features=num_classes,
bias=True
)
return model |