File size: 4,499 Bytes
982b011
 
 
 
5147546
 
 
982b011
 
 
 
 
5147546
982b011
 
 
5147546
982b011
 
 
 
5147546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
982b011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5147546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torchvision.models.feature_extraction
import torchvision
import os
import torch
import onnx
import onnxruntime
import numpy as np

from .config_extractor import MODEL_CONFIG

os.environ["KMP_DUPLICATE_LIB_OK"] = "True"


class FeatureExtractor:
    """Class for extracting features from images using a pre-trained model"""

    def __init__(self, base_model, onnx_path=None):
        # set the base model
        self.base_model = base_model
        # get the number of features
        self.feat_dims = MODEL_CONFIG[base_model]["feat_dims"]
        # get the feature layer name
        self.feat_layer = MODEL_CONFIG[base_model]["feat_layer"]

        # Set default ONNX path if not provided
        if onnx_path is None:
            onnx_path = f"model/{base_model}_feature_extractor.onnx"

        self.onnx_path = onnx_path
        self.onnx_session = None

        # Initialize transforms (needed for both ONNX and PyTorch)
        _, self.transforms = self.init_model(base_model)

        # Check if ONNX model exists
        if os.path.exists(onnx_path):
            print(f"Loading existing ONNX model from {onnx_path}")
            self.onnx_session = onnxruntime.InferenceSession(onnx_path)
        else:
            print(
                f"ONNX model not found at {onnx_path}. Initializing PyTorch model and converting to ONNX..."
            )
            # Initialize PyTorch model
            self.model, _ = self.init_model(base_model)
            self.model.eval()
            self.device = torch.device("cpu")
            self.model.to(self.device)

            # Create directory if it doesn't exist
            os.makedirs(os.path.dirname(onnx_path), exist_ok=True)

            # Convert to ONNX
            self.convert_to_onnx(onnx_path)

            # Load the newly created ONNX model
            self.onnx_session = onnxruntime.InferenceSession(onnx_path)
            print(f"Successfully created and loaded ONNX model from {onnx_path}")

    def init_model(self, base_model):
        """Initialize the model for feature extraction

        Args:
            base_model: str, the name of the base model

        Returns:
            model: torch.nn.Module, the feature extraction model
            transforms: torchvision.transforms.Compose, the image transformations
        """
        if base_model not in MODEL_CONFIG:
            raise ValueError(f"Invalid base model: {base_model}")

        # get the model and weights
        weights = MODEL_CONFIG[base_model]["weights"]
        model = torchvision.models.feature_extraction.create_feature_extractor(
            MODEL_CONFIG[base_model]["model"](weights=weights),
            [MODEL_CONFIG[base_model]["feat_layer"]],
        )
        # get the image transformations
        transforms = weights.transforms()
        return model, transforms

    def extract_features(self, img):
        """Extract features from an image

        Args:
            img: PIL.Image, the input image

        Returns:
            output: torch.Tensor, the extracted features
        """
        # apply transformations
        x = self.transforms(img)
        # add batch dimension
        x = x.unsqueeze(0)

        # Convert to numpy for ONNX Runtime
        x_numpy = x.numpy()
        # Run inference with ONNX Runtime
        print("Running inference with ONNX Runtime")
        output = self.onnx_session.run(
            None,
            {'input': x_numpy}
        )[0]
        # Convert back to torch tensor
        output = torch.from_numpy(output)

        return output

    def convert_to_onnx(self, save_path):
        """Convert the model to ONNX format and save it

        Args:
            save_path: str, the path to save the ONNX model

        Returns:
            None
        """
        # Create a dummy input tensor
        dummy_input = torch.randn(1, 3, 224, 224, device=self.device)

        # Export the model
        torch.onnx.export(
            self.model,
            dummy_input,
            save_path,
            export_params=True,
            opset_version=14,
            do_constant_folding=True,
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={
                'input': {0: 'batch_size'},
                'output': {0: 'batch_size'}
            }
        )

        # Verify the exported model
        onnx_model = onnx.load(save_path)
        onnx.checker.check_model(onnx_model)
        print(f"ONNX model saved to {save_path}")