ABAO77 commited on
Commit
afec78c
·
verified ·
1 Parent(s): 92f42f9

Update src/modules/feature_extractor.py

Browse files
Files changed (1) hide show
  1. src/modules/feature_extractor.py +13 -5
src/modules/feature_extractor.py CHANGED
@@ -4,7 +4,8 @@ import os
4
  import torch
5
  import onnx
6
  import onnxruntime
7
-
 
8
  from src.modules.config_extractor import MODEL_CONFIG
9
 
10
  os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
@@ -99,7 +100,7 @@ class FeatureExtractor:
99
  x = x.unsqueeze(0)
100
 
101
  # Convert to numpy for ONNX Runtime
102
- x_numpy = x.numpy()
103
  # Get model size in MB
104
  model_size = os.path.getsize(self.onnx_path) / (1024 * 1024)
105
  print(f"Running inference with ONNX model (size: {model_size:.2f} MB)")
@@ -124,7 +125,7 @@ class FeatureExtractor:
124
  # Create a dummy input tensor
125
  dummy_input = torch.randn(1, 3, 224, 224, device=self.device)
126
 
127
- # Export the model
128
  torch.onnx.export(
129
  self.model,
130
  dummy_input,
@@ -140,7 +141,14 @@ class FeatureExtractor:
140
  }
141
  )
142
 
143
- # Verify the exported model
 
144
  onnx_model = onnx.load(save_path)
145
  onnx.checker.check_model(onnx_model)
146
- print(f"ONNX model saved to {save_path}")
 
 
 
 
 
 
 
4
  import torch
5
  import onnx
6
  import onnxruntime
7
+ from onnxconverter_common import float16
8
+ import numpy as np
9
  from src.modules.config_extractor import MODEL_CONFIG
10
 
11
  os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
 
100
  x = x.unsqueeze(0)
101
 
102
  # Convert to numpy for ONNX Runtime
103
+ x_numpy = x.numpy().astype(np.float16)
104
  # Get model size in MB
105
  model_size = os.path.getsize(self.onnx_path) / (1024 * 1024)
106
  print(f"Running inference with ONNX model (size: {model_size:.2f} MB)")
 
125
  # Create a dummy input tensor
126
  dummy_input = torch.randn(1, 3, 224, 224, device=self.device)
127
 
128
+ # Export the model to ONNX (FP32 first)
129
  torch.onnx.export(
130
  self.model,
131
  dummy_input,
 
141
  }
142
  )
143
 
144
+ # Load and verify the exported model
145
+ print("Converting model to float16...")
146
  onnx_model = onnx.load(save_path)
147
  onnx.checker.check_model(onnx_model)
148
+
149
+ # Convert to float16
150
+ model_fp16 = float16.convert_float_to_float16(onnx_model)
151
+
152
+ # Save the float16 model
153
+ onnx.save(model_fp16, save_path)
154
+ print(f"Float16 ONNX model saved to {save_path}")