Spaces:
Sleeping
Sleeping
Update src/modules/feature_extractor.py
Browse files
src/modules/feature_extractor.py
CHANGED
|
@@ -4,7 +4,8 @@ import os
|
|
| 4 |
import torch
|
| 5 |
import onnx
|
| 6 |
import onnxruntime
|
| 7 |
-
|
|
|
|
| 8 |
from src.modules.config_extractor import MODEL_CONFIG
|
| 9 |
|
| 10 |
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
|
@@ -99,7 +100,7 @@ class FeatureExtractor:
|
|
| 99 |
x = x.unsqueeze(0)
|
| 100 |
|
| 101 |
# Convert to numpy for ONNX Runtime
|
| 102 |
-
x_numpy = x.numpy()
|
| 103 |
# Get model size in MB
|
| 104 |
model_size = os.path.getsize(self.onnx_path) / (1024 * 1024)
|
| 105 |
print(f"Running inference with ONNX model (size: {model_size:.2f} MB)")
|
|
@@ -124,7 +125,7 @@ class FeatureExtractor:
|
|
| 124 |
# Create a dummy input tensor
|
| 125 |
dummy_input = torch.randn(1, 3, 224, 224, device=self.device)
|
| 126 |
|
| 127 |
-
# Export the model
|
| 128 |
torch.onnx.export(
|
| 129 |
self.model,
|
| 130 |
dummy_input,
|
|
@@ -140,7 +141,14 @@ class FeatureExtractor:
|
|
| 140 |
}
|
| 141 |
)
|
| 142 |
|
| 143 |
-
#
|
|
|
|
| 144 |
onnx_model = onnx.load(save_path)
|
| 145 |
onnx.checker.check_model(onnx_model)
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import torch
|
| 5 |
import onnx
|
| 6 |
import onnxruntime
|
| 7 |
+
from onnxconverter_common import float16
|
| 8 |
+
import numpy as np
|
| 9 |
from src.modules.config_extractor import MODEL_CONFIG
|
| 10 |
|
| 11 |
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
|
|
|
| 100 |
x = x.unsqueeze(0)
|
| 101 |
|
| 102 |
# Convert to numpy for ONNX Runtime
|
| 103 |
+
x_numpy = x.numpy().astype(np.float16)
|
| 104 |
# Get model size in MB
|
| 105 |
model_size = os.path.getsize(self.onnx_path) / (1024 * 1024)
|
| 106 |
print(f"Running inference with ONNX model (size: {model_size:.2f} MB)")
|
|
|
|
| 125 |
# Create a dummy input tensor
|
| 126 |
dummy_input = torch.randn(1, 3, 224, 224, device=self.device)
|
| 127 |
|
| 128 |
+
# Export the model to ONNX (FP32 first)
|
| 129 |
torch.onnx.export(
|
| 130 |
self.model,
|
| 131 |
dummy_input,
|
|
|
|
| 141 |
}
|
| 142 |
)
|
| 143 |
|
| 144 |
+
# Load and verify the exported model
|
| 145 |
+
print("Converting model to float16...")
|
| 146 |
onnx_model = onnx.load(save_path)
|
| 147 |
onnx.checker.check_model(onnx_model)
|
| 148 |
+
|
| 149 |
+
# Convert to float16
|
| 150 |
+
model_fp16 = float16.convert_float_to_float16(onnx_model)
|
| 151 |
+
|
| 152 |
+
# Save the float16 model
|
| 153 |
+
onnx.save(model_fp16, save_path)
|
| 154 |
+
print(f"Float16 ONNX model saved to {save_path}")
|