|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision.models import swin_t |
|
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
|
|
|
class SwinClassifierConfig(PretrainedConfig): |
|
|
model_type = "swin_classifier" |
|
|
def __init__(self, num_classes=18, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.num_classes = num_classes |
|
|
|
|
|
|
|
|
class SwinClassifier(PreTrainedModel): |
|
|
config_class = SwinClassifierConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
self.backbone = swin_t() |
|
|
num_features = self.backbone.head.in_features |
|
|
|
|
|
self.backbone.head = nn.Sequential( |
|
|
nn.Linear(num_features, 256), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(0.5), |
|
|
|
|
|
nn.Linear(256, config.num_classes) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.backbone(x) |