YOGAI / pushups_counter.py
1mpreccable's picture
Upload 35 files
0ccc9b6 verified
import tqdm
import cv2
import numpy as np
from mediapipe.python.solutions import drawing_utils as mp_drawing
import mediapipe as mp
from PoseClassification.pose_embedding import FullBodyPoseEmbedding
from PoseClassification.pose_classifier import PoseClassifier
from PoseClassification.utils import EMADictSmoothing
from PoseClassification.utils import RepetitionCounter
from PoseClassification.visualize import PoseClassificationVisualizer
mp_pose = mp.solutions.pose
pose_tracker = mp_pose.Pose()
pose_samples_folder = "data/fitness_poses_csvs_out"
class_name = "pushups_down"
pose_embedder = FullBodyPoseEmbedding()
pose_classifier = PoseClassifier(
pose_samples_folder=pose_samples_folder,
pose_embedder=pose_embedder,
top_n_by_max_distance=30,
top_n_by_mean_distance=10,
)
pose_classification_filter = EMADictSmoothing(window_size=10, alpha=0.2)
repetition_counter = RepetitionCounter(
class_name=class_name, enter_threshold=6, exit_threshold=4
)
pose_classification_visualizer = PoseClassificationVisualizer(
class_name=class_name, plot_x_max=1000, plot_y_max=10
)
video_cap = cv2.VideoCapture(0)
video_fps = 30
video_width = 1280
video_height = 720
video_cap.set(cv2.CAP_PROP_FRAME_WIDTH, video_width)
video_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, video_height)
frame_idx = 0
output_frame = None
try:
with tqdm.tqdm(position=0, leave=True) as pbar:
while True:
success, input_frame = video_cap.read()
if not success:
print("Unable to read input video frame, breaking!")
break
# Run pose tracker
input_frame_rgb = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB)
result = pose_tracker.process(image=input_frame_rgb)
pose_landmarks = result.pose_landmarks
# Prepare the output frame
output_frame = input_frame.copy()
# Add a white banner on top
banner_height = 180
output_frame[0:banner_height, :] = (255, 255, 255) # White color
# Load the logo image
logo = cv2.imread("src/logo_impredalam.jpg")
logo_height, logo_width = logo.shape[:2]
logo = cv2.resize(
logo, (logo_width // 3, logo_height // 3)
) # Resize to 1/3 scale
# Overlay the logo on the upper right corner
output_frame[0 : logo.shape[0], output_frame.shape[1] - logo.shape[1] :] = (
logo
)
if pose_landmarks is not None:
mp_drawing.draw_landmarks(
image=output_frame,
landmark_list=pose_landmarks,
connections=mp_pose.POSE_CONNECTIONS,
)
# Get landmarks
frame_height, frame_width = output_frame.shape[0], output_frame.shape[1]
pose_landmarks = np.array(
[
[lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width]
for lmk in pose_landmarks.landmark
],
dtype=np.float32,
)
assert pose_landmarks.shape == (
33,
3,
), "Unexpected landmarks shape: {}".format(pose_landmarks.shape)
# Classify the pose on the current frame
pose_classification = pose_classifier(pose_landmarks)
# Smooth classification using EMA
pose_classification_filtered = pose_classification_filter(
pose_classification
)
# Count repetitions
repetitions_count = repetition_counter(pose_classification_filtered)
# Display repetitions count on the frame
cv2.putText(
output_frame,
f"Push-Ups: {repetitions_count}",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 0, 0),
2,
cv2.LINE_AA,
)
# Display classified pose on the frame
cv2.putText(
output_frame,
f"Pose: {pose_classification}",
(10, 70),
cv2.FONT_HERSHEY_SIMPLEX,
1.2, # Smaller font size
(0, 0, 0),
1, # Thinner line
cv2.LINE_AA,
)
else:
# If no landmarks are detected, still display the last count
repetitions_count = repetition_counter.n_repeats
cv2.putText(
output_frame,
f"Push-Ups: {repetitions_count}",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 255, 0),
2,
cv2.LINE_AA,
)
cv2.imshow("Push-Up Counter", output_frame)
key = cv2.waitKey(1) & 0xFF
if key == ord("q"):
break
elif key == ord("r"):
repetition_counter.reset()
print("Counter reset!")
frame_idx += 1
pbar.update()
finally:
pose_tracker.close()
video_cap.release()
cv2.destroyAllWindows()