Add custom ResNet files
Browse files- __init__.py +0 -0
- configuration_resnet.py +9 -0
- modeling_resnet.py +38 -0
__init__.py
ADDED
|
File without changes
|
configuration_resnet.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from transformers import PretrainedConfig
|
| 3 |
+
|
| 4 |
+
class CustomResNetConfig(PretrainedConfig):
|
| 5 |
+
model_type = "custom-resnet"
|
| 6 |
+
|
| 7 |
+
def __init__(self, num_labels=2, **kwargs):
|
| 8 |
+
super().__init__(**kwargs)
|
| 9 |
+
self.num_labels = num_labels # Register number of labels (output dimensions)
|
modeling_resnet.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import AutoModelForImageClassification, PreTrainedModel, hf_hub_download
|
| 6 |
+
from .configuration_resnet import CustomResNetConfig
|
| 7 |
+
|
| 8 |
+
class CustomResNetModel(PreTrainedModel):
|
| 9 |
+
config_class = CustomResNetConfig
|
| 10 |
+
|
| 11 |
+
def __init__(self, config):
|
| 12 |
+
super().__init__(config)
|
| 13 |
+
# Load pre-trained ResNet model
|
| 14 |
+
self.resnet = AutoModelForImageClassification.from_pretrained(config.model_name)
|
| 15 |
+
|
| 16 |
+
# Modify classifier
|
| 17 |
+
in_features = self.resnet.classifier[1].in_features
|
| 18 |
+
self.resnet.classifier = nn.Sequential(
|
| 19 |
+
nn.Flatten(),
|
| 20 |
+
nn.Linear(in_features, config.num_labels)
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
return self.resnet(x)
|
| 25 |
+
|
| 26 |
+
def save_pretrained(self, save_directory, **kwargs):
|
| 27 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 28 |
+
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
|
| 29 |
+
self.config.save_pretrained(save_directory)
|
| 30 |
+
|
| 31 |
+
@classmethod
|
| 32 |
+
def from_pretrained(cls, repo_id, **kwargs):
|
| 33 |
+
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
|
| 34 |
+
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
|
| 35 |
+
config = CustomResNetConfig.from_pretrained(config_path)
|
| 36 |
+
model = cls(config)
|
| 37 |
+
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
| 38 |
+
return model
|