ABAO77 commited on
Commit
fc7a73b
·
1 Parent(s): 179a781

setup cuda

Browse files
classification/inference_onnx.py CHANGED
@@ -33,8 +33,14 @@ def load_model(model_path: str):
33
  raise FileNotFoundError(f"Model file not found: {model_path}")
34
 
35
  try:
36
- # Try CPU provider first
37
- providers = ["CPUExecutionProvider"]
 
 
 
 
 
 
38
  session = ort.InferenceSession(model_path, providers=providers)
39
  return session
40
 
 
33
  raise FileNotFoundError(f"Model file not found: {model_path}")
34
 
35
  try:
36
+ available_providers = ort.get_available_providers()
37
+
38
+ if "CUDAExecutionProvider" in available_providers:
39
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
40
+ print("Using CUDA provider")
41
+ else:
42
+ providers = ["CPUExecutionProvider"]
43
+ print("Using CPU provider")
44
  session = ort.InferenceSession(model_path, providers=providers)
45
  return session
46
 
ultrafast/ultrafastLaneDetector.py CHANGED
@@ -130,7 +130,14 @@ class UltrafastLaneDetector:
130
  self.initialize_model(model_path)
131
 
132
  def initialize_model(self, model_path):
133
- providers = ["CPUExecutionProvider"]
 
 
 
 
 
 
 
134
 
135
  self.session = onnxruntime.InferenceSession(model_path, providers=providers)
136
 
 
130
  self.initialize_model(model_path)
131
 
132
  def initialize_model(self, model_path):
133
+ available_providers = onnxruntime.get_available_providers()
134
+
135
+ if 'CUDAExecutionProvider' in available_providers:
136
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
137
+ print("Using CUDA provider")
138
+ else:
139
+ providers = ['CPUExecutionProvider']
140
+ print("Using CPU provider")
141
 
142
  self.session = onnxruntime.InferenceSession(model_path, providers=providers)
143