Thastp commited on
Commit
fbcd6c6
·
verified ·
1 Parent(s): b265421

Update modeling_efficientnet.py

Browse files
Files changed (1) hide show
  1. modeling_efficientnet.py +51 -51
modeling_efficientnet.py CHANGED
@@ -1,52 +1,52 @@
1
- from torch import nn
2
-
3
- from transformers import PreTrainedModel
4
- from transformers.modeling_outputs import BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention
5
- from timm import create_model
6
-
7
- from configuration_efficientnet import EfficientNetConfig
8
-
9
- class EfficientNetModel(PreTrainedModel):
10
- config_class = EfficientNetConfig
11
-
12
- def __init__(self, config):
13
- super().__init__(config)
14
-
15
- self.config = config
16
- self.model = create_model(config.model_name,
17
- pretrained = config.pretrained,
18
- num_classes = config.num_classes,
19
- global_pool = config.global_pool)
20
-
21
- def forward(self, pixel_values):
22
- last_hidden_state = self.model.forward_features(pixel_values)
23
- return BaseModelOutputWithPoolingAndNoAttention(
24
- last_hidden_state = last_hidden_state
25
- )
26
-
27
- class EfficientNetModelForImageClassification(PreTrainedModel):
28
- config_class = EfficientNetConfig
29
-
30
- def __init__(self, config):
31
- super().__init__(config)
32
-
33
- self.config = config
34
- self.model = create_model(config.model_name,
35
- pretrained = config.pretrained,
36
- num_classes = config.num_classes,
37
- global_pool = config.global_pool)
38
-
39
- def forward(self, pixel_values, labels=None):
40
- logits = self.model(pixel_values)
41
- loss = None
42
- if labels is not None:
43
- loss = nn.CrossEntropyLoss(logits, labels)
44
- return ImageClassifierOutputWithNoAttention(
45
- loss = loss,
46
- logits = logits
47
- )
48
-
49
- __all__ = [
50
- "EfficientNetModel",
51
- "EfficientNetModelForImageClassification"
52
  ]
 
1
+ from torch import nn
2
+
3
+ from transformers import PreTrainedModel
4
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention
5
+ from timm import create_model
6
+
7
+ from .configuration_efficientnet import EfficientNetConfig
8
+
9
+ class EfficientNetModel(PreTrainedModel):
10
+ config_class = EfficientNetConfig
11
+
12
+ def __init__(self, config):
13
+ super().__init__(config)
14
+
15
+ self.config = config
16
+ self.model = create_model(config.model_name,
17
+ pretrained = config.pretrained,
18
+ num_classes = config.num_classes,
19
+ global_pool = config.global_pool)
20
+
21
+ def forward(self, pixel_values):
22
+ last_hidden_state = self.model.forward_features(pixel_values)
23
+ return BaseModelOutputWithPoolingAndNoAttention(
24
+ last_hidden_state = last_hidden_state
25
+ )
26
+
27
+ class EfficientNetModelForImageClassification(PreTrainedModel):
28
+ config_class = EfficientNetConfig
29
+
30
+ def __init__(self, config):
31
+ super().__init__(config)
32
+
33
+ self.config = config
34
+ self.model = create_model(config.model_name,
35
+ pretrained = config.pretrained,
36
+ num_classes = config.num_classes,
37
+ global_pool = config.global_pool)
38
+
39
+ def forward(self, pixel_values, labels=None):
40
+ logits = self.model(pixel_values)
41
+ loss = None
42
+ if labels is not None:
43
+ loss = nn.CrossEntropyLoss(logits, labels)
44
+ return ImageClassifierOutputWithNoAttention(
45
+ loss = loss,
46
+ logits = logits
47
+ )
48
+
49
+ __all__ = [
50
+ "EfficientNetModel",
51
+ "EfficientNetModelForImageClassification"
52
  ]