Thastp commited on
Commit
d1f81df
·
verified ·
1 Parent(s): 8d76e0c

Update modeling_efficientnet.py

Browse files
Files changed (1) hide show
  1. modeling_efficientnet.py +45 -45
modeling_efficientnet.py CHANGED
@@ -1,46 +1,46 @@
1
- from torch import nn
2
-
3
- from transformers import PreTrainedModel
4
- from transformers.modeling_outputs import BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention
5
- from timm import create_model
6
-
7
- from configuration_efficientnet import EfficientNetConfig
8
-
9
- class EfficientNetModel(PreTrainedModel):
10
- config_class = EfficientNetConfig
11
-
12
- def __init__(self, config):
13
- super().__init__(config)
14
-
15
- self.config = config
16
- self.model = create_model(config.model_name, pretrained = config.pretrained)
17
-
18
- def forward(self, pixel_values):
19
- last_hidden_state = self.model.forward_features(pixel_values)
20
- return BaseModelOutputWithPoolingAndNoAttention(
21
- last_hidden_state = last_hidden_state
22
- )
23
-
24
- class EfficientNetModelForImageClassification(PreTrainedModel):
25
- config_class = EfficientNetConfig
26
-
27
- def __init__(self, config):
28
- super().__init__(config)
29
-
30
- self.config = config
31
- self.model = create_model(config.model_name, pretrained = config.pretrained)
32
-
33
- def forward(self, pixel_values, labels=None):
34
- logits = self.model(pixel_values)
35
- loss = None
36
- if labels is not None:
37
- loss = nn.CrossEntropyLoss(logits, labels)
38
- return ImageClassifierOutputWithNoAttention(
39
- loss = loss,
40
- logits = logits
41
- )
42
-
43
- __all__ = [
44
- "EfficientNetModel",
45
- "EfficientNetModelForImageClassification"
46
  ]
 
1
+ from torch import nn
2
+
3
+ from transformers import PreTrainedModel
4
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention
5
+ from timm import create_model
6
+
7
+ from .configuration_efficientnet import EfficientNetConfig
8
+
9
+ class EfficientNetModel(PreTrainedModel):
10
+ config_class = EfficientNetConfig
11
+
12
+ def __init__(self, config):
13
+ super().__init__(config)
14
+
15
+ self.config = config
16
+ self.model = create_model(config.model_name, pretrained = config.pretrained)
17
+
18
+ def forward(self, pixel_values):
19
+ last_hidden_state = self.model.forward_features(pixel_values)
20
+ return BaseModelOutputWithPoolingAndNoAttention(
21
+ last_hidden_state = last_hidden_state
22
+ )
23
+
24
+ class EfficientNetModelForImageClassification(PreTrainedModel):
25
+ config_class = EfficientNetConfig
26
+
27
+ def __init__(self, config):
28
+ super().__init__(config)
29
+
30
+ self.config = config
31
+ self.model = create_model(config.model_name, pretrained = config.pretrained)
32
+
33
+ def forward(self, pixel_values, labels=None):
34
+ logits = self.model(pixel_values)
35
+ loss = None
36
+ if labels is not None:
37
+ loss = nn.CrossEntropyLoss(logits, labels)
38
+ return ImageClassifierOutputWithNoAttention(
39
+ loss = loss,
40
+ logits = logits
41
+ )
42
+
43
+ __all__ = [
44
+ "EfficientNetModel",
45
+ "EfficientNetModelForImageClassification"
46
  ]