File size: 1,440 Bytes
be6e105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

import os
import torch
import torch.nn as nn
from transformers import AutoModelForImageClassification, PreTrainedModel, hf_hub_download
from .configuration_resnet import CustomResNetConfig

class CustomResNetModel(PreTrainedModel):
    config_class = CustomResNetConfig

    def __init__(self, config):
        super().__init__(config)
        # Load pre-trained ResNet model
        self.resnet = AutoModelForImageClassification.from_pretrained(config.model_name)
        
        # Modify classifier
        in_features = self.resnet.classifier[1].in_features
        self.resnet.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features, config.num_labels)
        )

    def forward(self, x):
        return self.resnet(x)

    def save_pretrained(self, save_directory, **kwargs):
        os.makedirs(save_directory, exist_ok=True)
        torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
        self.config.save_pretrained(save_directory)

    @classmethod
    def from_pretrained(cls, repo_id, **kwargs):
        model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
        config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
        config = CustomResNetConfig.from_pretrained(config_path)
        model = cls(config)
        model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
        return model