Spaces:
Sleeping
Sleeping
File size: 5,456 Bytes
0ccc9b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import tqdm
import cv2
import numpy as np
from mediapipe.python.solutions import drawing_utils as mp_drawing
import mediapipe as mp
from PoseClassification.pose_embedding import FullBodyPoseEmbedding
from PoseClassification.pose_classifier import PoseClassifier
from PoseClassification.utils import EMADictSmoothing
from PoseClassification.utils import RepetitionCounter
from PoseClassification.visualize import PoseClassificationVisualizer
mp_pose = mp.solutions.pose
pose_tracker = mp_pose.Pose()
pose_samples_folder = "data/fitness_poses_csvs_out"
class_name = "pushups_down"
pose_embedder = FullBodyPoseEmbedding()
pose_classifier = PoseClassifier(
pose_samples_folder=pose_samples_folder,
pose_embedder=pose_embedder,
top_n_by_max_distance=30,
top_n_by_mean_distance=10,
)
pose_classification_filter = EMADictSmoothing(window_size=10, alpha=0.2)
repetition_counter = RepetitionCounter(
class_name=class_name, enter_threshold=6, exit_threshold=4
)
pose_classification_visualizer = PoseClassificationVisualizer(
class_name=class_name, plot_x_max=1000, plot_y_max=10
)
video_cap = cv2.VideoCapture(0)
video_fps = 30
video_width = 1280
video_height = 720
video_cap.set(cv2.CAP_PROP_FRAME_WIDTH, video_width)
video_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, video_height)
frame_idx = 0
output_frame = None
try:
with tqdm.tqdm(position=0, leave=True) as pbar:
while True:
success, input_frame = video_cap.read()
if not success:
print("Unable to read input video frame, breaking!")
break
# Run pose tracker
input_frame_rgb = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB)
result = pose_tracker.process(image=input_frame_rgb)
pose_landmarks = result.pose_landmarks
# Prepare the output frame
output_frame = input_frame.copy()
# Add a white banner on top
banner_height = 180
output_frame[0:banner_height, :] = (255, 255, 255) # White color
# Load the logo image
logo = cv2.imread("src/logo_impredalam.jpg")
logo_height, logo_width = logo.shape[:2]
logo = cv2.resize(
logo, (logo_width // 3, logo_height // 3)
) # Resize to 1/3 scale
# Overlay the logo on the upper right corner
output_frame[0 : logo.shape[0], output_frame.shape[1] - logo.shape[1] :] = (
logo
)
if pose_landmarks is not None:
mp_drawing.draw_landmarks(
image=output_frame,
landmark_list=pose_landmarks,
connections=mp_pose.POSE_CONNECTIONS,
)
# Get landmarks
frame_height, frame_width = output_frame.shape[0], output_frame.shape[1]
pose_landmarks = np.array(
[
[lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width]
for lmk in pose_landmarks.landmark
],
dtype=np.float32,
)
assert pose_landmarks.shape == (
33,
3,
), "Unexpected landmarks shape: {}".format(pose_landmarks.shape)
# Classify the pose on the current frame
pose_classification = pose_classifier(pose_landmarks)
# Smooth classification using EMA
pose_classification_filtered = pose_classification_filter(
pose_classification
)
# Count repetitions
repetitions_count = repetition_counter(pose_classification_filtered)
# Display repetitions count on the frame
cv2.putText(
output_frame,
f"Push-Ups: {repetitions_count}",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 0, 0),
2,
cv2.LINE_AA,
)
# Display classified pose on the frame
cv2.putText(
output_frame,
f"Pose: {pose_classification}",
(10, 70),
cv2.FONT_HERSHEY_SIMPLEX,
1.2, # Smaller font size
(0, 0, 0),
1, # Thinner line
cv2.LINE_AA,
)
else:
# If no landmarks are detected, still display the last count
repetitions_count = repetition_counter.n_repeats
cv2.putText(
output_frame,
f"Push-Ups: {repetitions_count}",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 255, 0),
2,
cv2.LINE_AA,
)
cv2.imshow("Push-Up Counter", output_frame)
key = cv2.waitKey(1) & 0xFF
if key == ord("q"):
break
elif key == ord("r"):
repetition_counter.reset()
print("Counter reset!")
frame_idx += 1
pbar.update()
finally:
pose_tracker.close()
video_cap.release()
cv2.destroyAllWindows()
|