ABAO77 commited on
Commit
92f42f9
·
verified ·
1 Parent(s): eea1fcf

Update src/modules/feature_extractor.py

Browse files
Files changed (1) hide show
  1. src/modules/feature_extractor.py +2 -7
src/modules/feature_extractor.py CHANGED
@@ -46,7 +46,7 @@ class FeatureExtractor:
46
  self.model, _ = self.init_model(base_model)
47
  self.model.eval()
48
  self.device = torch.device("cpu")
49
- # self.model.to(self.device).half()
50
 
51
  # Create directory if it doesn't exist
52
  os.makedirs(os.path.dirname(onnx_path), exist_ok=True)
@@ -98,8 +98,7 @@ class FeatureExtractor:
98
  # add batch dimension
99
  x = x.unsqueeze(0)
100
 
101
- # Convert to float16 and then to numpy for ONNX Runtime
102
- x = x.half()
103
  x_numpy = x.numpy()
104
  # Get model size in MB
105
  model_size = os.path.getsize(self.onnx_path) / (1024 * 1024)
@@ -126,10 +125,6 @@ class FeatureExtractor:
126
  dummy_input = torch.randn(1, 3, 224, 224, device=self.device)
127
 
128
  # Export the model
129
- # Convert model to float16
130
- self.model.half()
131
- dummy_input = dummy_input.half()
132
-
133
  torch.onnx.export(
134
  self.model,
135
  dummy_input,
 
46
  self.model, _ = self.init_model(base_model)
47
  self.model.eval()
48
  self.device = torch.device("cpu")
49
+ self.model.to(self.device)
50
 
51
  # Create directory if it doesn't exist
52
  os.makedirs(os.path.dirname(onnx_path), exist_ok=True)
 
98
  # add batch dimension
99
  x = x.unsqueeze(0)
100
 
101
+ # Convert to numpy for ONNX Runtime
 
102
  x_numpy = x.numpy()
103
  # Get model size in MB
104
  model_size = os.path.getsize(self.onnx_path) / (1024 * 1024)
 
125
  dummy_input = torch.randn(1, 3, 224, 224, device=self.device)
126
 
127
  # Export the model
 
 
 
 
128
  torch.onnx.export(
129
  self.model,
130
  dummy_input,