Priyansh01's picture
Update app.py
bd5cd64 verified
import cv2
import gradio as gr
import numpy as np
import torch
from ultralytics import YOLO
from ultralytics.nn.tasks import DetectionModel
# Allow safe loading of the DetectionModel class.
torch.serialization.add_safe_globals([DetectionModel])
# Load your trained YOLO model from best.pt.
# Ensure that best.pt is in the repository root.
model = YOLO("best.pt")
# List of available poses (order must match your model's training)
POSE_OPTIONS = [
"Bridge_Pose_or_Setu_Bandha_Sarvangasana_",
"Cat_Cow_Pose_or_Marjaryasana_",
"Child_Pose_or_Balasana_",
"Cobra_Pose_or_Bhujangasana_",
"Corpse_Pose_or_Savasana_",
"Downward-Facing_Dog_pose_or_Adho_Mukha_Svanasana_",
"Fish_Pose_or_Matsyasana_",
"Half_Lord_of_the_Fishes_Pose_or_Ardha_Matsyendrasana_",
"Legs-Up-the-Wall_Pose_or_Viparita_Karani_",
"Pigeon_Pose_or_Kapotasana_",
"Seated_Forward_Bend_pose_or_Paschimottanasana_",
"Standing_Forward_Bend_pose_or_Uttanasana_",
"Tree_Pose_or_Vrksasana_",
"Warrior_II_Pose_or_Virabhadrasana_II_",
"Warrior_I_Pose_or_Virabhadrasana_I_"
]
# Create a mapping from pose names to the class index expected by your model.
label_mapping = {pose: i for i, pose in enumerate(POSE_OPTIONS)}
# Difficulty options – these can be used to adjust duration or parameters.
DIFFICULTY_OPTIONS = [
"Beginner (10s)",
"Intermediate (25s)",
"Advanced (45s)"
]
def detect_pose(image: np.ndarray, pose: str, difficulty: str) -> np.ndarray:
"""
Run inference on the input image using the YOLO model.
Searches for the selected pose and overlays the detection's confidence as a correctness percentage.
"""
# Run the model on the provided image.
results = model(image, verbose=False)
# Copy the image for annotation (Gradio inputs are in RGB format)
annotated_image = image.copy()
if results:
res = results[0] # Process the first result.
detections = res.boxes # Detections with .cls and .conf
# Get the class index corresponding to the selected pose.
selected_class = label_mapping.get(pose, None)
best_confidence = 0.0 # Initialize confidence.
# Loop through detections to find the best confidence for the selected class.
for box in detections:
cls_pred = int(box.cls.item())
conf = box.conf.item() * 100 # Convert to percentage.
if selected_class is not None and cls_pred == selected_class:
best_confidence = max(best_confidence, conf)
# Decide overlay text and color based on confidence.
if best_confidence > 0:
if best_confidence >= 80: # Threshold for "correct"
color = (0, 255, 0) # Green for correct.
status = "Correct"
else:
color = (0, 0, 255) # Red for incorrect.
status = "Incorrect"
text = f"{pose}: {status} ({best_confidence:.1f}%)"
else:
color = (0, 0, 255)
text = f"{pose}: Not Detected"
# Convert image to BGR for OpenCV drawing.
annotated_image_bgr = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR)
cv2.putText(annotated_image_bgr, text, (30, 50),
cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
# Convert back to RGB before returning.
annotated_image = cv2.cvtColor(annotated_image_bgr, cv2.COLOR_BGR2RGB)
return annotated_image
iface = gr.Interface(
fn=detect_pose,
inputs=[
gr.components.Camera(source="webcam", type="numpy", label="Camera Feed"),
gr.components.Dropdown(choices=POSE_OPTIONS, label="Select Pose"),
gr.components.Dropdown(choices=DIFFICULTY_OPTIONS, label="Select Difficulty")
],
outputs=gr.components.Image(label="Detection Result"),
live=True,
title="Yoga Pose Detection",
description=(
"This app uses your trained YOLO model (best.pt) to detect the selected yoga pose "
"and display a correctness percentage based on the model's confidence. "
"Green indicates a correct pose; red indicates incorrect or not detected."
)
)
if __name__ == "__main__":
iface.launch()