File size: 5,241 Bytes
322d2b7
 
 
 
 
 
afec78c
 
322d2b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9102e2e
 
 
322d2b7
 
 
 
 
 
 
 
eea1fcf
92f42f9
322d2b7
 
 
 
 
 
 
 
 
9102e2e
 
322d2b7
9102e2e
322d2b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92f42f9
afec78c
9102e2e
 
 
322d2b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afec78c
322d2b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afec78c
 
322d2b7
 
afec78c
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torchvision.models.feature_extraction
import torchvision
import os
import torch
import onnx
import onnxruntime
from onnxconverter_common import float16
import numpy as np
from src.modules.config_extractor import MODEL_CONFIG

os.environ["KMP_DUPLICATE_LIB_OK"] = "True"


class FeatureExtractor:
    """Class for extracting features from images using a pre-trained model"""

    def __init__(self, base_model, onnx_path=None):
        # set the base model
        self.base_model = base_model
        # get the number of features
        self.feat_dims = MODEL_CONFIG[base_model]["feat_dims"]
        # get the feature layer name
        self.feat_layer = MODEL_CONFIG[base_model]["feat_layer"]

        # Set default ONNX path if not provided
        if onnx_path is None:
            onnx_path = f"model/{base_model}_feature_extractor.onnx"

        self.onnx_path = onnx_path
        self.onnx_session = None

        # Initialize transforms (needed for both ONNX and PyTorch)
        _, self.transforms = self.init_model(base_model)

        # Check if ONNX model exists
        if os.path.exists(onnx_path):
            print(f"Loading existing ONNX model from {onnx_path}")
            # Get model size in MB
            model_size = os.path.getsize(onnx_path) / (1024 * 1024)
            print(f"Model size: {model_size:.2f} MB")
            self.onnx_session = onnxruntime.InferenceSession(onnx_path)
        else:
            print(
                f"ONNX model not found at {onnx_path}. Initializing PyTorch model and converting to ONNX..."
            )
            # Initialize PyTorch model
            self.model, _ = self.init_model(base_model)
            self.model.eval()
            self.device = torch.device("cpu")
            self.model.to(self.device)

            # Create directory if it doesn't exist
            os.makedirs(os.path.dirname(onnx_path), exist_ok=True)

            # Convert to ONNX
            self.convert_to_onnx(onnx_path)

            # Load the newly created ONNX model
            self.onnx_session = onnxruntime.InferenceSession(onnx_path)
            # Get model size in MB
            model_size = os.path.getsize(onnx_path) / (1024 * 1024)
            print(f"Successfully created and loaded ONNX model from {onnx_path}")
            print(f"Model size: {model_size:.2f} MB")

    def init_model(self, base_model):
        """Initialize the model for feature extraction

        Args:
            base_model: str, the name of the base model

        Returns:
            model: torch.nn.Module, the feature extraction model
            transforms: torchvision.transforms.Compose, the image transformations
        """
        if base_model not in MODEL_CONFIG:
            raise ValueError(f"Invalid base model: {base_model}")

        # get the model and weights
        weights = MODEL_CONFIG[base_model]["weights"]
        model = torchvision.models.feature_extraction.create_feature_extractor(
            MODEL_CONFIG[base_model]["model"](weights=weights),
            [MODEL_CONFIG[base_model]["feat_layer"]],
        )
        # get the image transformations
        transforms = weights.transforms()
        return model, transforms

    def extract_features(self, img):
        """Extract features from an image

        Args:
            img: PIL.Image, the input image

        Returns:
            output: torch.Tensor, the extracted features
        """
        # apply transformations
        x = self.transforms(img)
        # add batch dimension
        x = x.unsqueeze(0)

        # Convert to numpy for ONNX Runtime
        x_numpy = x.numpy().astype(np.float16)
        # Get model size in MB
        model_size = os.path.getsize(self.onnx_path) / (1024 * 1024)
        print(f"Running inference with ONNX model (size: {model_size:.2f} MB)")
        output = self.onnx_session.run(
            None,
            {'input': x_numpy}
        )[0]
        # Convert back to torch tensor
        output = torch.from_numpy(output)

        return output

    def convert_to_onnx(self, save_path):
        """Convert the model to ONNX format and save it

        Args:
            save_path: str, the path to save the ONNX model

        Returns:
            None
        """
        # Create a dummy input tensor
        dummy_input = torch.randn(1, 3, 224, 224, device=self.device)

        # Export the model to ONNX (FP32 first)
        torch.onnx.export(
            self.model,
            dummy_input,
            save_path,
            export_params=True,
            opset_version=14,
            do_constant_folding=True,
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={
                'input': {0: 'batch_size'},
                'output': {0: 'batch_size'}
            }
        )

        # Load and verify the exported model
        print("Converting model to float16...")
        onnx_model = onnx.load(save_path)
        onnx.checker.check_model(onnx_model)
        
        # Convert to float16
        model_fp16 = float16.convert_float_to_float16(onnx_model)
        
        # Save the float16 model
        onnx.save(model_fp16, save_path)
        print(f"Float16 ONNX model saved to {save_path}")