File size: 623 Bytes
6085c77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from torchvision import models
import torch.nn as nn
def build_model(fine_tune=True, num_classes=4):
    model = models.swin_t(weights='DEFAULT')
    print(model)
    if fine_tune:
        print('[INFO]: Fine-tuning all layers...')
        for params in model.parameters():
            params.requires_grad = True
    if not fine_tune:
        print('[INFO]: Freezing hidden layers...')
        for params in model.parameters():
            params.requires_grad = False
    
    model.head = nn.Linear(
        in_features=768, 
        out_features=num_classes, 
        bias=True
    )
    return model