Thastp commited on
Commit
04cce47
·
verified ·
1 Parent(s): b31a063

Upload model

Browse files
Files changed (2) hide show
  1. config.json +2 -4
  2. modeling_efficientnet.py +9 -3
config.json CHANGED
@@ -1,12 +1,10 @@
1
  {
2
- "_name_or_path": "./efficientnet/temp",
3
  "architectures": [
4
- "EfficientNetModelForImageClassification"
5
  ],
6
  "auto_map": {
7
  "AutoConfig": "configuration_efficientnet.EfficientNetConfig",
8
- "AutoModel": "modeling_efficientnet.EfficientNetModel",
9
- "AutoModelForImageClassification": "modeling_efficientnet.EfficientNetModelForImageClassification"
10
  },
11
  "model_name": "efficientnet_b0",
12
  "model_type": "efficientnet",
 
1
  {
 
2
  "architectures": [
3
+ "EfficientNetModel"
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "configuration_efficientnet.EfficientNetConfig",
7
+ "AutoModel": "modeling_efficientnet.EfficientNetModel"
 
8
  },
9
  "model_name": "efficientnet_b0",
10
  "model_type": "efficientnet",
modeling_efficientnet.py CHANGED
@@ -2,6 +2,7 @@ from torch import nn
2
  from functools import partial
3
 
4
  from transformers import PreTrainedModel
 
5
  from timm import create_model
6
 
7
  from configuration_efficientnet import EfficientNetConfig
@@ -15,7 +16,10 @@ class EfficientNetModel(PreTrainedModel):
15
  self.model = create_model(config.model_name, pretrained = config.pretrained)
16
 
17
  def forward(self, pixel_values):
18
- return self.model.forward_features(pixel_values)
 
 
 
19
 
20
  class EfficientNetModelForImageClassification(PreTrainedModel):
21
  config_class = EfficientNetConfig
@@ -29,8 +33,10 @@ class EfficientNetModelForImageClassification(PreTrainedModel):
29
  logits = self.model(pixel_values)
30
  if labels is not None:
31
  loss = nn.CrossEntropyLoss(logits, labels)
32
- return {"loss": loss, "logits": logits}
33
- return logits
 
 
34
 
35
  __all__ = [
36
  "EfficientNetModel",
 
2
  from functools import partial
3
 
4
  from transformers import PreTrainedModel
5
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention
6
  from timm import create_model
7
 
8
  from configuration_efficientnet import EfficientNetConfig
 
16
  self.model = create_model(config.model_name, pretrained = config.pretrained)
17
 
18
  def forward(self, pixel_values):
19
+ last_hidden_state = self.model.forward_features(pixel_values)
20
+ return BaseModelOutputWithPoolingAndNoAttention(
21
+ last_hidden_state = last_hidden_state
22
+ )
23
 
24
  class EfficientNetModelForImageClassification(PreTrainedModel):
25
  config_class = EfficientNetConfig
 
33
  logits = self.model(pixel_values)
34
  if labels is not None:
35
  loss = nn.CrossEntropyLoss(logits, labels)
36
+ return ImageClassifierOutputWithNoAttention(
37
+ loss = loss,
38
+ logits = logits
39
+ )
40
 
41
  __all__ = [
42
  "EfficientNetModel",