File size: 970 Bytes
e89a14c
 
 
 
db7d3b8
e89a14c
 
 
695e3cb
e89a14c
db7d3b8
 
 
 
 
 
 
e89a14c
 
 
 
 
695e3cb
e89a14c
695e3cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import timm

from resnet_model.configuration_resnet import ResnetConfig
from resnet_model.modeling_resnet import ResnetModel, ResnetModelForImageClassification
from transformers import AutoConfig, AutoModel, AutoModelForImageClassification

ResnetConfig.register_for_auto_class()
ResnetModel.register_for_auto_class("AutoModel")
ResnetModelForImageClassification.register_for_auto_class("AutoModel")


# AutoConfig.register("rgbdsod-resnet", ResnetConfig)
# AutoModel.register(ResnetConfig, ResnetModel)
# AutoModelForImageClassification.register(
#     ResnetConfig, ResnetModelForImageClassification
# )

resnet50d_config = ResnetConfig(
    block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True
)
resnet50d = ResnetModelForImageClassification(resnet50d_config)
pretrained_model = timm.create_model("resnet50d", pretrained=True)
resnet50d.model.model.load_state_dict(pretrained_model.state_dict())

resnet50d.push_to_hub("RGBD-SOD/custom-resnet50d")