|
|
|
|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoModelForImageClassification, PreTrainedModel, hf_hub_download |
|
|
from .configuration_resnet import CustomResNetConfig |
|
|
|
|
|
class CustomResNetModel(PreTrainedModel): |
|
|
config_class = CustomResNetConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
self.resnet = AutoModelForImageClassification.from_pretrained(config.model_name) |
|
|
|
|
|
|
|
|
in_features = self.resnet.classifier[1].in_features |
|
|
self.resnet.classifier = nn.Sequential( |
|
|
nn.Flatten(), |
|
|
nn.Linear(in_features, config.num_labels) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.resnet(x) |
|
|
|
|
|
def save_pretrained(self, save_directory, **kwargs): |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin")) |
|
|
self.config.save_pretrained(save_directory) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, repo_id, **kwargs): |
|
|
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") |
|
|
config_path = hf_hub_download(repo_id=repo_id, filename="config.json") |
|
|
config = CustomResNetConfig.from_pretrained(config_path) |
|
|
model = cls(config) |
|
|
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) |
|
|
return model |
|
|
|