Paul commited on
Commit
35ee8c6
·
verified ·
1 Parent(s): 9d30e4d

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +18 -10
model.py CHANGED
@@ -1,23 +1,31 @@
1
  import torch
2
  import torch.nn as nn
3
- from torchvision.models import swin_t, Swin_T_Weights
4
- from huggingface_hub import PyTorchModelHubMixin
5
 
6
- class SwinClassifier(nn.Module, PyTorchModelHubMixin):
7
- def __init__(self, num_classes):
8
- super().__init__()
9
- # 1. Load the backbone
 
 
 
 
 
 
 
 
 
 
10
  self.backbone = swin_t()
11
-
12
- # 2. Get features and replace the head
13
  num_features = self.backbone.head.in_features
14
 
15
- # We replace the original head with our custom Sequential block
16
  self.backbone.head = nn.Sequential(
17
  nn.Linear(num_features, 256),
18
  nn.ReLU(inplace=True),
19
  nn.Dropout(0.5),
20
- nn.Linear(256, num_classes)
 
21
  )
22
 
23
  def forward(self, x):
 
1
  import torch
2
  import torch.nn as nn
3
+ from torchvision.models import swin_t
4
+ from transformers import PretrainedConfig, PreTrainedModel
5
 
6
+ # 1. Define a Config class
7
+ class SwinClassifierConfig(PretrainedConfig):
8
+ model_type = "swin_classifier"
9
+ def __init__(self, num_classes=18, **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.num_classes = num_classes
12
+
13
+ # 2. Update your Model class to inherit from PreTrainedModel
14
+ class SwinClassifier(PreTrainedModel):
15
+ config_class = SwinClassifierConfig
16
+
17
+ def __init__(self, config):
18
+ super().__init__(config)
19
+ # Use config.num_classes instead of a raw number
20
  self.backbone = swin_t()
 
 
21
  num_features = self.backbone.head.in_features
22
 
 
23
  self.backbone.head = nn.Sequential(
24
  nn.Linear(num_features, 256),
25
  nn.ReLU(inplace=True),
26
  nn.Dropout(0.5),
27
+ # Use the value from the config
28
+ nn.Linear(256, config.num_classes)
29
  )
30
 
31
  def forward(self, x):