Thastp commited on
Commit
9824c82
·
verified ·
1 Parent(s): aa7f33b

Update modeling_efficientnet.py

Browse files
Files changed (1) hide show
  1. modeling_efficientnet.py +46 -46
modeling_efficientnet.py CHANGED
@@ -1,47 +1,47 @@
1
- from torch import nn
2
- from functools import partial
3
-
4
- from transformers import PreTrainedModel
5
- from transformers.modeling_outputs import BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention
6
- from timm import create_model
7
-
8
- from configuration_efficientnet import EfficientNetConfig
9
-
10
- class EfficientNetModel(PreTrainedModel):
11
- config_class = EfficientNetConfig
12
-
13
- def __init__(self, config):
14
- super().__init__(config)
15
-
16
- self.config = config
17
- self.model = create_model(config.model_name, pretrained = config.pretrained)
18
-
19
- def forward(self, pixel_values):
20
- last_hidden_state = self.model.forward_features(pixel_values)
21
- return BaseModelOutputWithPoolingAndNoAttention(
22
- last_hidden_state = last_hidden_state
23
- )
24
-
25
- class EfficientNetModelForImageClassification(PreTrainedModel):
26
- config_class = EfficientNetConfig
27
-
28
- def __init__(self, config):
29
- super().__init__(config)
30
-
31
- self.config = config
32
- self.model = create_model(config.model_name, pretrained = config.pretrained)
33
-
34
- def forward(self, pixel_values, labels=None):
35
- logits = self.model(pixel_values)
36
- loss = None
37
- if labels is not None:
38
- loss = nn.CrossEntropyLoss(logits, labels)
39
- return ImageClassifierOutputWithNoAttention(
40
- loss = loss,
41
- logits = logits
42
- )
43
-
44
- __all__ = [
45
- "EfficientNetModel",
46
- "EfficientNetModelForImageClassification"
47
  ]
 
1
+ from torch import nn
2
+ from functools import partial
3
+
4
+ from transformers import PreTrainedModel
5
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention
6
+ from timm import create_model
7
+
8
+ from .configuration_efficientnet import EfficientNetConfig
9
+
10
+ class EfficientNetModel(PreTrainedModel):
11
+ config_class = EfficientNetConfig
12
+
13
+ def __init__(self, config):
14
+ super().__init__(config)
15
+
16
+ self.config = config
17
+ self.model = create_model(config.model_name, pretrained = config.pretrained)
18
+
19
+ def forward(self, pixel_values):
20
+ last_hidden_state = self.model.forward_features(pixel_values)
21
+ return BaseModelOutputWithPoolingAndNoAttention(
22
+ last_hidden_state = last_hidden_state
23
+ )
24
+
25
+ class EfficientNetModelForImageClassification(PreTrainedModel):
26
+ config_class = EfficientNetConfig
27
+
28
+ def __init__(self, config):
29
+ super().__init__(config)
30
+
31
+ self.config = config
32
+ self.model = create_model(config.model_name, pretrained = config.pretrained)
33
+
34
+ def forward(self, pixel_values, labels=None):
35
+ logits = self.model(pixel_values)
36
+ loss = None
37
+ if labels is not None:
38
+ loss = nn.CrossEntropyLoss(logits, labels)
39
+ return ImageClassifierOutputWithNoAttention(
40
+ loss = loss,
41
+ logits = logits
42
+ )
43
+
44
+ __all__ = [
45
+ "EfficientNetModel",
46
+ "EfficientNetModelForImageClassification"
47
  ]