File size: 8,565 Bytes
e282e15
 
 
 
 
 
 
 
f426b7a
e282e15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30058ce
 
 
 
 
 
f426b7a
 
30058ce
 
f426b7a
30058ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e282e15
30058ce
 
f426b7a
 
 
e282e15
30058ce
adaf96f
30058ce
 
 
 
e282e15
30058ce
 
 
 
 
 
 
 
f426b7a
30058ce
 
f426b7a
 
30058ce
 
 
 
 
 
 
 
f426b7a
30058ce
 
 
 
 
f426b7a
 
 
 
 
30058ce
f426b7a
30058ce
 
f426b7a
 
 
30058ce
 
f426b7a
e282e15
adaf96f
f426b7a
 
 
 
 
30058ce
f426b7a
 
 
 
 
 
 
 
 
 
 
 
 
 
30058ce
 
f426b7a
 
 
30058ce
f426b7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30058ce
 
 
 
f426b7a
30058ce
e282e15
c0bcda9
e282e15
b11844e
adaf96f
c0bcda9
e282e15
 
30058ce
 
 
 
 
e282e15
 
 
 
f426b7a
 
 
 
 
 
 
 
 
 
 
 
adaf96f
8258179
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import gradio as gr
import cv2
import numpy as np
import mediapipe as mp
from mediapipe.python.solutions import drawing_utils as mp_drawing
from PoseClassification.pose_embedding import FullBodyPoseEmbedding
from PoseClassification.pose_classifier import PoseClassifier
from PoseClassification.utils import EMADictSmoothing
import time

# Initialize components
mp_pose = mp.solutions.pose
pose_tracker = mp_pose.Pose()
pose_embedder = FullBodyPoseEmbedding()
pose_classifier = PoseClassifier(
    pose_samples_folder="data/yoga_poses_csvs_out",
    pose_embedder=pose_embedder,
    top_n_by_max_distance=30,
    top_n_by_mean_distance=10,
)
pose_classification_filter = EMADictSmoothing(window_size=10, alpha=0.2)

class_names = ["chair", "cobra", "dog", "goddess", "plank", "tree", "warrior", "none"]
position_threshold = 8.0


def check_major_current_position(positions_detected: dict, threshold_position) -> str:
    if max(positions_detected.values()) < float(threshold_position):
        return "none"
    return max(positions_detected, key=positions_detected.get)


def process_frame(frame):
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    result = pose_tracker.process(image=frame_rgb)
    pose_landmarks = result.pose_landmarks

    if pose_landmarks is not None:
        frame_height, frame_width = frame.shape[0], frame.shape[1]
        pose_landmarks = np.array(
            [
                [lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width]
                for lmk in pose_landmarks.landmark
            ],
            dtype=np.float32,
        )
        pose_classification = pose_classifier(pose_landmarks)
        pose_classification_filtered = pose_classification_filter(pose_classification)
        current_position = pose_classification_filtered
    else:
        current_position = {"none": 10.0}

    current_position_major = check_major_current_position(
        current_position, position_threshold
    )
    return current_position_major, frame


