ABAO77 commited on
Commit
5147546
·
verified ·
1 Parent(s): 4dc9354

Upload feature_extractor.py

Browse files
Files changed (1) hide show
  1. src/modules/feature_extractor.py +87 -14
src/modules/feature_extractor.py CHANGED
@@ -1,28 +1,60 @@
1
-
2
  import torchvision.models.feature_extraction
3
  import torchvision
4
  import os
5
  import torch
 
 
 
6
 
7
  from .config_extractor import MODEL_CONFIG
8
 
9
  os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
10
 
 
11
  class FeatureExtractor:
12
  """Class for extracting features from images using a pre-trained model"""
13
 
14
- def __init__(self, base_model):
15
  # set the base model
16
  self.base_model = base_model
17
  # get the number of features
18
  self.feat_dims = MODEL_CONFIG[base_model]["feat_dims"]
19
- # initialize the image transformations
20
- self.model, self.transforms = self.init_model(base_model)
21
- self.model.eval() # set the model to evaluation mode
22
- # move the model to MPS
23
- self.device = torch.device("cpu")
24
- print("device is", self.device)
25
- self.model.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def init_model(self, base_model):
28
  """Initialize the model for feature extraction
@@ -60,8 +92,49 @@ class FeatureExtractor:
60
  x = self.transforms(img)
61
  # add batch dimension
62
  x = x.unsqueeze(0)
63
- # move the tensor to MPS
64
- x = x.to(self.device)
65
- # output now has the features corresponding to input x
66
- output = self.model(x)[MODEL_CONFIG[self.base_model]["feat_layer"]]
67
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torchvision.models.feature_extraction
2
  import torchvision
3
  import os
4
  import torch
5
+ import onnx
6
+ import onnxruntime
7
+ import numpy as np
8
 
9
  from .config_extractor import MODEL_CONFIG
10
 
11
  os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
12
 
13
+
14
  class FeatureExtractor:
15
  """Class for extracting features from images using a pre-trained model"""
16
 
17
+ def __init__(self, base_model, onnx_path=None):
18
  # set the base model
19
  self.base_model = base_model
20
  # get the number of features
21
  self.feat_dims = MODEL_CONFIG[base_model]["feat_dims"]
22
+ # get the feature layer name
23
+ self.feat_layer = MODEL_CONFIG[base_model]["feat_layer"]
24
+
25
+ # Set default ONNX path if not provided
26
+ if onnx_path is None:
27
+ onnx_path = f"model/{base_model}_feature_extractor.onnx"
28
+
29
+ self.onnx_path = onnx_path
30
+ self.onnx_session = None
31
+
32
+ # Initialize transforms (needed for both ONNX and PyTorch)
33
+ _, self.transforms = self.init_model(base_model)
34
+
35
+ # Check if ONNX model exists
36
+ if os.path.exists(onnx_path):
37
+ print(f"Loading existing ONNX model from {onnx_path}")
38
+ self.onnx_session = onnxruntime.InferenceSession(onnx_path)
39
+ else:
40
+ print(
41
+ f"ONNX model not found at {onnx_path}. Initializing PyTorch model and converting to ONNX..."
42
+ )
43
+ # Initialize PyTorch model
44
+ self.model, _ = self.init_model(base_model)
45
+ self.model.eval()
46
+ self.device = torch.device("cpu")
47
+ self.model.to(self.device)
48
+
49
+ # Create directory if it doesn't exist
50
+ os.makedirs(os.path.dirname(onnx_path), exist_ok=True)
51
+
52
+ # Convert to ONNX
53
+ self.convert_to_onnx(onnx_path)
54
+
55
+ # Load the newly created ONNX model
56
+ self.onnx_session = onnxruntime.InferenceSession(onnx_path)
57
+ print(f"Successfully created and loaded ONNX model from {onnx_path}")
58
 
59
  def init_model(self, base_model):
60
  """Initialize the model for feature extraction
 
92
  x = self.transforms(img)
93
  # add batch dimension
94
  x = x.unsqueeze(0)
95
+
96
+ # Convert to numpy for ONNX Runtime
97
+ x_numpy = x.numpy()
98
+ # Run inference with ONNX Runtime
99
+ print("Running inference with ONNX Runtime")
100
+ output = self.onnx_session.run(
101
+ None,
102
+ {'input': x_numpy}
103
+ )[0]
104
+ # Convert back to torch tensor
105
+ output = torch.from_numpy(output)
106
+
107
+ return output
108
+
109
+ def convert_to_onnx(self, save_path):
110
+ """Convert the model to ONNX format and save it
111
+
112
+ Args:
113
+ save_path: str, the path to save the ONNX model
114
+
115
+ Returns:
116
+ None
117
+ """
118
+ # Create a dummy input tensor
119
+ dummy_input = torch.randn(1, 3, 224, 224, device=self.device)
120
+
121
+ # Export the model
122
+ torch.onnx.export(
123
+ self.model,
124
+ dummy_input,
125
+ save_path,
126
+ export_params=True,
127
+ opset_version=14,
128
+ do_constant_folding=True,
129
+ input_names=['input'],
130
+ output_names=['output'],
131
+ dynamic_axes={
132
+ 'input': {0: 'batch_size'},
133
+ 'output': {0: 'batch_size'}
134
+ }
135
+ )
136
+
137
+ # Verify the exported model
138
+ onnx_model = onnx.load(save_path)
139
+ onnx.checker.check_model(onnx_model)
140
+ print(f"ONNX model saved to {save_path}")