wendys-llc commited on
Commit
f1eeeea
·
verified ·
1 Parent(s): cb53782

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +37 -0
model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+ import torch.nn as nn
3
+ from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
4
+
5
+ class CheckboxConfig(PretrainedConfig):
6
+ model_type = "checkbox-classifier"
7
+
8
+ def __init__(self, num_labels=2, dropout_rate=0.3, **kwargs):
9
+ super().__init__(num_labels=num_labels, **kwargs)
10
+ self.dropout_rate = dropout_rate
11
+
12
+ class CheckboxClassifier(PreTrainedModel):
13
+ config_class = CheckboxConfig
14
+
15
+ def __init__(self, config):
16
+ super().__init__(config)
17
+ self.num_labels = config.num_labels
18
+
19
+ self.backbone = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
20
+ num_features = self.backbone.classifier[1].in_features
21
+
22
+ self.backbone.classifier = nn.Sequential(
23
+ nn.Dropout(config.dropout_rate),
24
+ nn.Linear(num_features, 512),
25
+ nn.SiLU(inplace=True),
26
+ nn.BatchNorm1d(512),
27
+ nn.Dropout(config.dropout_rate),
28
+ nn.Linear(512, 256),
29
+ nn.SiLU(inplace=True),
30
+ nn.BatchNorm1d(256),
31
+ nn.Dropout(config.dropout_rate/2),
32
+ nn.Linear(256, config.num_labels)
33
+ )
34
+
35
+ def forward(self, pixel_values):
36
+ outputs = self.backbone(pixel_values)
37
+ return {"logits": outputs}