Upload folder using huggingface_hub
Browse files- configuration_resnet.py +0 -12
- modeling_resnet.py +4 -2
configuration_resnet.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
| 1 |
from transformers import PretrainedConfig
|
| 2 |
from typing import List
|
| 3 |
-
from pprint import pprint
|
| 4 |
-
|
| 5 |
|
| 6 |
class ResnetConfig(PretrainedConfig):
|
| 7 |
model_type = "faen_resnet"
|
|
@@ -34,13 +32,3 @@ class ResnetConfig(PretrainedConfig):
|
|
| 34 |
self.stem_type = stem_type
|
| 35 |
self.avg_down = avg_down
|
| 36 |
super().__init__(**kwargs)
|
| 37 |
-
|
| 38 |
-
if __name__ == "__main__":
|
| 39 |
-
resnet50d_config = ResnetConfig(block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True)
|
| 40 |
-
print("init a ResnetConfig, it is:\n")
|
| 41 |
-
pprint(resnet50d_config)
|
| 42 |
-
resnet50d_config.save_pretrained("./")
|
| 43 |
-
resnet50d_config = ResnetConfig.from_pretrained("./")
|
| 44 |
-
print("\n")
|
| 45 |
-
print("saved to file config.json and reload it from config.json and it is:\n")
|
| 46 |
-
pprint(resnet50d_config)
|
|
|
|
| 1 |
from transformers import PretrainedConfig
|
| 2 |
from typing import List
|
|
|
|
|
|
|
| 3 |
|
| 4 |
class ResnetConfig(PretrainedConfig):
|
| 5 |
model_type = "faen_resnet"
|
|
|
|
| 32 |
self.stem_type = stem_type
|
| 33 |
self.avg_down = avg_down
|
| 34 |
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modeling_resnet.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
from transformers import PreTrainedModel
|
| 2 |
from timm.models.resnet import BasicBlock, Bottleneck, ResNet
|
| 3 |
-
from configuration_resnet import ResnetConfig
|
| 4 |
import torch
|
| 5 |
|
| 6 |
BLOCK_MAPPING = {"basic": BasicBlock, "bottleneck": Bottleneck}
|
| 7 |
|
| 8 |
|
| 9 |
class ResnetModel(PreTrainedModel):
|
|
|
|
| 10 |
|
| 11 |
def __init__(self, config):
|
| 12 |
super().__init__(config)
|
|
@@ -28,7 +29,8 @@ class ResnetModel(PreTrainedModel):
|
|
| 28 |
|
| 29 |
|
| 30 |
class ResnetModelForImageClassification(PreTrainedModel):
|
| 31 |
-
|
|
|
|
| 32 |
def __init__(self, config):
|
| 33 |
super().__init__(config)
|
| 34 |
block_layer = BLOCK_MAPPING[config.block_type]
|
|
|
|
| 1 |
from transformers import PreTrainedModel
|
| 2 |
from timm.models.resnet import BasicBlock, Bottleneck, ResNet
|
| 3 |
+
from .configuration_resnet import ResnetConfig
|
| 4 |
import torch
|
| 5 |
|
| 6 |
BLOCK_MAPPING = {"basic": BasicBlock, "bottleneck": Bottleneck}
|
| 7 |
|
| 8 |
|
| 9 |
class ResnetModel(PreTrainedModel):
|
| 10 |
+
config_class = ResnetConfig
|
| 11 |
|
| 12 |
def __init__(self, config):
|
| 13 |
super().__init__(config)
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
class ResnetModelForImageClassification(PreTrainedModel):
|
| 32 |
+
config_class = ResnetConfig
|
| 33 |
+
|
| 34 |
def __init__(self, config):
|
| 35 |
super().__init__(config)
|
| 36 |
block_layer = BLOCK_MAPPING[config.block_type]
|