Spaces:
Runtime error
Runtime error
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from ultralytics import YOLO | |
| from ultralytics.nn.tasks import DetectionModel | |
| # Allow safe loading of the DetectionModel class. | |
| torch.serialization.add_safe_globals([DetectionModel]) | |
| # Load your trained YOLO model from best.pt. | |
| # Ensure that best.pt is in the repository root. | |
| model = YOLO("best.pt") | |
| # List of available poses (order must match your model's training) | |
| POSE_OPTIONS = [ | |
| "Bridge_Pose_or_Setu_Bandha_Sarvangasana_", | |
| "Cat_Cow_Pose_or_Marjaryasana_", | |
| "Child_Pose_or_Balasana_", | |
| "Cobra_Pose_or_Bhujangasana_", | |
| "Corpse_Pose_or_Savasana_", | |
| "Downward-Facing_Dog_pose_or_Adho_Mukha_Svanasana_", | |
| "Fish_Pose_or_Matsyasana_", | |
| "Half_Lord_of_the_Fishes_Pose_or_Ardha_Matsyendrasana_", | |
| "Legs-Up-the-Wall_Pose_or_Viparita_Karani_", | |
| "Pigeon_Pose_or_Kapotasana_", | |
| "Seated_Forward_Bend_pose_or_Paschimottanasana_", | |
| "Standing_Forward_Bend_pose_or_Uttanasana_", | |
| "Tree_Pose_or_Vrksasana_", | |
| "Warrior_II_Pose_or_Virabhadrasana_II_", | |
| "Warrior_I_Pose_or_Virabhadrasana_I_" | |
| ] | |
| # Create a mapping from pose names to the class index expected by your model. | |
| label_mapping = {pose: i for i, pose in enumerate(POSE_OPTIONS)} | |
| # Difficulty options – these can be used to adjust duration or parameters. | |
| DIFFICULTY_OPTIONS = [ | |
| "Beginner (10s)", | |
| "Intermediate (25s)", | |
| "Advanced (45s)" | |
| ] | |
| def detect_pose(image: np.ndarray, pose: str, difficulty: str) -> np.ndarray: | |
| """ | |
| Run inference on the input image using the YOLO model. | |
| Searches for the selected pose and overlays the detection's confidence as a correctness percentage. | |
| """ | |
| # Run the model on the provided image. | |
| results = model(image, verbose=False) | |
| # Copy the image for annotation (Gradio inputs are in RGB format) | |
| annotated_image = image.copy() | |
| if results: | |
| res = results[0] # Process the first result. | |
| detections = res.boxes # Detections with .cls and .conf | |
| # Get the class index corresponding to the selected pose. | |
| selected_class = label_mapping.get(pose, None) | |
| best_confidence = 0.0 # Initialize confidence. | |
| # Loop through detections to find the best confidence for the selected class. | |
| for box in detections: | |
| cls_pred = int(box.cls.item()) | |
| conf = box.conf.item() * 100 # Convert to percentage. | |
| if selected_class is not None and cls_pred == selected_class: | |
| best_confidence = max(best_confidence, conf) | |
| # Decide overlay text and color based on confidence. | |
| if best_confidence > 0: | |
| if best_confidence >= 80: # Threshold for "correct" | |
| color = (0, 255, 0) # Green for correct. | |
| status = "Correct" | |
| else: | |
| color = (0, 0, 255) # Red for incorrect. | |
| status = "Incorrect" | |
| text = f"{pose}: {status} ({best_confidence:.1f}%)" | |
| else: | |
| color = (0, 0, 255) | |
| text = f"{pose}: Not Detected" | |
| # Convert image to BGR for OpenCV drawing. | |
| annotated_image_bgr = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR) | |
| cv2.putText(annotated_image_bgr, text, (30, 50), | |
| cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2) | |
| # Convert back to RGB before returning. | |
| annotated_image = cv2.cvtColor(annotated_image_bgr, cv2.COLOR_BGR2RGB) | |
| return annotated_image | |
| iface = gr.Interface( | |
| fn=detect_pose, | |
| inputs=[ | |
| gr.components.Camera(source="webcam", type="numpy", label="Camera Feed"), | |
| gr.components.Dropdown(choices=POSE_OPTIONS, label="Select Pose"), | |
| gr.components.Dropdown(choices=DIFFICULTY_OPTIONS, label="Select Difficulty") | |
| ], | |
| outputs=gr.components.Image(label="Detection Result"), | |
| live=True, | |
| title="Yoga Pose Detection", | |
| description=( | |
| "This app uses your trained YOLO model (best.pt) to detect the selected yoga pose " | |
| "and display a correctness percentage based on the model's confidence. " | |
| "Green indicates a correct pose; red indicates incorrect or not detected." | |
| ) | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |