Spaces:
Sleeping
Sleeping
| import argparse | |
| import sys | |
| import cv2 | |
| import numpy as np | |
| from rich.console import Console | |
| from rich.panel import Panel | |
| from rich.align import Align | |
| from rich.layout import Layout | |
| from pyfiglet import Figlet | |
| import mediapipe as mp | |
| from PoseClassification.pose_embedding import FullBodyPoseEmbedding | |
| from PoseClassification.pose_classifier import PoseClassifier | |
| from PoseClassification.utils import EMADictSmoothing | |
| from PoseClassification.visualize import PoseClassificationVisualizer | |
| # For cross-platform compatibility | |
| try: | |
| import msvcrt # Windows | |
| except ImportError: | |
| import termios # Unix-like | |
| import tty | |
| def getch(): | |
| if sys.platform == "win32": | |
| return msvcrt.getch().decode("utf-8") | |
| else: | |
| fd = sys.stdin.fileno() | |
| old_settings = termios.tcgetattr(fd) | |
| try: | |
| tty.setraw(sys.stdin.fileno()) | |
| ch = sys.stdin.read(1) | |
| finally: | |
| termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) | |
| return ch | |
| def create_ascii_title(text): | |
| f = Figlet(font="isometric2") | |
| return f.renderText(text) | |
| def main(input_source, display=False, output_file=None): | |
| console = Console() | |
| layout = Layout() | |
| # Create ASCII title | |
| ascii_title = create_ascii_title("YOGAI") | |
| # Create the layout | |
| layout.split( | |
| Layout(Panel(Align.center(ascii_title), border_style="bold blue"), size=15), | |
| Layout(name="main"), | |
| ) | |
| is_live = input_source == "live" | |
| if is_live: | |
| layout["main"].update( | |
| Panel( | |
| "Processing live video from camera", | |
| title="Video Classification", | |
| border_style="bold blue", | |
| ) | |
| ) | |
| else: | |
| layout["main"].update( | |
| Panel( | |
| f"Processing video: {input_source}", | |
| title="Video Classification", | |
| border_style="bold blue", | |
| ) | |
| ) | |
| console.print(layout) | |
| # Initialize pose tracker, embedder, and classifier | |
| mp_pose = mp.solutions.pose | |
| pose_tracker = mp_pose.Pose() | |
| pose_embedder = FullBodyPoseEmbedding() | |
| pose_classifier = PoseClassifier( | |
| pose_samples_folder="data/yoga_poses_csvs_out", | |
| pose_embedder=pose_embedder, | |
| top_n_by_max_distance=30, | |
| top_n_by_mean_distance=10, | |
| ) | |
| pose_classification_filter = EMADictSmoothing(window_size=10, alpha=0.2) | |
| # Open the video source | |
| if is_live: | |
| video = cv2.VideoCapture(0) | |
| fps = 30 # Assume 30 fps for live video | |
| total_frames = float("inf") # Infinite frames for live video | |
| else: | |
| video = cv2.VideoCapture(input_source) | |
| fps = video.get(cv2.CAP_PROP_FPS) | |
| total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # Initialize pose timings (use lowercase for keys) | |
| pose_timings = { | |
| "chair": 0, | |
| "cobra": 0, | |
| "dog": 0, | |
| "plank": 0, | |
| "goddess": 0, | |
| "tree": 0, | |
| "warrior": 0, | |
| "no pose detected": 0, | |
| "fallen": 0, | |
| } | |
| frame_count = 0 | |
| while True: | |
| ret, frame = video.read() | |
| if not ret: | |
| if is_live: | |
| console.print( | |
| "[bold red]Error reading from camera. Exiting...[/bold red]" | |
| ) | |
| break | |
| # Process the frame | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| result = pose_tracker.process(image=frame_rgb) | |
| if result.pose_landmarks is not None: | |
| # Draw landmarks on the frame | |
| mp.solutions.drawing_utils.draw_landmarks( | |
| frame, result.pose_landmarks, mp_pose.POSE_CONNECTIONS | |
| ) | |
| frame_height, frame_width = frame.shape[0], frame.shape[1] | |
| pose_landmarks = np.array( | |
| [ | |
| [lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width] | |
| for lmk in result.pose_landmarks.landmark | |
| ], | |
| dtype=np.float32, | |
| ) | |
| # Classify the pose | |
| pose_classification = pose_classifier(pose_landmarks) | |
| pose_classification_filtered = pose_classification_filter( | |
| pose_classification | |
| ) | |
| # Update pose timings (only for the pose with highest confidence) | |
| max_pose = max( | |
| pose_classification_filtered, key=pose_classification_filtered.get | |
| ).lower() | |
| pose_timings[max_pose] += 1 / fps | |
| else: | |
| pose_timings["no pose detected"] += 1 / fps | |
| frame_count += 1 | |
| if frame_count % 30 == 0: # Update every 30 frames | |
| panel_content = ( | |
| f"[bold]Chair:[/bold] {pose_timings['chair']:.2f}s\n" | |
| f"[bold]Cobra:[/bold] {pose_timings['cobra']:.2f}s\n" | |
| f"[bold]Dog:[/bold] {pose_timings['dog']:.2f}s\n" | |
| f"[bold]Plank:[/bold] {pose_timings['plank']:.2f}s\n" | |
| f"[bold]Goddess:[/bold] {pose_timings['goddess']:.2f}s\n" | |
| f"[bold]Tree:[/bold] {pose_timings['tree']:.2f}s\n" | |
| f"[bold]Warrior:[/bold] {pose_timings['warrior']:.2f}s\n" | |
| f"---\n" | |
| f"[bold]No pose detected:[/bold] {pose_timings['no pose detected']:.2f}s\n" | |
| f"[bold]Fallen:[/bold] {pose_timings['fallen']:.2f}s" | |
| ) | |
| if not is_live: | |
| panel_content += f"\n\nProcessed {frame_count}/{total_frames} frames" | |
| layout["main"].update( | |
| Panel( | |
| panel_content, | |
| title="Classification Results", | |
| border_style="bold green", | |
| ) | |
| ) | |
| console.print(layout) | |
| if display: | |
| cv2.imshow("Video", frame) | |
| if cv2.waitKey(1) & 0xFF == ord("q"): | |
| break | |
| video.release() | |
| if display: | |
| cv2.destroyAllWindows() | |
| # Final results | |
| final_panel_content = ( | |
| f"[bold]Chair:[/bold] {pose_timings['chair']:.2f}s\n" | |
| f"[bold]Cobra:[/bold] {pose_timings['cobra']:.2f}s\n" | |
| f"[bold]Dog:[/bold] {pose_timings['dog']:.2f}s\n" | |
| f"[bold]Plank:[/bold] {pose_timings['plank']:.2f}s\n" | |
| f"[bold]Goddess:[/bold] {pose_timings['goddess']:.2f}s\n" | |
| f"[bold]Tree:[/bold] {pose_timings['tree']:.2f}s\n" | |
| f"[bold]Warrior:[/bold] {pose_timings['warrior']:.2f}s\n" | |
| f"---\n" | |
| f"[bold]No pose detected:[/bold] {pose_timings['no pose detected']:.2f}s\n" | |
| f"[bold]Fallen:[/bold] {pose_timings['fallen']:.2f}s" | |
| ) | |
| layout["main"].update( | |
| Panel( | |
| final_panel_content, | |
| title="Final Classification Results", | |
| border_style="bold green", | |
| ) | |
| ) | |
| console.print(layout) | |
| if output_file: | |
| console.print(f"[green]Output saved to: {output_file}[/green]") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Classify poses in a video file or from live camera." | |
| ) | |
| parser.add_argument("input", help="Input video file or 'live' for camera feed") | |
| parser.add_argument( | |
| "--display", action="store_true", help="Display the video with detected poses" | |
| ) | |
| parser.add_argument("--output", help="Output video file") | |
| if len(sys.argv) == 1: | |
| parser.print_help(sys.stderr) | |
| sys.exit(1) | |
| args = parser.parse_args() | |
| main(args.input, args.display, args.output) | |