SignLanguage-pro / inference.py
thienphuc12339's picture
Update inference.py
0bfec51 verified
# inference.py
import shutil
import logging
from time import time
import numpy as np
import pandas as pd
import cv2
from traceback import format_exc
from argparse import Namespace
from pydantic import BaseModel
import mediapipe as mp
from mediapipe.python.solutions.pose import PoseLandmark
from mediapipe.python.solutions.hands import HandLandmark
from mediapipe.python.solutions.drawing_utils import DrawingSpec
from visualization import draw_text_on_image
from configs import ModelConfig, InferenceConfig
from utils import config_logger, POSE_BASED_MODELS
from data import Arm, get_sample_timestamp, ok_to_get_frame
from tools.models import load_pipeline, get_predictions, Predictions
# Define id2gloss mapping
# Bạn cần thay thế bản đồ này với bản đồ thực tế của bạn
id2gloss = {
"0": "hello",
"1": "thanks",
"2": "yes",
# Thêm các ánh xạ cần thiết
}
SPOTER_POSE_LANDMARKS = [
PoseLandmark.NOSE,
PoseLandmark.LEFT_EYE,
PoseLandmark.RIGHT_EYE,
PoseLandmark.RIGHT_SHOULDER,
PoseLandmark.LEFT_SHOULDER,
PoseLandmark.RIGHT_ELBOW,
PoseLandmark.LEFT_ELBOW,
PoseLandmark.RIGHT_WRIST,
PoseLandmark.LEFT_WRIST
]
SPOTER_HAND_LANDMARKS = [
HandLandmark.WRIST,
HandLandmark.INDEX_FINGER_TIP, HandLandmark.INDEX_FINGER_DIP, HandLandmark.INDEX_FINGER_PIP, HandLandmark.INDEX_FINGER_MCP,
HandLandmark.MIDDLE_FINGER_TIP, HandLandmark.MIDDLE_FINGER_DIP, HandLandmark.MIDDLE_FINGER_PIP, HandLandmark.MIDDLE_FINGER_MCP,
HandLandmark.RING_FINGER_TIP, HandLandmark.RING_FINGER_DIP, HandLandmark.RING_FINGER_PIP, HandLandmark.RING_FINGER_MCP,
HandLandmark.PINKY_TIP, HandLandmark.PINKY_DIP, HandLandmark.PINKY_PIP, HandLandmark.PINKY_MCP,
HandLandmark.THUMB_TIP, HandLandmark.THUMB_IP, HandLandmark.THUMB_MCP, HandLandmark.THUMB_CMC,
]
def get_args() -> Namespace:
parser = ArgumentParser(
description="Train a model on VSL",
add_config_path_arg=True,
)
parser.add_arguments(ModelConfig, "model")
parser.add_arguments(InferenceConfig, "inference")
return parser.parse_args()
def inference(model_config, inference_config: InferenceConfig, session: ort.InferenceSession) -> dict:
# Load video
source = str(inference_config.source) if Path(inference_config.source).is_file() else 0
cap = cv2.VideoCapture(source)
if inference_config.output_dir is not None:
writer = cv2.VideoWriter(
str(Path(inference_config.output_dir) / "output.mp4"),
cv2.VideoWriter_fourcc(*"mp4v"),
cap.get(cv2.CAP_PROP_FPS),
(int(cap.get(3)), int(cap.get(4))),
)
# Init Mediapipe
mp_holistic = mp.solutions.holistic
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
custom_pose_style = mp_drawing_styles.get_default_pose_landmarks_style()
custom_right_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
custom_left_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
custom_pose_connections = list(mp_holistic.POSE_CONNECTIONS)
custom_hand_connections = list(mp_holistic.HAND_CONNECTIONS)
if inference_config.show_skeleton:
# Định dạng đặc biệt cho 'spoter'
pose_landmarks = SPOTER_POSE_LANDMARKS
hand_landmarks = SPOTER_HAND_LANDMARKS
for landmark in PoseLandmark:
if landmark in pose_landmarks:
custom_pose_style[landmark] = DrawingSpec(color=(0, 255, 0), thickness=2, circle_radius=2)
else:
custom_pose_style[landmark] = DrawingSpec(color=(0, 0, 0), thickness=0, circle_radius=0)
# Loại bỏ các kết nối liên quan
custom_pose_connections = [conn for conn in custom_pose_connections if landmark.value not in conn]
for landmark in HandLandmark:
if landmark in hand_landmarks:
custom_right_hand_style[landmark] = DrawingSpec(color=(0, 0, 255), thickness=2, circle_radius=2)
custom_left_hand_style[landmark] = DrawingSpec(color=(255, 0, 0), thickness=2, circle_radius=2)
else:
# Loại bỏ các kết nối liên quan
custom_hand_connections = [conn for conn in custom_hand_connections if landmark.value not in conn]
custom_right_hand_style[landmark] = DrawingSpec(color=(0, 0, 0), thickness=0, circle_radius=0)
custom_left_hand_style[landmark] = DrawingSpec(color=(0, 0, 0), thickness=0, circle_radius=0)
# Init variables
right_arm = Arm("right", inference_config.visibility)
left_arm = Arm("left", inference_config.visibility)
data = []
results = None
predictions = Predictions()
with mp_holistic.Holistic(min_detection_confidence=0.9, min_tracking_confidence=0.5) as holistic:
while cap.isOpened():
success, frame = cap.read()
if not success:
break
# Recolor image to RGB, because mp processes on RGB image
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame.flags.writeable = False
# Make detections
detection_results = holistic.process(frame)
# Recolor image back to BGR, because cv2 processes on BGR image
frame.flags.writeable = True
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
# Extract landmarks
try:
landmarks = detection_results.pose_landmarks.landmark
except Exception:
continue
left_arm.set_pose(landmarks)
right_arm.set_pose(landmarks)
# Check if arms are up or down
left_arm_ok_to_get_frame = ok_to_get_frame(
arm=left_arm,
angle_threshold=inference_config.angle_threshold,
min_num_up_frames=inference_config.min_num_up_frames,
min_num_down_frames=inference_config.min_num_down_frames,
current_time=cap.get(cv2.CAP_PROP_POS_MSEC),
delay=inference_config.delay,
)
right_arm_ok_to_get_frame = ok_to_get_frame(
arm=right_arm,
angle_threshold=inference_config.angle_threshold,
min_num_up_frames=inference_config.min_num_up_frames,
min_num_down_frames=inference_config.min_num_down_frames,
current_time=cap.get(cv2.CAP_PROP_POS_MSEC),
delay=inference_config.delay,
)
if left_arm_ok_to_get_frame or right_arm_ok_to_get_frame:
# logging.info("Frame added to the list")
predictions = Predictions()
data.append(frame) # Chỉ sử dụng frame vì bạn đang dùng .onnx
# Calculate the start and end time of sign
start_time, end_time = get_sample_timestamp(left_arm, right_arm)
# Convert from milliseconds to seconds
start_time /= 1_000
end_time /= 1_000
if start_time != 0 and end_time != 0:
# Render waiting screen
if inference_config.visualize:
wait_frame = draw_text_on_image(
np.zeros_like(frame),
text="Please wait for the prediction...",
position=(20, 20),
color=(255, 255, 255),
font_size=20,
)
cv2.imshow("Video Visualization", wait_frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
start_inference_time = time()
# Chuyển data thành np.ndarray phù hợp với mô hình ONNX
data_np = np.stack(data, axis=0) # Giả sử mô hình nhận dạng theo batch
predictions = get_predictions(data_np, session, id2gloss=id2gloss, k=inference_config.top_k)
predictions.inference_time = time() - start_inference_time
predictions.start_time = start_time
predictions.end_time = end_time
logging.info(str(predictions))
results = predictions.merge_results(results)
# Reset variables
start_time = 0
end_time = 0
left_arm.reset_state()
right_arm.reset_state()
data = []
# Render detections
frame = left_arm.visualize(frame, (20, 10), "Left arm angle")
frame = right_arm.visualize(frame, (20, 40), "Right arm angle")
frame = predictions.visualize(frame, (20, 70))
if inference_config.show_skeleton:
mp_drawing.draw_landmarks(
frame,
detection_results.pose_landmarks,
connections=custom_pose_connections,
landmark_drawing_spec=custom_pose_style
)
mp_drawing.draw_landmarks(
frame,
detection_results.right_hand_landmarks,
connections=custom_hand_connections,
landmark_drawing_spec=custom_right_hand_style
)
mp_drawing.draw_landmarks(
frame,
detection_results.left_hand_landmarks,
connections=custom_hand_connections,
landmark_drawing_spec=custom_left_hand_style
)
if inference_config.output_dir is not None:
writer.write(frame)
if inference_config.visualize:
cv2.imshow("Video Visualization", frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
if inference_config.output_dir is not None:
writer.release()
logging.info(f"Video is recorded and saved to {Path(inference_config.output_dir) / 'output.mp4'}")
pd.DataFrame(results).to_csv(Path(inference_config.output_dir) / "results.csv", index=False)
logging.info(f"Results saved to {Path(inference_config.output_dir) / 'results.csv'}")
return {"results": results}