lling0212 commited on
Commit
be6e105
·
1 Parent(s): 64c9b8b

Add custom ResNet files

Browse files
Files changed (3) hide show
  1. __init__.py +0 -0
  2. configuration_resnet.py +9 -0
  3. modeling_resnet.py +38 -0
__init__.py ADDED
File without changes
configuration_resnet.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import PretrainedConfig
3
+
4
+ class CustomResNetConfig(PretrainedConfig):
5
+ model_type = "custom-resnet"
6
+
7
+ def __init__(self, num_labels=2, **kwargs):
8
+ super().__init__(**kwargs)
9
+ self.num_labels = num_labels # Register number of labels (output dimensions)
modeling_resnet.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoModelForImageClassification, PreTrainedModel, hf_hub_download
6
+ from .configuration_resnet import CustomResNetConfig
7
+
8
+ class CustomResNetModel(PreTrainedModel):
9
+ config_class = CustomResNetConfig
10
+
11
+ def __init__(self, config):
12
+ super().__init__(config)
13
+ # Load pre-trained ResNet model
14
+ self.resnet = AutoModelForImageClassification.from_pretrained(config.model_name)
15
+
16
+ # Modify classifier
17
+ in_features = self.resnet.classifier[1].in_features
18
+ self.resnet.classifier = nn.Sequential(
19
+ nn.Flatten(),
20
+ nn.Linear(in_features, config.num_labels)
21
+ )
22
+
23
+ def forward(self, x):
24
+ return self.resnet(x)
25
+
26
+ def save_pretrained(self, save_directory, **kwargs):
27
+ os.makedirs(save_directory, exist_ok=True)
28
+ torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
29
+ self.config.save_pretrained(save_directory)
30
+
31
+ @classmethod
32
+ def from_pretrained(cls, repo_id, **kwargs):
33
+ model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
34
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
35
+ config = CustomResNetConfig.from_pretrained(config_path)
36
+ model = cls(config)
37
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
38
+ return model