SignLanguage / src /inference.py
thienphuc12339's picture
Add all source code
9f83ce9
import shutil
import logging
from time import time
import numpy as np
import pandas as pd
import cv2
from traceback import format_exc
from argparse import Namespace
from transformers import Pipeline
from simple_parsing import ArgumentParser
import mediapipe as mp
from mediapipe.python.solutions.pose import PoseLandmark
from mediapipe.python.solutions.hands import HandLandmark
from mediapipe.python.solutions.drawing_utils import DrawingSpec
from visualization import draw_text_on_image
from configs import ModelConfig, InferenceConfig
from utils import config_logger, POSE_BASED_MODELS
from data import Arm, get_sample_timestamp, ok_to_get_frame
from tools import load_pipeline, Predictions
SPOTER_POSE_LANDMARKS = [
PoseLandmark.NOSE,
PoseLandmark.LEFT_EYE,
PoseLandmark.RIGHT_EYE,
PoseLandmark.RIGHT_SHOULDER,
PoseLandmark.LEFT_SHOULDER,
PoseLandmark.RIGHT_ELBOW,
PoseLandmark.LEFT_ELBOW,
PoseLandmark.RIGHT_WRIST,
PoseLandmark.LEFT_WRIST ]
SPOTER_HAND_LANDMARKS = [
HandLandmark.WRIST,
HandLandmark.INDEX_FINGER_TIP, HandLandmark.INDEX_FINGER_DIP, HandLandmark.INDEX_FINGER_PIP, HandLandmark.INDEX_FINGER_MCP,
HandLandmark.MIDDLE_FINGER_TIP, HandLandmark.MIDDLE_FINGER_DIP, HandLandmark.MIDDLE_FINGER_PIP, HandLandmark.MIDDLE_FINGER_MCP,
HandLandmark.RING_FINGER_TIP, HandLandmark.RING_FINGER_DIP, HandLandmark.RING_FINGER_PIP, HandLandmark.RING_FINGER_MCP,
HandLandmark.PINKY_TIP, HandLandmark.PINKY_DIP, HandLandmark.PINKY_PIP, HandLandmark.PINKY_MCP,
HandLandmark.THUMB_TIP, HandLandmark.THUMB_IP, HandLandmark.THUMB_MCP, HandLandmark.THUMB_CMC,
]
def get_args() -> Namespace:
parser = ArgumentParser(
description="Train a model on VSL",
add_config_path_arg=True,
)
parser.add_arguments(ModelConfig, "model")
parser.add_arguments(InferenceConfig, "inference")
return parser.parse_args()
def inference(model_config, inference_config: InferenceConfig, pipeline: Pipeline) -> None:
# Load video
source = str(inference_config.source) if inference_config.source.is_file() else 0
cap = cv2.VideoCapture(source)
if inference_config.output_dir is not None:
writer = cv2.VideoWriter(
str(inference_config.output_dir / "output.mp4"),
cv2.VideoWriter_fourcc(*"mp4v"),
cap.get(cv2.CAP_PROP_FPS),
(int(cap.get(3)), int(cap.get(4))),
)
# Init Mediapipe
mp_holistic = mp.solutions.holistic
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
custom_pose_style = mp_drawing_styles.get_default_pose_landmarks_style()
custom_right_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
custom_left_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
custom_pose_connections = list(mp_holistic.POSE_CONNECTIONS)
custom_hand_connections = list(mp_holistic.HAND_CONNECTIONS)
if inference_config.show_skeleton:
# if model_config.arch == 'spoter':
pose_landmarks = SPOTER_POSE_LANDMARKS
hand_landmarks = SPOTER_HAND_LANDMARKS
for landmark in PoseLandmark:
if landmark in pose_landmarks:
custom_pose_style[landmark] = DrawingSpec(color=(0,255,0), thickness=2, circle_radius=2)
else:
custom_pose_style[landmark] = DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0)
for connection_tuple in custom_pose_connections:
if landmark.value in connection_tuple:
custom_pose_connections.remove(connection_tuple)
for landmark in HandLandmark:
if landmark in hand_landmarks:
custom_right_hand_style[landmark] = DrawingSpec(color=(0,0,255), thickness=2, circle_radius=2)
custom_left_hand_style[landmark] = DrawingSpec(color=(255,0,0), thickness=2, circle_radius=2)
else:
custom_right_hand_style[HandLandmark[landmark.name]] = DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0)
custom_left_hand_style[HandLandmark[landmark.name]] = DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0)
for connection_tuple in custom_hand_connections:
if landmark.value in connection_tuple:
custom_hand_connections.remove(connection_tuple)
# Init variables
right_arm = Arm("right", inference_config.visibility)
left_arm = Arm("left", inference_config.visibility)
data = []
results = None
predictions = Predictions()
with mp_holistic.Holistic(min_detection_confidence=0.9, min_tracking_confidence=0.5) as holistic:
while cap.isOpened():
success, frame = cap.read()
if not success:
break
# Recolor image to RGB, because mp processes on RGB image
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame.flags.writeable = False
# Make detections
detection_results = holistic.process(frame)
# Recolor image back to BGR, because cv2 processes on BGR image
frame.flags.writeable = True
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
# Extract landmarks
try:
landmarks = detection_results.pose_landmarks.landmark
except Exception:
continue
left_arm.set_pose(landmarks)
right_arm.set_pose(landmarks)
# Check if arms are up or down
left_arm_ok_to_get_frame = ok_to_get_frame(
arm=left_arm,
angle_threshold=inference_config.angle_threshold,
min_num_up_frames=inference_config.min_num_up_frames,
min_num_down_frames=inference_config.min_num_down_frames,
current_time=cap.get(cv2.CAP_PROP_POS_MSEC),
delay=inference_config.delay,
)
right_arm_ok_to_get_frame = ok_to_get_frame(
arm=right_arm,
angle_threshold=inference_config.angle_threshold,
min_num_up_frames=inference_config.min_num_up_frames,
min_num_down_frames=inference_config.min_num_down_frames,
current_time=cap.get(cv2.CAP_PROP_POS_MSEC),
delay=inference_config.delay,
)
if left_arm_ok_to_get_frame or right_arm_ok_to_get_frame:
# logging.info("Frame added to the list")
predictions = Predictions()
data.append(detection_results if inference_config.use_pose_model else frame)
# Calculate the start and end time of sign
start_time, end_time = get_sample_timestamp(left_arm, right_arm)
# Convert from miliseconds to seconds
start_time /= 1_000
end_time /= 1_000
# logging.info(f"start_time: {start_time} - end_time: {end_time}")
# logging.info(f"\tLeft arm: {left_arm.start_time} - {left_arm.end_time} - {left_arm.is_up}")
# logging.info(f"\tRight arm: {right_arm.start_time} - {right_arm.end_time} - {right_arm.is_up}")
if start_time != 0 and end_time != 0:
# Render waiting screen
if inference_config.visualize:
wait_frame = draw_text_on_image(
np.zeros_like(frame),
text="Please wait for the prediction...",
position=(20, 20),
color=(255, 255, 255),
font_size=20,
)
cv2.imshow("Video Visualization", wait_frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
start_inference_time = time()
predictions = Predictions(predictions=pipeline(np.array(data)))
predictions.inference_time = time() - start_inference_time
predictions.start_time = start_time
predictions.end_time = end_time
logging.info(str(predictions))
results = predictions.merge_results(results)
# Reset variables
start_time = 0
end_time = 0
left_arm.reset_state()
right_arm.reset_state()
data = []
# Render detections
frame = left_arm.visualize(frame, (20, 10), "Left arm angle")
frame = right_arm.visualize(frame, (20, 40), "Right arm angle")
frame = predictions.visualize(frame, (20, 70))
if inference_config.show_skeleton:
mp_drawing.draw_landmarks(
frame,
detection_results.pose_landmarks,
connections = custom_pose_connections, # passing the modified connections list
landmark_drawing_spec=custom_pose_style) # and drawing style
mp_drawing.draw_landmarks(
frame,
detection_results.right_hand_landmarks,
connections = custom_hand_connections, # passing the modified connections list
landmark_drawing_spec=custom_right_hand_style) # and drawing style
mp_drawing.draw_landmarks(
frame,
detection_results.left_hand_landmarks,
connections = custom_hand_connections, # passing the modified connections list
landmark_drawing_spec=custom_left_hand_style) # and drawing style
if inference_config.output_dir is not None:
writer.write(frame)
if inference_config.visualize:
cv2.imshow("Video Visualization", frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
if inference_config.output_dir is not None:
writer.release()
logging.info(f"Video is recorded and saved to {inference_config.output_dir / 'output.avi'}")
pd.DataFrame(results).to_csv(inference_config.output_dir / "results.csv", index=False)
logging.info(f"Results saved to {inference_config.output_dir / 'results.csv'}")
def main(args: Namespace) -> None:
model_config = args.model
logging.info(model_config)
inference_config = args.inference
logging.info(inference_config)
if model_config.arch in POSE_BASED_MODELS:
inference_config.use_pose_model = True
else:
inference_config.use_pose_model = False
pipeline = load_pipeline(model_config, inference_config)
logging.info("Pipeline loaded")
inference(model_config, inference_config, pipeline)
logging.info("Inference completed")
if __name__ == "__main__":
try:
args = get_args()
config_logger(args.inference.output_dir / "inference.log")
logging.info(f"Config file loaded from {args.config_path[0]}")
shutil.copy(args.config_path[0], args.inference.output_dir / "inference.yaml")
logging.info(f"Config file saved to {args.inference.output_dir}")
main(args=args)
except Exception:
print(format_exc())