File size: 5,456 Bytes
0ccc9b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import tqdm
import cv2
import numpy as np
from mediapipe.python.solutions import drawing_utils as mp_drawing
import mediapipe as mp
from PoseClassification.pose_embedding import FullBodyPoseEmbedding
from PoseClassification.pose_classifier import PoseClassifier
from PoseClassification.utils import EMADictSmoothing
from PoseClassification.utils import RepetitionCounter
from PoseClassification.visualize import PoseClassificationVisualizer

mp_pose = mp.solutions.pose
pose_tracker = mp_pose.Pose()

pose_samples_folder = "data/fitness_poses_csvs_out"
class_name = "pushups_down"

pose_embedder = FullBodyPoseEmbedding()

pose_classifier = PoseClassifier(
    pose_samples_folder=pose_samples_folder,
    pose_embedder=pose_embedder,
    top_n_by_max_distance=30,
    top_n_by_mean_distance=10,
)

pose_classification_filter = EMADictSmoothing(window_size=10, alpha=0.2)

repetition_counter = RepetitionCounter(
    class_name=class_name, enter_threshold=6, exit_threshold=4
)

pose_classification_visualizer = PoseClassificationVisualizer(
    class_name=class_name, plot_x_max=1000, plot_y_max=10
)

video_cap = cv2.VideoCapture(0)
video_fps = 30
video_width = 1280
video_height = 720
video_cap.set(cv2.CAP_PROP_FRAME_WIDTH, video_width)
video_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, video_height)

frame_idx = 0
output_frame = None

try:
    with tqdm.tqdm(position=0, leave=True) as pbar:
        while True:
            success, input_frame = video_cap.read()
            if not success:
                print("Unable to read input video frame, breaking!")
                break

            # Run pose tracker
            input_frame_rgb = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB)
            result = pose_tracker.process(image=input_frame_rgb)
            pose_landmarks = result.pose_landmarks

            # Prepare the output frame
            output_frame = input_frame.copy()

            # Add a white banner on top
            banner_height = 180
            output_frame[0:banner_height, :] = (255, 255, 255)  # White color

            # Load the logo image
            logo = cv2.imread("src/logo_impredalam.jpg")
            logo_height, logo_width = logo.shape[:2]
            logo = cv2.resize(
                logo, (logo_width // 3, logo_height // 3)
            )  # Resize to 1/3 scale

            # Overlay the logo on the upper right corner
            output_frame[0 : logo.shape[0], output_frame.shape[1] - logo.shape[1] :] = (
                logo
            )
            if pose_landmarks is not None:
                mp_drawing.draw_landmarks(
                    image=output_frame,
                    landmark_list=pose_landmarks,
                    connections=mp_pose.POSE_CONNECTIONS,
                )

                # Get landmarks
                frame_height, frame_width = output_frame.shape[0], output_frame.shape[1]
                pose_landmarks = np.array(
                    [
                        [lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width]
                        for lmk in pose_landmarks.landmark
                    ],
                    dtype=np.float32,
                )
                assert pose_landmarks.shape == (
                    33,
                    3,
                ), "Unexpected landmarks shape: {}".format(pose_landmarks.shape)

                # Classify the pose on the current frame
                pose_classification = pose_classifier(pose_landmarks)

                # Smooth classification using EMA
                pose_classification_filtered = pose_classification_filter(
                    pose_classification
                )

                # Count repetitions
                repetitions_count = repetition_counter(pose_classification_filtered)

                # Display repetitions count on the frame
                cv2.putText(
                    output_frame,
                    f"Push-Ups: {repetitions_count}",
                    (10, 30),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    1,
                    (0, 0, 0),
                    2,
                    cv2.LINE_AA,
                )
                # Display classified pose on the frame
                cv2.putText(
                    output_frame,
                    f"Pose: {pose_classification}",
                    (10, 70),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    1.2,  # Smaller font size
                    (0, 0, 0),
                    1,  # Thinner line
                    cv2.LINE_AA,
                )
            else:
                # If no landmarks are detected, still display the last count
                repetitions_count = repetition_counter.n_repeats
                cv2.putText(
                    output_frame,
                    f"Push-Ups: {repetitions_count}",
                    (10, 30),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    1,
                    (0, 255, 0),
                    2,
                    cv2.LINE_AA,
                )

            cv2.imshow("Push-Up Counter", output_frame)

            key = cv2.waitKey(1) & 0xFF
            if key == ord("q"):
                break
            elif key == ord("r"):
                repetition_counter.reset()
                print("Counter reset!")

            frame_idx += 1
            pbar.update()

finally:

    pose_tracker.close()
    video_cap.release()
    cv2.destroyAllWindows()