def yoga_position_from_stream():
    current_position = "none"
    position_timer = 0
    last_update_time = 0
    recording = False
    recorded_frames = []
    start_time = 0
    frame_count = 0

    def classify_pose(frame):
        nonlocal current_position, position_timer, last_update_time, recording, recorded_frames, start_time, frame_count
        if frame is None:
            return (
                None,
                None,
                current_position,
                f"Duration: {int(position_timer)} seconds",
            )

        new_position, processed_frame = process_frame(frame)

        if new_position != current_position:
            current_position = new_position
            position_timer = 0
            last_update_time = cv2.getTickCount() / cv2.getTickFrequency()
        else:
            current_time = cv2.getTickCount() / cv2.getTickFrequency()
            position_timer += current_time - last_update_time
            last_update_time = current_time

        mp_drawing.draw_landmarks(
            image=processed_frame,
            landmark_list=pose_tracker.process(
                cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB)
            ).pose_landmarks,
            connections=mp_pose.POSE_CONNECTIONS,
        )

        cv2.putText(
            processed_frame,
            f"Pose: {current_position}",
            (10, 30),
            cv2.FONT_HERSHEY_SIMPLEX,
            1,
            (0, 255, 0),
            2,
        )
        cv2.putText(
            processed_frame,
            f"Duration: {int(position_timer)} seconds",
            (10, 70),
            cv2.FONT_HERSHEY_SIMPLEX,
            1,
            (0, 255, 0),
            2,
        )

        if recording:
            recorded_frames.append(processed_frame)
            frame_count += 1
            if frame_count == 1:
                start_time = time.time()

        return (
            frame,
            processed_frame,
            current_position,
            f"Duration: {int(position_timer)} seconds",
        )

    def toggle_debug(debug_mode):
        return [
            gr.update(visible=debug_mode),
            gr.update(visible=not debug_mode),
            gr.update(visible=debug_mode),
        ]

    def start_recording():
        nonlocal recording, recorded_frames, start_time, frame_count
        recording = True
        recorded_frames = []
        start_time = 0
        frame_count = 0
        return "Recording started"

    def stop_recording():
        nonlocal recording
        recording = False
        return "Recording stopped"

    def save_video():
        nonlocal recorded_frames, start_time, frame_count
        if not recorded_frames:
            return None, "No recorded frames available"

        output_path = "recorded_yoga_session.mp4"
        height, width, _ = recorded_frames[0].shape

        # Calculate the actual frame rate
        elapsed_time = time.time() - start_time
        fps = frame_count / elapsed_time if elapsed_time > 0 else 30.0

        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

        for frame in recorded_frames:
            # Convert frame to BGR color space before writing
            frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
            out.write(frame_bgr)
        out.release()

        return output_path, f"Video saved successfully at {fps:.2f} FPS"

    with gr.Column() as yoga_stream:
        gr.Markdown("# Yoga Position Classifier", elem_classes=["custom-title"])
        gr.Markdown(
            "Stream live yoga sessions and get real-time pose classification.",
            elem_classes=["custom-subtitle"],
        )

        with gr.Row():
            with gr.Column(scale=3):
                video_feed = gr.Webcam(streaming=True, elem_classes=["custom-webcam"])

            with gr.Column(scale=2):
                pose_output = gr.Textbox(
                    label="Current Pose", elem_classes=["custom-textbox"]
                )
                timer_output = gr.Textbox(
                    label="Pose Duration", elem_classes=["custom-textbox"]
                )
                debug_toggle = gr.Checkbox(
                    label="Debug Mode", value=False, elem_classes=["custom-checkbox"]
                )

        with gr.Column(visible=False) as debug_view:
            classified_video = gr.Image(
                label="Classified Video Feed", elem_classes=["custom-image"]
            )
            with gr.Row():
                start_button = gr.Button(
                    "Start Recording", elem_classes=["custom-button"]
                )
                stop_button = gr.Button(
                    "Stop Recording", elem_classes=["custom-button"]
                )
            save_button = gr.Button("Save Recording", elem_classes=["custom-button"])
            recording_status = gr.Textbox(
                label="Recording Status", elem_classes=["custom-textbox"]
            )
            recorded_video = gr.Video(
                label="Recorded Video", elem_classes=["custom-video"]
            )
            download_button = gr.Button(
                "Download Recorded Video", elem_classes=["custom-button"]
            )

        debug_toggle.change(
            toggle_debug,
            inputs=[debug_toggle],
            outputs=[debug_view, video_feed, classified_video],
        )

        video_feed.stream(
            classify_pose,
            inputs=[video_feed],
            outputs=[video_feed, classified_video, pose_output, timer_output],
            show_progress=False,
        )

        start_button.click(start_recording, outputs=[recording_status])
        stop_button.click(stop_recording, outputs=[recording_status])
        save_button.click(save_video, outputs=[recorded_video, recording_status])
        download_button.click(lambda: "recorded_yoga_session.mp4", outputs=[gr.File()])

    return yoga_stream


if __name__ == "__main__":
    with gr.Blocks(
        css="""
        .custom-title { font-size: 36px; font-weight: bold; margin-bottom: 10px; }
        .custom-subtitle { font-size: 18px; margin-bottom: 20px; }
        .custom-webcam { height: 480px; }
        .custom-textbox input { font-size: 24px; }
        .custom-checkbox label { font-size: 18px; }
        .custom-button { font-size: 18px; }
        .custom-image img { max-height: 400px; }
        .custom-video video { max-height: 400px; }
        """
    ) as demo:
        yoga_position_from_stream()
    demo.launch()