| from torchvision import models | |
| import torch.nn as nn | |
| def build_model(fine_tune=True, num_classes=4): | |
| model = models.swin_t(weights='DEFAULT') | |
| print(model) | |
| if fine_tune: | |
| print('[INFO]: Fine-tuning all layers...') | |
| for params in model.parameters(): | |
| params.requires_grad = True | |
| if not fine_tune: | |
| print('[INFO]: Freezing hidden layers...') | |
| for params in model.parameters(): | |
| params.requires_grad = False | |
| model.head = nn.Linear( | |
| in_features=768, | |
| out_features=num_classes, | |
| bias=True | |
| ) | |
| return model |