import torch import torch.nn as nn from torchvision.models import swin_t from transformers import PretrainedConfig, PreTrainedModel # 1. Define a Config class class SwinClassifierConfig(PretrainedConfig): model_type = "swin_classifier" def __init__(self, num_classes=18, **kwargs): super().__init__(**kwargs) self.num_classes = num_classes # 2. Update your Model class to inherit from PreTrainedModel class SwinClassifier(PreTrainedModel): config_class = SwinClassifierConfig def __init__(self, config): super().__init__(config) # Use config.num_classes instead of a raw number self.backbone = swin_t() num_features = self.backbone.head.in_features self.backbone.head = nn.Sequential( nn.Linear(num_features, 256), nn.ReLU(inplace=True), nn.Dropout(0.5), # Use the value from the config nn.Linear(256, config.num_classes) ) def forward(self, x): return self.backbone(x)