Autonomous_Car / classification /inference_onnx.py
ABAO77's picture
update path
b39c057
import onnxruntime as ort
import cv2
import numpy as np
from numpy.typing import NDArray
import os
from a_utils_func_2_model import (
CLEAN_DATA_CSV_DIRECTION,
ADD_DATA_CSV_MASK_DIRECTION,
ADD_DATA_CSV_DIRECTION_STRAIGHT,
CLEAN_DATA_CSV_DIRECTION_STRAIGHT,
CHECK_PUSH,
ADD_DATA_CSV_CLASSIFICATION,
CHECK_CSV_CLASSIFICATION,
CLEAN_DATA_CSV_CLASSIFICATION,
)
def load_model(model_path: str):
"""
Load ONNX model for inference with appropriate execution provider.
Args:
model_path: Path to the ONNX model file
Returns:
ONNX Runtime InferenceSession
Raises:
FileNotFoundError: If model file doesn't exist
RuntimeError: If model loading fails
"""
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
try:
available_providers = ort.get_available_providers()
if "CUDAExecutionProvider" in available_providers:
# "TensorrtExecutionProvider"
# "CUDAExecutionProvider"
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
print("Using CUDA provider")
else:
providers = ["CPUExecutionProvider"]
print("Using CPU provider")
session = ort.InferenceSession(model_path, providers=providers)
return session
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
dirname = os.path.dirname(__file__)
model_path = os.path.join(dirname, "./model/low_angle_model_float16.onnx")
session = load_model(model_path)
print("model_path classification", model_path)
def prepare_input(image):
"""
Prepare image input for model inference.
Args:
image: Input image in BGR format with shape (H, W, 3)
Returns:
Preprocessed image as float16 array with shape (1, 3, H, W)
"""
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224)).astype(np.float16)
# Normalize pixel values to range [-1, 1]
mean = np.array([0.485, 0.456, 0.406], dtype=np.float16)
std = np.array([0.229, 0.224, 0.225], dtype=np.float16)
img = (img / 255.0 - mean) / std
# Convert to (1, 3, H, W) format
img = img.transpose(2, 0, 1)
img = np.expand_dims(img, axis=0)
return img.astype(np.float16)
def softmax(x):
"""Apply softmax function to numpy array."""
exp_x = np.exp(x - np.max(x)) # Subtract max for numerical stability
return exp_x / exp_x.sum()
classes = ["LEFT", "RIGHT", "STRAIGHT"]
def inference(image):
"""
Run inference on an image and return class prediction with probabilities.
Args:
session: ONNX runtime session
image: Input image in BGR format
Returns:
tuple containing:
- predicted class index (int)
- confidence score (float)
- probability distribution (numpy array)
"""
input_tensor = prepare_input(image)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
output = session.run([output_name], {input_name: input_tensor})[0]
# Apply softmax to get probabilities
probabilities = softmax(output[0])
predicted_class = classes[np.argmax(probabilities)]
confidence = np.max(probabilities)
# max_value = probabilities[max_index]
return predicted_class, confidence
def process_video(
video_path: str, session, output_path: str = None, display: bool = True
):
"""
Process video file and perform inference on each frame.
Args:
video_path: Path to input video file
session: ONNX runtime session
output_path: Path to save output video (optional)
display: Whether to display video while processing
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError("Error opening video file")
# Get video properties
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
# Initialize video writer if output path is specified
writer = None
if output_path:
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
classes = ["left", "right", "straight"]
try:
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Perform inference
max_index, confidence, probs = inference(session, frame)
# Draw prediction on frame
text = f"{classes[max_index]}: {confidence:.2f}"
cv2.putText(
frame,
text,
(50, 50),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 255, 0),
2,
cv2.LINE_AA,
)
if display:
cv2.imshow("Video Processing", frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
if writer:
writer.write(frame)
finally:
cap.release()
if writer:
writer.release()
if display:
cv2.destroyAllWindows()
def inference_classification(image):
predicted_class, probabilities = inference(image)
ADD_DATA_CSV_CLASSIFICATION(predicted_class)
CHECK_CSV_CLASSIFICATION()
if __name__ == "__main__":
model_path = "./model/model_16.onnx"
image_path = "../images/1.png"
# Load model
# session = load_model(model_path)
# # Load and preprocess image
image = cv2.imread(image_path)
# # Perform inference
predicted_class, probabilities = inference(image)
print(f"Predicted Class: {predicted_class}, Confidence: {probabilities}")
# video_path = "./data/IMG_2478.MOV" # Replace with your video path
# # output_path = "./output_video.mp4" # Optional output path
# try:
# process_video(
# video_path,
# session,
# )
# except Exception as e:
# print(f"Error processing video: {str(e)}")
# 0 left
# 1 right
# 2 straight