ABAO77 commited on
Commit
9102e2e
·
verified ·
1 Parent(s): 51b22ee

Upload feature_extractor.py

Browse files
Files changed (1) hide show
  1. src/modules/feature_extractor.py +17 -5
src/modules/feature_extractor.py CHANGED
@@ -34,6 +34,9 @@ class FeatureExtractor:
34
  # Check if ONNX model exists
35
  if os.path.exists(onnx_path):
36
  print(f"Loading existing ONNX model from {onnx_path}")
 
 
 
37
  self.onnx_session = onnxruntime.InferenceSession(onnx_path)
38
  else:
39
  print(
@@ -42,8 +45,8 @@ class FeatureExtractor:
42
  # Initialize PyTorch model
43
  self.model, _ = self.init_model(base_model)
44
  self.model.eval()
45
- self.device = torch.device("cpu")
46
- self.model.to(self.device)
47
 
48
  # Create directory if it doesn't exist
49
  os.makedirs(os.path.dirname(onnx_path), exist_ok=True)
@@ -53,7 +56,10 @@ class FeatureExtractor:
53
 
54
  # Load the newly created ONNX model
55
  self.onnx_session = onnxruntime.InferenceSession(onnx_path)
 
 
56
  print(f"Successfully created and loaded ONNX model from {onnx_path}")
 
57
 
58
  def init_model(self, base_model):
59
  """Initialize the model for feature extraction
@@ -92,10 +98,12 @@ class FeatureExtractor:
92
  # add batch dimension
93
  x = x.unsqueeze(0)
94
 
95
- # Convert to numpy for ONNX Runtime
 
96
  x_numpy = x.numpy()
97
- # Run inference with ONNX Runtime
98
- print("Running inference with ONNX Runtime")
 
99
  output = self.onnx_session.run(
100
  None,
101
  {'input': x_numpy}
@@ -118,6 +126,10 @@ class FeatureExtractor:
118
  dummy_input = torch.randn(1, 3, 224, 224, device=self.device)
119
 
120
  # Export the model
 
 
 
 
121
  torch.onnx.export(
122
  self.model,
123
  dummy_input,
 
34
  # Check if ONNX model exists
35
  if os.path.exists(onnx_path):
36
  print(f"Loading existing ONNX model from {onnx_path}")
37
+ # Get model size in MB
38
+ model_size = os.path.getsize(onnx_path) / (1024 * 1024)
39
+ print(f"Model size: {model_size:.2f} MB")
40
  self.onnx_session = onnxruntime.InferenceSession(onnx_path)
41
  else:
42
  print(
 
45
  # Initialize PyTorch model
46
  self.model, _ = self.init_model(base_model)
47
  self.model.eval()
48
+ # self.device = torch.device("cpu")
49
+ # self.model.to(self.device).half()
50
 
51
  # Create directory if it doesn't exist
52
  os.makedirs(os.path.dirname(onnx_path), exist_ok=True)
 
56
 
57
  # Load the newly created ONNX model
58
  self.onnx_session = onnxruntime.InferenceSession(onnx_path)
59
+ # Get model size in MB
60
+ model_size = os.path.getsize(onnx_path) / (1024 * 1024)
61
  print(f"Successfully created and loaded ONNX model from {onnx_path}")
62
+ print(f"Model size: {model_size:.2f} MB")
63
 
64
  def init_model(self, base_model):
65
  """Initialize the model for feature extraction
 
98
  # add batch dimension
99
  x = x.unsqueeze(0)
100
 
101
+ # Convert to float16 and then to numpy for ONNX Runtime
102
+ x = x.half()
103
  x_numpy = x.numpy()
104
+ # Get model size in MB
105
+ model_size = os.path.getsize(self.onnx_path) / (1024 * 1024)
106
+ print(f"Running inference with ONNX model (size: {model_size:.2f} MB)")
107
  output = self.onnx_session.run(
108
  None,
109
  {'input': x_numpy}
 
126
  dummy_input = torch.randn(1, 3, 224, 224, device=self.device)
127
 
128
  # Export the model
129
+ # Convert model to float16
130
+ self.model.half()
131
+ dummy_input = dummy_input.half()
132
+
133
  torch.onnx.export(
134
  self.model,
135
  dummy_input,