resnet_model / modeling_resnet.py
lling0212's picture
Add custom ResNet files
be6e105
import os
import torch
import torch.nn as nn
from transformers import AutoModelForImageClassification, PreTrainedModel, hf_hub_download
from .configuration_resnet import CustomResNetConfig
class CustomResNetModel(PreTrainedModel):
config_class = CustomResNetConfig
def __init__(self, config):
super().__init__(config)
# Load pre-trained ResNet model
self.resnet = AutoModelForImageClassification.from_pretrained(config.model_name)
# Modify classifier
in_features = self.resnet.classifier[1].in_features
self.resnet.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features, config.num_labels)
)
def forward(self, x):
return self.resnet(x)
def save_pretrained(self, save_directory, **kwargs):
os.makedirs(save_directory, exist_ok=True)
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
self.config.save_pretrained(save_directory)
@classmethod
def from_pretrained(cls, repo_id, **kwargs):
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
config = CustomResNetConfig.from_pretrained(config_path)
model = cls(config)
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
return model