import os import torch import torch.nn as nn from transformers import AutoModelForImageClassification, PreTrainedModel, hf_hub_download from .configuration_resnet import CustomResNetConfig class CustomResNetModel(PreTrainedModel): config_class = CustomResNetConfig def __init__(self, config): super().__init__(config) # Load pre-trained ResNet model self.resnet = AutoModelForImageClassification.from_pretrained(config.model_name) # Modify classifier in_features = self.resnet.classifier[1].in_features self.resnet.classifier = nn.Sequential( nn.Flatten(), nn.Linear(in_features, config.num_labels) ) def forward(self, x): return self.resnet(x) def save_pretrained(self, save_directory, **kwargs): os.makedirs(save_directory, exist_ok=True) torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin")) self.config.save_pretrained(save_directory) @classmethod def from_pretrained(cls, repo_id, **kwargs): model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") config_path = hf_hub_download(repo_id=repo_id, filename="config.json") config = CustomResNetConfig.from_pretrained(config_path) model = cls(config) model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) return model