Spaces:
Paused
Paused
Emily Chen commited on
Commit ·
a422282
0
Parent(s):
initial commit
Browse filesSigned-off-by: Emily Chen <emilychen@Emilys-iMac.lan>
- README.md +92 -0
- app/__init__.py +0 -0
- app/components/__init__.py +0 -0
- app/main.py +102 -0
- app/models/__init__.py +0 -0
- app/models/llm_analyzer.py +172 -0
- app/models/pose_estimator.py +156 -0
- app/models/swing_analyzer.py +174 -0
- app/streamlit_app.py +281 -0
- app/utils/__init__.py +0 -0
- app/utils/video_downloader.py +79 -0
- app/utils/video_processor.py +91 -0
- app/utils/visualizer.py +225 -0
- requirements.txt +12 -0
- run_streamlit.sh +10 -0
- setup_directories.sh +19 -0
README.md
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Golf Swing Analysis
|
| 2 |
+
|
| 3 |
+
A Python application that analyzes golf swings from YouTube videos using computer vision and AI.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- YouTube video retrieval and processing using yt-dlp
|
| 8 |
+
- Golfer, club, and ball detection using YOLOv8
|
| 9 |
+
- Pose estimation for swing analysis
|
| 10 |
+
- Swing phase segmentation (setup, backswing, downswing, impact, follow-through)
|
| 11 |
+
- Trajectory and speed analysis
|
| 12 |
+
- AI-powered swing evaluation and coaching tips
|
| 13 |
+
- Visual feedback with annotations
|
| 14 |
+
- Streamlit web interface
|
| 15 |
+
|
| 16 |
+
## Installation
|
| 17 |
+
|
| 18 |
+
1. Clone this repository
|
| 19 |
+
2. Run the setup script to create necessary directories:
|
| 20 |
+
```
|
| 21 |
+
chmod +x setup_directories.sh
|
| 22 |
+
./setup_directories.sh
|
| 23 |
+
```
|
| 24 |
+
3. Create a virtual environment:
|
| 25 |
+
```
|
| 26 |
+
python -m venv .venv
|
| 27 |
+
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
| 28 |
+
```
|
| 29 |
+
4. Install dependencies:
|
| 30 |
+
```
|
| 31 |
+
pip install -r requirements.txt
|
| 32 |
+
```
|
| 33 |
+
5. Edit the `.env` file with your OpenAI API key:
|
| 34 |
+
```
|
| 35 |
+
OPENAI_API_KEY=your_api_key_here
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## Usage
|
| 39 |
+
|
| 40 |
+
### Command Line Interface
|
| 41 |
+
|
| 42 |
+
Run the main application:
|
| 43 |
+
|
| 44 |
+
```
|
| 45 |
+
python app/main.py
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
Follow the prompts to input a YouTube URL containing a golf swing recording.
|
| 49 |
+
|
| 50 |
+
### Streamlit Web Interface
|
| 51 |
+
|
| 52 |
+
Run the Streamlit web app using the provided shell script:
|
| 53 |
+
|
| 54 |
+
```
|
| 55 |
+
./run_streamlit.sh
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
Or manually with:
|
| 59 |
+
|
| 60 |
+
```
|
| 61 |
+
source .venv/bin/activate
|
| 62 |
+
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
The web interface provides:
|
| 66 |
+
- Options to upload a video or use a YouTube URL
|
| 67 |
+
- Control over frame skip rate for YOLO detection
|
| 68 |
+
- Toggle for enabling/disabling GPT analysis
|
| 69 |
+
- Interactive display of analysis results
|
| 70 |
+
- Option to create and view annotated videos
|
| 71 |
+
|
| 72 |
+
## File Organization
|
| 73 |
+
|
| 74 |
+
- **downloads/**: Contains both downloaded YouTube videos and annotated videos
|
| 75 |
+
- All videos (both original and annotated) are stored in the same directory for easy access
|
| 76 |
+
|
| 77 |
+
## Troubleshooting
|
| 78 |
+
|
| 79 |
+
If you encounter issues with the "Create Annotated Video" button:
|
| 80 |
+
1. Make sure you've run the setup script to create the downloads directory
|
| 81 |
+
2. Check that the `downloads` directory has write permissions
|
| 82 |
+
3. Try restarting the Streamlit app
|
| 83 |
+
|
| 84 |
+
## Requirements
|
| 85 |
+
|
| 86 |
+
- Python 3.8+
|
| 87 |
+
- OpenCV
|
| 88 |
+
- YOLOv8
|
| 89 |
+
- MediaPipe
|
| 90 |
+
- yt-dlp
|
| 91 |
+
- OpenAI API key
|
| 92 |
+
- Streamlit
|
app/__init__.py
ADDED
|
File without changes
|
app/components/__init__.py
ADDED
|
File without changes
|
app/main.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Golf Swing Analysis - Main Application
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
|
| 10 |
+
# Load environment variables
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
# Add the app directory to the path
|
| 14 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 15 |
+
|
| 16 |
+
from app.utils.video_downloader import download_youtube_video
|
| 17 |
+
from app.utils.video_processor import process_video
|
| 18 |
+
from app.models.pose_estimator import analyze_pose
|
| 19 |
+
from app.models.swing_analyzer import segment_swing, analyze_trajectory
|
| 20 |
+
from app.models.llm_analyzer import generate_swing_analysis
|
| 21 |
+
from app.utils.visualizer import create_annotated_video
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main():
|
| 25 |
+
"""Main application function"""
|
| 26 |
+
print("\n===== Golf Swing Analysis =====\n")
|
| 27 |
+
|
| 28 |
+
# Step 1: Get YouTube URL from user
|
| 29 |
+
youtube_url = input("Enter YouTube URL of golf swing: ")
|
| 30 |
+
|
| 31 |
+
# Step 2: Configure analysis options
|
| 32 |
+
enable_gpt = input(
|
| 33 |
+
"\nEnable GPT analysis? (y/n, default: y): ").lower() != 'n'
|
| 34 |
+
|
| 35 |
+
sample_rate_input = input(
|
| 36 |
+
"\nFrame skip rate for YOLO (1-10, default: 5): ")
|
| 37 |
+
sample_rate = 5 # Default value
|
| 38 |
+
if sample_rate_input.isdigit():
|
| 39 |
+
sample_rate = max(1, min(10, int(sample_rate_input)))
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
# Step 3: Download the video
|
| 43 |
+
print("\nDownloading video...")
|
| 44 |
+
video_path = download_youtube_video(youtube_url)
|
| 45 |
+
print(f"Video downloaded to: {video_path}")
|
| 46 |
+
|
| 47 |
+
# Step 4: Process video and detect golfer, club, and ball
|
| 48 |
+
print("\nProcessing video and detecting objects...")
|
| 49 |
+
frames, detections = process_video(video_path, sample_rate=sample_rate)
|
| 50 |
+
|
| 51 |
+
# Step 5: Analyze pose throughout the swing
|
| 52 |
+
print("\nAnalyzing golfer's pose...")
|
| 53 |
+
pose_data = analyze_pose(frames)
|
| 54 |
+
|
| 55 |
+
# Step 6: Segment swing into phases
|
| 56 |
+
print("\nSegmenting swing phases...")
|
| 57 |
+
swing_phases = segment_swing(pose_data,
|
| 58 |
+
detections,
|
| 59 |
+
sample_rate=sample_rate)
|
| 60 |
+
|
| 61 |
+
# Step 7: Analyze trajectory and speed
|
| 62 |
+
print("\nAnalyzing trajectory and speed...")
|
| 63 |
+
trajectory_data = analyze_trajectory(frames,
|
| 64 |
+
detections,
|
| 65 |
+
swing_phases,
|
| 66 |
+
sample_rate=sample_rate)
|
| 67 |
+
|
| 68 |
+
# Step 8: Generate swing analysis using LLM (if enabled)
|
| 69 |
+
if enable_gpt:
|
| 70 |
+
print("\nGenerating swing analysis and coaching tips...")
|
| 71 |
+
analysis = generate_swing_analysis(pose_data, swing_phases,
|
| 72 |
+
trajectory_data)
|
| 73 |
+
|
| 74 |
+
# Display results
|
| 75 |
+
print("\n===== Swing Analysis Results =====\n")
|
| 76 |
+
print(analysis)
|
| 77 |
+
else:
|
| 78 |
+
print("\nGPT analysis disabled. Skipping swing evaluation.")
|
| 79 |
+
|
| 80 |
+
# Step 9: Create annotated video (optional)
|
| 81 |
+
create_video = input(
|
| 82 |
+
"\nCreate annotated video? (y/n): ").lower() == 'y'
|
| 83 |
+
if create_video:
|
| 84 |
+
print("\nCreating annotated video...")
|
| 85 |
+
output_path = create_annotated_video(video_path,
|
| 86 |
+
frames,
|
| 87 |
+
detections,
|
| 88 |
+
pose_data,
|
| 89 |
+
swing_phases,
|
| 90 |
+
trajectory_data,
|
| 91 |
+
sample_rate=sample_rate)
|
| 92 |
+
print(f"Annotated video saved to: {output_path}")
|
| 93 |
+
|
| 94 |
+
print("\nAnalysis complete!")
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f"\nError: {str(e)}")
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
main()
|
app/models/__init__.py
ADDED
|
File without changes
|
app/models/llm_analyzer.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM-based golf swing analysis module
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
from openai import OpenAI
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def generate_swing_analysis(pose_data, swing_phases, trajectory_data):
|
| 11 |
+
"""
|
| 12 |
+
Generate swing analysis and coaching tips using LLM
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
pose_data (dict): Dictionary mapping frame indices to pose keypoints
|
| 16 |
+
swing_phases (dict): Dictionary mapping phase names to lists of frame indices
|
| 17 |
+
trajectory_data (dict): Dictionary mapping frame indices to trajectory data
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
str: Detailed swing analysis and coaching tips
|
| 21 |
+
"""
|
| 22 |
+
# Check if OpenAI API key is available
|
| 23 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 24 |
+
if not api_key:
|
| 25 |
+
return "Error: OpenAI API key not found. Please set the OPENAI_API_KEY environment variable."
|
| 26 |
+
|
| 27 |
+
# Create OpenAI client
|
| 28 |
+
client = OpenAI(api_key=api_key)
|
| 29 |
+
|
| 30 |
+
# Prepare data for LLM
|
| 31 |
+
analysis_data = prepare_data_for_llm(pose_data, swing_phases,
|
| 32 |
+
trajectory_data)
|
| 33 |
+
|
| 34 |
+
# Generate prompt for LLM
|
| 35 |
+
prompt = create_llm_prompt(analysis_data)
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
# Call OpenAI API
|
| 39 |
+
response = client.chat.completions.create(
|
| 40 |
+
model="gpt-4",
|
| 41 |
+
messages=[{
|
| 42 |
+
"role":
|
| 43 |
+
"system",
|
| 44 |
+
"content":
|
| 45 |
+
"You are a professional golf coach with expertise in analyzing golf swings. Provide detailed, actionable feedback based on the swing data provided."
|
| 46 |
+
}, {
|
| 47 |
+
"role": "user",
|
| 48 |
+
"content": prompt
|
| 49 |
+
}],
|
| 50 |
+
temperature=0.7,
|
| 51 |
+
max_tokens=1000)
|
| 52 |
+
|
| 53 |
+
# Extract and return analysis
|
| 54 |
+
analysis = response.choices[0].message.content
|
| 55 |
+
return analysis
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
return f"Error generating swing analysis: {str(e)}"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def prepare_data_for_llm(pose_data, swing_phases, trajectory_data):
|
| 62 |
+
"""
|
| 63 |
+
Prepare swing data for LLM analysis
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
pose_data (dict): Dictionary mapping frame indices to pose keypoints
|
| 67 |
+
swing_phases (dict): Dictionary mapping phase names to lists of frame indices
|
| 68 |
+
trajectory_data (dict): Dictionary mapping frame indices to trajectory data
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
dict: Processed data for LLM analysis
|
| 72 |
+
"""
|
| 73 |
+
analysis_data = {"swing_phases": {}, "joint_angles": {}, "trajectory": {}}
|
| 74 |
+
|
| 75 |
+
# Process swing phases
|
| 76 |
+
for phase, frames in swing_phases.items():
|
| 77 |
+
if frames:
|
| 78 |
+
# Get a representative frame for each phase
|
| 79 |
+
mid_frame = frames[len(frames) // 2]
|
| 80 |
+
|
| 81 |
+
# Get joint angles for the representative frame
|
| 82 |
+
if mid_frame in pose_data:
|
| 83 |
+
keypoints = pose_data[mid_frame]
|
| 84 |
+
|
| 85 |
+
# Calculate key metrics for each phase
|
| 86 |
+
analysis_data["swing_phases"][phase] = {
|
| 87 |
+
"frame_index": mid_frame,
|
| 88 |
+
"duration_frames": len(frames)
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
# Process trajectory data
|
| 92 |
+
impact_frames = swing_phases.get("impact", [])
|
| 93 |
+
if impact_frames:
|
| 94 |
+
impact_frame = impact_frames[len(impact_frames) // 2]
|
| 95 |
+
if impact_frame in trajectory_data:
|
| 96 |
+
impact_data = trajectory_data[impact_frame]
|
| 97 |
+
if "club_speed" in impact_data and impact_data["club_speed"]:
|
| 98 |
+
analysis_data["trajectory"]["club_speed_mph"] = impact_data[
|
| 99 |
+
"club_speed"]
|
| 100 |
+
|
| 101 |
+
# Add additional metrics that would be calculated in a real implementation
|
| 102 |
+
# These are placeholder values for demonstration
|
| 103 |
+
analysis_data["metrics"] = {
|
| 104 |
+
"tempo_ratio": 3.0, # Backswing to downswing time ratio
|
| 105 |
+
"swing_plane_consistency": 0.85, # 0-1 scale
|
| 106 |
+
"weight_shift": 0.7, # 0-1 scale
|
| 107 |
+
"hip_rotation": 45, # degrees
|
| 108 |
+
"shoulder_rotation": 90, # degrees
|
| 109 |
+
"wrist_hinge": 80, # degrees
|
| 110 |
+
"posture_score": 0.8 # 0-1 scale
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
return analysis_data
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def create_llm_prompt(analysis_data):
|
| 117 |
+
"""
|
| 118 |
+
Create a prompt for the LLM based on swing analysis data
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
analysis_data (dict): Processed swing analysis data
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
str: Prompt for LLM
|
| 125 |
+
"""
|
| 126 |
+
prompt = """
|
| 127 |
+
I've analyzed a golf swing and extracted the following data:
|
| 128 |
+
|
| 129 |
+
## Swing Phases
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
# Add swing phases information
|
| 133 |
+
for phase, data in analysis_data["swing_phases"].items():
|
| 134 |
+
prompt += f"- {phase.capitalize()}: Frame {data['frame_index']}, Duration: {data['duration_frames']} frames\n"
|
| 135 |
+
|
| 136 |
+
# Add trajectory information
|
| 137 |
+
prompt += "\n## Trajectory Data\n"
|
| 138 |
+
if "trajectory" in analysis_data and "club_speed_mph" in analysis_data[
|
| 139 |
+
"trajectory"]:
|
| 140 |
+
prompt += f"- Club Speed: {analysis_data['trajectory']['club_speed_mph']:.1f} mph\n"
|
| 141 |
+
|
| 142 |
+
# Add metrics
|
| 143 |
+
prompt += "\n## Swing Metrics\n"
|
| 144 |
+
for metric, value in analysis_data["metrics"].items():
|
| 145 |
+
# Format metric name for readability
|
| 146 |
+
metric_name = metric.replace("_", " ").title()
|
| 147 |
+
|
| 148 |
+
# Format value based on type
|
| 149 |
+
if isinstance(value, float):
|
| 150 |
+
if 0 <= value <= 1:
|
| 151 |
+
# Format as percentage for 0-1 scale metrics
|
| 152 |
+
formatted_value = f"{value * 100:.0f}%"
|
| 153 |
+
else:
|
| 154 |
+
# Format as decimal for other floats
|
| 155 |
+
formatted_value = f"{value:.1f}"
|
| 156 |
+
else:
|
| 157 |
+
# Use as is for integers and other types
|
| 158 |
+
formatted_value = str(value)
|
| 159 |
+
|
| 160 |
+
prompt += f"- {metric_name}: {formatted_value}\n"
|
| 161 |
+
|
| 162 |
+
prompt += """
|
| 163 |
+
Based on this data, please provide:
|
| 164 |
+
1. A detailed analysis of the golf swing
|
| 165 |
+
2. Key strengths and weaknesses
|
| 166 |
+
3. Specific recommendations for improvement
|
| 167 |
+
4. Drills or exercises that could help address the identified issues
|
| 168 |
+
|
| 169 |
+
Please be specific and actionable in your feedback.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
return prompt
|
app/models/pose_estimator.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pose estimation module for golf swing analysis
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import mediapipe as mp
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PoseEstimator:
|
| 12 |
+
"""MediaPipe-based pose estimator for golf swing analysis"""
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
"""Initialize the pose estimator"""
|
| 16 |
+
self.mp_pose = mp.solutions.pose
|
| 17 |
+
self.pose = self.mp_pose.Pose(static_image_mode=False,
|
| 18 |
+
model_complexity=2,
|
| 19 |
+
enable_segmentation=False,
|
| 20 |
+
min_detection_confidence=0.5,
|
| 21 |
+
min_tracking_confidence=0.5)
|
| 22 |
+
|
| 23 |
+
def process_frame(self, frame):
|
| 24 |
+
"""
|
| 25 |
+
Process a single frame and extract pose landmarks
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
frame (numpy.ndarray): Input frame
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
list: List of keypoints [x, y, visibility] or None if not detected
|
| 32 |
+
"""
|
| 33 |
+
# Convert BGR to RGB
|
| 34 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 35 |
+
|
| 36 |
+
# Process the frame
|
| 37 |
+
results = self.pose.process(frame_rgb)
|
| 38 |
+
|
| 39 |
+
if not results.pose_landmarks:
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
# Extract keypoints
|
| 43 |
+
keypoints = []
|
| 44 |
+
for landmark in results.pose_landmarks.landmark:
|
| 45 |
+
# Convert normalized coordinates to pixel coordinates
|
| 46 |
+
h, w, _ = frame.shape
|
| 47 |
+
x, y = int(landmark.x * w), int(landmark.y * h)
|
| 48 |
+
visibility = landmark.visibility
|
| 49 |
+
keypoints.append([x, y, visibility])
|
| 50 |
+
|
| 51 |
+
return keypoints
|
| 52 |
+
|
| 53 |
+
def close(self):
|
| 54 |
+
"""Release resources"""
|
| 55 |
+
self.pose.close()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def analyze_pose(frames):
|
| 59 |
+
"""
|
| 60 |
+
Analyze pose in video frames
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
frames (list): List of video frames
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
dict: Dictionary mapping frame indices to pose keypoints
|
| 67 |
+
"""
|
| 68 |
+
pose_estimator = PoseEstimator()
|
| 69 |
+
pose_data = {}
|
| 70 |
+
|
| 71 |
+
for i, frame in enumerate(tqdm(frames, desc="Analyzing pose")):
|
| 72 |
+
keypoints = pose_estimator.process_frame(frame)
|
| 73 |
+
if keypoints:
|
| 74 |
+
pose_data[i] = keypoints
|
| 75 |
+
|
| 76 |
+
pose_estimator.close()
|
| 77 |
+
|
| 78 |
+
return pose_data
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def calculate_joint_angles(keypoints):
|
| 82 |
+
"""
|
| 83 |
+
Calculate joint angles from pose keypoints
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
keypoints (list): List of keypoints [x, y, visibility]
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
dict: Dictionary of joint angles in degrees
|
| 90 |
+
"""
|
| 91 |
+
# Define joint connections for angle calculation
|
| 92 |
+
joint_connections = {
|
| 93 |
+
"right_shoulder": [
|
| 94 |
+
mp.solutions.pose.PoseLandmark.RIGHT_ELBOW.value,
|
| 95 |
+
mp.solutions.pose.PoseLandmark.RIGHT_SHOULDER.value,
|
| 96 |
+
mp.solutions.pose.PoseLandmark.RIGHT_HIP.value
|
| 97 |
+
],
|
| 98 |
+
"left_shoulder": [
|
| 99 |
+
mp.solutions.pose.PoseLandmark.LEFT_ELBOW.value,
|
| 100 |
+
mp.solutions.pose.PoseLandmark.LEFT_SHOULDER.value,
|
| 101 |
+
mp.solutions.pose.PoseLandmark.LEFT_HIP.value
|
| 102 |
+
],
|
| 103 |
+
"right_elbow": [
|
| 104 |
+
mp.solutions.pose.PoseLandmark.RIGHT_WRIST.value,
|
| 105 |
+
mp.solutions.pose.PoseLandmark.RIGHT_ELBOW.value,
|
| 106 |
+
mp.solutions.pose.PoseLandmark.RIGHT_SHOULDER.value
|
| 107 |
+
],
|
| 108 |
+
"left_elbow": [
|
| 109 |
+
mp.solutions.pose.PoseLandmark.LEFT_WRIST.value,
|
| 110 |
+
mp.solutions.pose.PoseLandmark.LEFT_ELBOW.value,
|
| 111 |
+
mp.solutions.pose.PoseLandmark.LEFT_SHOULDER.value
|
| 112 |
+
],
|
| 113 |
+
"right_hip": [
|
| 114 |
+
mp.solutions.pose.PoseLandmark.RIGHT_KNEE.value,
|
| 115 |
+
mp.solutions.pose.PoseLandmark.RIGHT_HIP.value,
|
| 116 |
+
mp.solutions.pose.PoseLandmark.RIGHT_SHOULDER.value
|
| 117 |
+
],
|
| 118 |
+
"left_hip": [
|
| 119 |
+
mp.solutions.pose.PoseLandmark.LEFT_KNEE.value,
|
| 120 |
+
mp.solutions.pose.PoseLandmark.LEFT_HIP.value,
|
| 121 |
+
mp.solutions.pose.PoseLandmark.LEFT_SHOULDER.value
|
| 122 |
+
],
|
| 123 |
+
"right_knee": [
|
| 124 |
+
mp.solutions.pose.PoseLandmark.RIGHT_ANKLE.value,
|
| 125 |
+
mp.solutions.pose.PoseLandmark.RIGHT_KNEE.value,
|
| 126 |
+
mp.solutions.pose.PoseLandmark.RIGHT_HIP.value
|
| 127 |
+
],
|
| 128 |
+
"left_knee": [
|
| 129 |
+
mp.solutions.pose.PoseLandmark.LEFT_ANKLE.value,
|
| 130 |
+
mp.solutions.pose.PoseLandmark.LEFT_KNEE.value,
|
| 131 |
+
mp.solutions.pose.PoseLandmark.LEFT_HIP.value
|
| 132 |
+
]
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
angles = {}
|
| 136 |
+
|
| 137 |
+
for joint_name, landmarks in joint_connections.items():
|
| 138 |
+
# Get the three points that form the angle
|
| 139 |
+
if all(landmarks[i] < len(keypoints) for i in range(3)):
|
| 140 |
+
p1 = np.array(keypoints[landmarks[0]][:2])
|
| 141 |
+
p2 = np.array(keypoints[landmarks[1]][:2])
|
| 142 |
+
p3 = np.array(keypoints[landmarks[2]][:2])
|
| 143 |
+
|
| 144 |
+
# Calculate vectors
|
| 145 |
+
v1 = p1 - p2
|
| 146 |
+
v2 = p3 - p2
|
| 147 |
+
|
| 148 |
+
# Calculate angle
|
| 149 |
+
cosine_angle = np.dot(
|
| 150 |
+
v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
|
| 151 |
+
angle = np.arccos(np.clip(cosine_angle, -1.0, 1.0))
|
| 152 |
+
angle_degrees = np.degrees(angle)
|
| 153 |
+
|
| 154 |
+
angles[joint_name] = angle_degrees
|
| 155 |
+
|
| 156 |
+
return angles
|
app/models/swing_analyzer.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Swing analysis module for golf swing segmentation and trajectory analysis
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
from app.models.pose_estimator import calculate_joint_angles
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def segment_swing(pose_data, detections, sample_rate=5):
|
| 11 |
+
"""
|
| 12 |
+
Segment the golf swing into key phases
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
pose_data (dict): Dictionary mapping frame indices to pose keypoints
|
| 16 |
+
detections (list): List of Detection objects
|
| 17 |
+
sample_rate (int): The frame sampling rate used during processing
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
dict: Dictionary mapping phase names to lists of frame indices
|
| 21 |
+
"""
|
| 22 |
+
# Initialize swing phases
|
| 23 |
+
swing_phases = {
|
| 24 |
+
"setup": [],
|
| 25 |
+
"backswing": [],
|
| 26 |
+
"downswing": [],
|
| 27 |
+
"impact": [],
|
| 28 |
+
"follow_through": []
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
# Get frame indices with pose data
|
| 32 |
+
frame_indices = sorted(pose_data.keys())
|
| 33 |
+
|
| 34 |
+
if not frame_indices:
|
| 35 |
+
return swing_phases
|
| 36 |
+
|
| 37 |
+
# Calculate joint angles for each frame
|
| 38 |
+
angles_by_frame = {}
|
| 39 |
+
for idx in frame_indices:
|
| 40 |
+
keypoints = pose_data[idx]
|
| 41 |
+
angles = calculate_joint_angles(keypoints)
|
| 42 |
+
angles_by_frame[idx] = angles
|
| 43 |
+
|
| 44 |
+
# Analyze shoulder rotation to identify swing phases
|
| 45 |
+
# This is a simplified approach - a more sophisticated algorithm would be needed for production
|
| 46 |
+
|
| 47 |
+
# Find the frame with the maximum right shoulder angle (top of backswing)
|
| 48 |
+
max_shoulder_angle = -1
|
| 49 |
+
top_backswing_frame = frame_indices[0]
|
| 50 |
+
|
| 51 |
+
for idx in frame_indices:
|
| 52 |
+
angles = angles_by_frame[idx]
|
| 53 |
+
if "right_shoulder" in angles and angles[
|
| 54 |
+
"right_shoulder"] > max_shoulder_angle:
|
| 55 |
+
max_shoulder_angle = angles["right_shoulder"]
|
| 56 |
+
top_backswing_frame = idx
|
| 57 |
+
|
| 58 |
+
# Find impact frame (when club meets ball)
|
| 59 |
+
# In a real implementation, this would use club and ball detection
|
| 60 |
+
impact_frame = None
|
| 61 |
+
person_positions = {}
|
| 62 |
+
|
| 63 |
+
# Extract person positions from detections
|
| 64 |
+
for detection in detections:
|
| 65 |
+
if detection.class_name == "person":
|
| 66 |
+
frame_idx = detection.frame_idx // sample_rate # Convert to processed frame index
|
| 67 |
+
if frame_idx in frame_indices:
|
| 68 |
+
person_positions[frame_idx] = detection.bbox
|
| 69 |
+
|
| 70 |
+
# Find the frame with the most forward position (impact)
|
| 71 |
+
if person_positions:
|
| 72 |
+
min_x = float('inf')
|
| 73 |
+
for idx, bbox in person_positions.items():
|
| 74 |
+
if idx > top_backswing_frame and bbox[0] < min_x:
|
| 75 |
+
min_x = bbox[0]
|
| 76 |
+
impact_frame = idx
|
| 77 |
+
|
| 78 |
+
# If impact frame not found, estimate it as 2/3 between top of backswing and end
|
| 79 |
+
if impact_frame is None:
|
| 80 |
+
impact_frame = frame_indices[0] + int(
|
| 81 |
+
(frame_indices[-1] - top_backswing_frame) * 2 / 3)
|
| 82 |
+
|
| 83 |
+
# Assign frames to phases
|
| 84 |
+
for idx in frame_indices:
|
| 85 |
+
if idx < frame_indices[len(frame_indices) // 5]:
|
| 86 |
+
# First 20% of frames are setup
|
| 87 |
+
swing_phases["setup"].append(idx)
|
| 88 |
+
elif idx < top_backswing_frame:
|
| 89 |
+
# Frames before top of backswing are backswing
|
| 90 |
+
swing_phases["backswing"].append(idx)
|
| 91 |
+
elif idx < impact_frame:
|
| 92 |
+
# Frames between top of backswing and impact are downswing
|
| 93 |
+
swing_phases["downswing"].append(idx)
|
| 94 |
+
elif idx < impact_frame + 5:
|
| 95 |
+
# Frames around impact
|
| 96 |
+
swing_phases["impact"].append(idx)
|
| 97 |
+
else:
|
| 98 |
+
# Remaining frames are follow-through
|
| 99 |
+
swing_phases["follow_through"].append(idx)
|
| 100 |
+
|
| 101 |
+
return swing_phases
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def analyze_trajectory(frames, detections, swing_phases, sample_rate=5):
|
| 105 |
+
"""
|
| 106 |
+
Analyze club and ball trajectory and speed
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
frames (list): List of video frames
|
| 110 |
+
detections (list): List of Detection objects
|
| 111 |
+
swing_phases (dict): Dictionary mapping phase names to lists of frame indices
|
| 112 |
+
sample_rate (int): The frame sampling rate used during processing
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
dict: Dictionary mapping frame indices to trajectory data
|
| 116 |
+
"""
|
| 117 |
+
trajectory_data = {}
|
| 118 |
+
|
| 119 |
+
# Extract ball detections
|
| 120 |
+
ball_detections = [d for d in detections if d.class_name == "sports ball"]
|
| 121 |
+
|
| 122 |
+
# Get impact frame index
|
| 123 |
+
impact_frames = swing_phases.get("impact", [])
|
| 124 |
+
if not impact_frames:
|
| 125 |
+
return trajectory_data
|
| 126 |
+
|
| 127 |
+
impact_frame_idx = impact_frames[len(impact_frames) // 2]
|
| 128 |
+
|
| 129 |
+
# Track ball trajectory after impact
|
| 130 |
+
ball_trajectory = []
|
| 131 |
+
ball_positions = {}
|
| 132 |
+
|
| 133 |
+
for detection in ball_detections:
|
| 134 |
+
frame_idx = detection.frame_idx // sample_rate # Convert to processed frame index
|
| 135 |
+
if frame_idx >= impact_frame_idx:
|
| 136 |
+
# Calculate ball center
|
| 137 |
+
x1, y1, x2, y2 = detection.bbox
|
| 138 |
+
center_x = (x1 + x2) / 2
|
| 139 |
+
center_y = (y1 + y2) / 2
|
| 140 |
+
ball_positions[frame_idx] = (center_x, center_y)
|
| 141 |
+
|
| 142 |
+
# Sort ball positions by frame index
|
| 143 |
+
sorted_frames = sorted(ball_positions.keys())
|
| 144 |
+
for idx in sorted_frames:
|
| 145 |
+
ball_trajectory.append(ball_positions[idx])
|
| 146 |
+
|
| 147 |
+
# Estimate club speed at impact
|
| 148 |
+
# In a real implementation, this would use more sophisticated tracking
|
| 149 |
+
club_speed = None
|
| 150 |
+
if len(swing_phases.get("downswing", [])) >= 2:
|
| 151 |
+
# Simplified club speed calculation
|
| 152 |
+
# In reality, this would require tracking the club head specifically
|
| 153 |
+
downswing_frames = swing_phases["downswing"]
|
| 154 |
+
time_diff = (downswing_frames[-1] -
|
| 155 |
+
downswing_frames[0]) / 30 # Assuming 30 fps
|
| 156 |
+
if time_diff > 0:
|
| 157 |
+
# Simplified speed calculation (just an example)
|
| 158 |
+
club_speed = 100 * (1 / time_diff) # Arbitrary scaling
|
| 159 |
+
|
| 160 |
+
# Populate trajectory data
|
| 161 |
+
for idx in sorted(swing_phases.keys()):
|
| 162 |
+
frames_in_phase = swing_phases[idx]
|
| 163 |
+
for frame_idx in frames_in_phase:
|
| 164 |
+
trajectory_data[frame_idx] = {
|
| 165 |
+
"phase":
|
| 166 |
+
idx,
|
| 167 |
+
"club_speed":
|
| 168 |
+
club_speed if idx == "impact" else None,
|
| 169 |
+
"ball_trajectory":
|
| 170 |
+
ball_trajectory
|
| 171 |
+
if idx == "impact" or idx == "follow_through" else None
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
return trajectory_data
|
app/streamlit_app.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Streamlit web UI for Golf Swing Analysis
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import tempfile
|
| 8 |
+
import streamlit as st
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
import base64
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import shutil
|
| 13 |
+
|
| 14 |
+
# Load environment variables
|
| 15 |
+
load_dotenv()
|
| 16 |
+
|
| 17 |
+
# Add the app directory to the path
|
| 18 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 19 |
+
|
| 20 |
+
from app.utils.video_downloader import download_youtube_video
|
| 21 |
+
from app.utils.video_processor import process_video
|
| 22 |
+
from app.models.pose_estimator import analyze_pose
|
| 23 |
+
from app.models.swing_analyzer import segment_swing, analyze_trajectory
|
| 24 |
+
from app.models.llm_analyzer import generate_swing_analysis, create_llm_prompt, prepare_data_for_llm
|
| 25 |
+
from app.utils.visualizer import create_annotated_video
|
| 26 |
+
|
| 27 |
+
# Set page config
|
| 28 |
+
st.set_page_config(page_title="Golf Swing Analysis",
|
| 29 |
+
page_icon="🏌️",
|
| 30 |
+
layout="wide",
|
| 31 |
+
initial_sidebar_state="expanded")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Define functions
|
| 35 |
+
def validate_youtube_url(url):
|
| 36 |
+
"""Validate if the URL is a YouTube URL"""
|
| 37 |
+
return "youtube.com" in url or "youtu.be" in url
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def process_uploaded_video(uploaded_file):
|
| 41 |
+
"""Process an uploaded video file"""
|
| 42 |
+
# Create downloads directory if it doesn't exist
|
| 43 |
+
os.makedirs("downloads", exist_ok=True)
|
| 44 |
+
|
| 45 |
+
# Save uploaded file to the downloads directory
|
| 46 |
+
file_path = os.path.join("downloads", uploaded_file.name)
|
| 47 |
+
with open(file_path, "wb") as f:
|
| 48 |
+
f.write(uploaded_file.getvalue())
|
| 49 |
+
|
| 50 |
+
return file_path
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def display_video(video_path):
|
| 54 |
+
"""Display a video with download option"""
|
| 55 |
+
# Read video bytes
|
| 56 |
+
with open(video_path, "rb") as file:
|
| 57 |
+
video_bytes = file.read()
|
| 58 |
+
|
| 59 |
+
# Display video using st.video with bytes
|
| 60 |
+
st.video(video_bytes)
|
| 61 |
+
|
| 62 |
+
# Show download button
|
| 63 |
+
st.download_button(label="Download Video",
|
| 64 |
+
data=video_bytes,
|
| 65 |
+
file_name=os.path.basename(video_path),
|
| 66 |
+
mime="video/mp4")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# Main app
|
| 70 |
+
def main():
|
| 71 |
+
"""Main Streamlit application"""
|
| 72 |
+
st.title("🏌️ Golf Swing Analysis")
|
| 73 |
+
st.write("Analyze your golf swing using computer vision and AI")
|
| 74 |
+
|
| 75 |
+
# Initialize session state for storing analysis results
|
| 76 |
+
if 'video_analyzed' not in st.session_state:
|
| 77 |
+
st.session_state.video_analyzed = False
|
| 78 |
+
if 'analysis_data' not in st.session_state:
|
| 79 |
+
st.session_state.analysis_data = {
|
| 80 |
+
'video_path': None,
|
| 81 |
+
'frames': None,
|
| 82 |
+
'detections': None,
|
| 83 |
+
'pose_data': None,
|
| 84 |
+
'swing_phases': None,
|
| 85 |
+
'trajectory_data': None,
|
| 86 |
+
'sample_rate': None
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
# Sidebar for configuration
|
| 90 |
+
st.sidebar.title("Configuration")
|
| 91 |
+
|
| 92 |
+
# Option to enable/disable GPT analysis
|
| 93 |
+
enable_gpt = st.sidebar.checkbox("Enable GPT Analysis", value=True)
|
| 94 |
+
|
| 95 |
+
# Frame skip rate for YOLO
|
| 96 |
+
sample_rate = st.sidebar.slider(
|
| 97 |
+
"Frame Skip Rate (YOLO)",
|
| 98 |
+
min_value=1,
|
| 99 |
+
max_value=10,
|
| 100 |
+
value=5,
|
| 101 |
+
help=
|
| 102 |
+
"Process every Nth frame. Higher values = faster but less accurate.")
|
| 103 |
+
|
| 104 |
+
# Video input options
|
| 105 |
+
st.header("Video Input")
|
| 106 |
+
input_option = st.radio("Choose input method:",
|
| 107 |
+
["YouTube URL", "Upload Video"])
|
| 108 |
+
|
| 109 |
+
video_path = None
|
| 110 |
+
analyze_clicked = False
|
| 111 |
+
|
| 112 |
+
if input_option == "YouTube URL":
|
| 113 |
+
youtube_url = st.text_input("Enter YouTube URL of golf swing:")
|
| 114 |
+
|
| 115 |
+
analyze_clicked = st.button("Analyze Swing", key="analyze_youtube")
|
| 116 |
+
if youtube_url and analyze_clicked:
|
| 117 |
+
if validate_youtube_url(youtube_url):
|
| 118 |
+
with st.spinner("Downloading video..."):
|
| 119 |
+
try:
|
| 120 |
+
video_path = download_youtube_video(youtube_url)
|
| 121 |
+
st.success("Video downloaded successfully!")
|
| 122 |
+
display_video(video_path)
|
| 123 |
+
except Exception as e:
|
| 124 |
+
st.error(f"Error downloading video: {str(e)}")
|
| 125 |
+
st.session_state.video_analyzed = False
|
| 126 |
+
return
|
| 127 |
+
else:
|
| 128 |
+
st.error("Please enter a valid YouTube URL")
|
| 129 |
+
st.session_state.video_analyzed = False
|
| 130 |
+
return
|
| 131 |
+
|
| 132 |
+
else: # Upload Video
|
| 133 |
+
uploaded_file = st.file_uploader("Upload a golf swing video",
|
| 134 |
+
type=["mp4", "mov", "avi"])
|
| 135 |
+
|
| 136 |
+
analyze_clicked = st.button("Analyze Swing", key="analyze_upload")
|
| 137 |
+
if uploaded_file and analyze_clicked:
|
| 138 |
+
with st.spinner("Processing uploaded video..."):
|
| 139 |
+
try:
|
| 140 |
+
video_path = process_uploaded_video(uploaded_file)
|
| 141 |
+
st.success("Video uploaded successfully!")
|
| 142 |
+
display_video(video_path)
|
| 143 |
+
except Exception as e:
|
| 144 |
+
st.error(f"Error processing video: {str(e)}")
|
| 145 |
+
st.session_state.video_analyzed = False
|
| 146 |
+
return
|
| 147 |
+
|
| 148 |
+
# Process video if available and analyze button was clicked
|
| 149 |
+
if video_path and analyze_clicked:
|
| 150 |
+
try:
|
| 151 |
+
# Step 1: Process video and detect objects
|
| 152 |
+
with st.spinner("Processing video and detecting objects..."):
|
| 153 |
+
frames, detections = process_video(video_path,
|
| 154 |
+
sample_rate=sample_rate)
|
| 155 |
+
st.success(f"Processed {len(frames)} frames")
|
| 156 |
+
|
| 157 |
+
# Step 2: Analyze pose
|
| 158 |
+
with st.spinner("Analyzing golfer's pose..."):
|
| 159 |
+
pose_data = analyze_pose(frames)
|
| 160 |
+
st.success("Pose analysis complete")
|
| 161 |
+
|
| 162 |
+
# Step 3: Segment swing into phases
|
| 163 |
+
with st.spinner("Segmenting swing phases..."):
|
| 164 |
+
swing_phases = segment_swing(pose_data,
|
| 165 |
+
detections,
|
| 166 |
+
sample_rate=sample_rate)
|
| 167 |
+
|
| 168 |
+
# Display swing phases
|
| 169 |
+
st.subheader("Swing Phases")
|
| 170 |
+
phase_cols = st.columns(5)
|
| 171 |
+
for i, (phase,
|
| 172 |
+
frames_in_phase) in enumerate(swing_phases.items()):
|
| 173 |
+
with phase_cols[i]:
|
| 174 |
+
st.metric(label=phase.capitalize(),
|
| 175 |
+
value=f"{len(frames_in_phase)} frames")
|
| 176 |
+
|
| 177 |
+
# Step 4: Analyze trajectory and speed
|
| 178 |
+
with st.spinner("Analyzing trajectory and speed..."):
|
| 179 |
+
trajectory_data = analyze_trajectory(frames,
|
| 180 |
+
detections,
|
| 181 |
+
swing_phases,
|
| 182 |
+
sample_rate=sample_rate)
|
| 183 |
+
|
| 184 |
+
# Display club speed if available
|
| 185 |
+
impact_frames = swing_phases.get("impact", [])
|
| 186 |
+
if impact_frames:
|
| 187 |
+
impact_frame = impact_frames[len(impact_frames) // 2]
|
| 188 |
+
if impact_frame in trajectory_data and trajectory_data[
|
| 189 |
+
impact_frame].get("club_speed"):
|
| 190 |
+
st.subheader("Club Speed")
|
| 191 |
+
st.metric(
|
| 192 |
+
label="Estimated Club Speed",
|
| 193 |
+
value=
|
| 194 |
+
f"{trajectory_data[impact_frame]['club_speed']:.1f} mph"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Step 5: Generate swing analysis using LLM (if enabled)
|
| 198 |
+
# Prepare data for LLM regardless of whether GPT is enabled
|
| 199 |
+
analysis_data = prepare_data_for_llm(pose_data, swing_phases,
|
| 200 |
+
trajectory_data)
|
| 201 |
+
prompt = create_llm_prompt(analysis_data)
|
| 202 |
+
|
| 203 |
+
# Display the GPT prompt
|
| 204 |
+
with st.expander("View GPT Prompt"):
|
| 205 |
+
st.code(prompt, language="text")
|
| 206 |
+
|
| 207 |
+
if enable_gpt:
|
| 208 |
+
with st.spinner(
|
| 209 |
+
"Generating swing analysis and coaching tips..."):
|
| 210 |
+
analysis = generate_swing_analysis(pose_data, swing_phases,
|
| 211 |
+
trajectory_data)
|
| 212 |
+
|
| 213 |
+
# Display analysis
|
| 214 |
+
st.subheader("Swing Analysis")
|
| 215 |
+
st.write(analysis)
|
| 216 |
+
else:
|
| 217 |
+
st.info(
|
| 218 |
+
"GPT Analysis is disabled. Enable it in the sidebar to generate coaching tips."
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Store analysis data in session state
|
| 222 |
+
st.session_state.video_analyzed = True
|
| 223 |
+
st.session_state.analysis_data = {
|
| 224 |
+
'video_path': video_path,
|
| 225 |
+
'frames': frames,
|
| 226 |
+
'detections': detections,
|
| 227 |
+
'pose_data': pose_data,
|
| 228 |
+
'swing_phases': swing_phases,
|
| 229 |
+
'trajectory_data': trajectory_data,
|
| 230 |
+
'sample_rate': sample_rate
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
except Exception as e:
|
| 234 |
+
st.error(f"Error during analysis: {str(e)}")
|
| 235 |
+
st.session_state.video_analyzed = False
|
| 236 |
+
|
| 237 |
+
# Create annotated video button (only show if analysis is complete)
|
| 238 |
+
if st.session_state.video_analyzed:
|
| 239 |
+
st.header("Create Annotated Video")
|
| 240 |
+
st.write(
|
| 241 |
+
"Create a video with annotations showing the analysis results")
|
| 242 |
+
|
| 243 |
+
# Create a separate section for the annotated video
|
| 244 |
+
if st.button("Generate Annotated Video", key="create_annotated"):
|
| 245 |
+
try:
|
| 246 |
+
with st.spinner("Creating annotated video..."):
|
| 247 |
+
# Create downloads directory if it doesn't exist
|
| 248 |
+
os.makedirs("downloads", exist_ok=True)
|
| 249 |
+
|
| 250 |
+
# Get data from session state
|
| 251 |
+
data = st.session_state.analysis_data
|
| 252 |
+
|
| 253 |
+
# Create the annotated video
|
| 254 |
+
output_path = create_annotated_video(
|
| 255 |
+
data['video_path'],
|
| 256 |
+
data['frames'],
|
| 257 |
+
data['detections'],
|
| 258 |
+
data['pose_data'],
|
| 259 |
+
data['swing_phases'],
|
| 260 |
+
data['trajectory_data'],
|
| 261 |
+
sample_rate=data['sample_rate'])
|
| 262 |
+
|
| 263 |
+
# Verify the file exists
|
| 264 |
+
if not os.path.exists(output_path):
|
| 265 |
+
raise FileNotFoundError(
|
| 266 |
+
f"Annotated video file not found at {output_path}")
|
| 267 |
+
|
| 268 |
+
st.success("Annotated video created successfully!")
|
| 269 |
+
|
| 270 |
+
# Display the video with download option
|
| 271 |
+
display_video(output_path)
|
| 272 |
+
|
| 273 |
+
except Exception as e:
|
| 274 |
+
st.error(f"Error creating annotated video: {str(e)}")
|
| 275 |
+
st.error(
|
| 276 |
+
"Please check if the downloads directory exists and is writable"
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
main()
|
app/utils/__init__.py
ADDED
|
File without changes
|
app/utils/video_downloader.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
YouTube video downloader module using yt-dlp
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import yt_dlp
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def download_youtube_video(url, output_dir="downloads"):
|
| 10 |
+
"""
|
| 11 |
+
Download a YouTube video from the provided URL using yt-dlp
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
url (str): YouTube video URL
|
| 15 |
+
output_dir (str): Directory to save the downloaded video
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
str: Path to the downloaded video file
|
| 19 |
+
|
| 20 |
+
Raises:
|
| 21 |
+
ValueError: If the URL is invalid or video is unavailable
|
| 22 |
+
"""
|
| 23 |
+
# Create output directory if it doesn't exist
|
| 24 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 25 |
+
|
| 26 |
+
# Set output template for the downloaded file
|
| 27 |
+
output_template = os.path.join(output_dir, "%(title)s.%(ext)s")
|
| 28 |
+
|
| 29 |
+
# Configure yt-dlp options
|
| 30 |
+
ydl_opts = {
|
| 31 |
+
'format': 'best[ext=mp4]/best', # Prefer mp4 format
|
| 32 |
+
'outtmpl': output_template,
|
| 33 |
+
'noplaylist': True,
|
| 34 |
+
'quiet': False,
|
| 35 |
+
'no_warnings': False,
|
| 36 |
+
'ignoreerrors': False,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
# Create yt-dlp object and download the video
|
| 41 |
+
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
| 42 |
+
info = ydl.extract_info(url, download=True)
|
| 43 |
+
|
| 44 |
+
# Get the downloaded file path
|
| 45 |
+
if 'entries' in info:
|
| 46 |
+
# Playlist (should not happen with noplaylist=True)
|
| 47 |
+
raise ValueError("Playlists are not supported")
|
| 48 |
+
|
| 49 |
+
# Get video title and extension
|
| 50 |
+
title = info.get('title', 'video')
|
| 51 |
+
ext = info.get('ext', 'mp4')
|
| 52 |
+
|
| 53 |
+
# Construct the file path
|
| 54 |
+
video_path = os.path.join(output_dir, f"{title}.{ext}")
|
| 55 |
+
|
| 56 |
+
# Check if file exists
|
| 57 |
+
if not os.path.exists(video_path):
|
| 58 |
+
# Try with sanitized filename
|
| 59 |
+
sanitized_title = ''.join(c for c in title
|
| 60 |
+
if c.isalnum() or c in ' ._-')
|
| 61 |
+
video_path = os.path.join(output_dir,
|
| 62 |
+
f"{sanitized_title}.{ext}")
|
| 63 |
+
|
| 64 |
+
if not os.path.exists(video_path):
|
| 65 |
+
# If still not found, look for any mp4 file in the directory
|
| 66 |
+
mp4_files = [
|
| 67 |
+
f for f in os.listdir(output_dir) if f.endswith('.mp4')
|
| 68 |
+
]
|
| 69 |
+
if mp4_files:
|
| 70 |
+
video_path = os.path.join(output_dir, mp4_files[0])
|
| 71 |
+
else:
|
| 72 |
+
raise ValueError("Downloaded file not found")
|
| 73 |
+
|
| 74 |
+
return video_path
|
| 75 |
+
|
| 76 |
+
except yt_dlp.utils.DownloadError as e:
|
| 77 |
+
raise ValueError(f"Error downloading video: {str(e)}")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
raise ValueError(f"Error: {str(e)}")
|
app/utils/video_processor.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Video processing and object detection module
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from ultralytics import YOLO
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Detection:
|
| 12 |
+
"""Class to store detection results"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, frame_idx, class_id, class_name, bbox, confidence):
|
| 15 |
+
self.frame_idx = frame_idx
|
| 16 |
+
self.class_id = class_id
|
| 17 |
+
self.class_name = class_name
|
| 18 |
+
self.bbox = bbox # [x1, y1, x2, y2]
|
| 19 |
+
self.confidence = confidence
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def process_video(video_path, sample_rate=5):
|
| 23 |
+
"""
|
| 24 |
+
Process video and detect golfer, club, and ball
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
video_path (str): Path to the video file
|
| 28 |
+
sample_rate (int): Process every nth frame
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
tuple: (frames, detections)
|
| 32 |
+
- frames: List of processed frames
|
| 33 |
+
- detections: List of Detection objects
|
| 34 |
+
"""
|
| 35 |
+
# Load YOLOv8 model
|
| 36 |
+
model = YOLO("yolov8n.pt")
|
| 37 |
+
|
| 38 |
+
# Custom class names for golf-specific objects
|
| 39 |
+
class_names = model.names
|
| 40 |
+
|
| 41 |
+
# Open video file
|
| 42 |
+
cap = cv2.VideoCapture(video_path)
|
| 43 |
+
if not cap.isOpened():
|
| 44 |
+
raise ValueError("Error opening video file")
|
| 45 |
+
|
| 46 |
+
# Get video properties
|
| 47 |
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 48 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 49 |
+
|
| 50 |
+
frames = []
|
| 51 |
+
detections = []
|
| 52 |
+
|
| 53 |
+
# Process frames
|
| 54 |
+
for frame_idx in tqdm(range(0, frame_count, sample_rate),
|
| 55 |
+
desc="Processing frames"):
|
| 56 |
+
# Set frame position
|
| 57 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
| 58 |
+
|
| 59 |
+
# Read frame
|
| 60 |
+
ret, frame = cap.read()
|
| 61 |
+
if not ret:
|
| 62 |
+
break
|
| 63 |
+
|
| 64 |
+
# Store original frame
|
| 65 |
+
frames.append(frame)
|
| 66 |
+
|
| 67 |
+
# Run YOLOv8 detection
|
| 68 |
+
results = model(frame)
|
| 69 |
+
|
| 70 |
+
# Process detection results
|
| 71 |
+
for result in results:
|
| 72 |
+
boxes = result.boxes
|
| 73 |
+
for box in boxes:
|
| 74 |
+
# Get detection information
|
| 75 |
+
class_id = int(box.cls.item())
|
| 76 |
+
class_name = class_names[class_id]
|
| 77 |
+
|
| 78 |
+
# Filter for relevant objects (person, sports ball)
|
| 79 |
+
if class_name in ["person", "sports ball"]:
|
| 80 |
+
bbox = box.xyxy[0].tolist() # [x1, y1, x2, y2]
|
| 81 |
+
confidence = box.conf.item()
|
| 82 |
+
|
| 83 |
+
# Create Detection object
|
| 84 |
+
detection = Detection(frame_idx, class_id, class_name,
|
| 85 |
+
bbox, confidence)
|
| 86 |
+
detections.append(detection)
|
| 87 |
+
|
| 88 |
+
# Release video capture
|
| 89 |
+
cap.release()
|
| 90 |
+
|
| 91 |
+
return frames, detections
|
app/utils/visualizer.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Visualization module for creating annotated videos
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import logging
|
| 10 |
+
import mediapipe as mp
|
| 11 |
+
|
| 12 |
+
# Define body part groups and their colors
|
| 13 |
+
BODY_PART_COLORS = {
|
| 14 |
+
"head": (255, 0, 0), # Blue
|
| 15 |
+
"torso": (0, 255, 0), # Green
|
| 16 |
+
"arms": (255, 165, 0), # Orange
|
| 17 |
+
"hands": (255, 0, 255), # Magenta
|
| 18 |
+
"legs": (0, 255, 255), # Cyan
|
| 19 |
+
"feet": (255, 255, 0) # Yellow
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
# Define which landmarks belong to which body part groups
|
| 23 |
+
BODY_PARTS_MAPPING = {
|
| 24 |
+
"head": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # Nose, eyes, ears, mouth
|
| 25 |
+
"torso": [11, 12, 23, 24], # Shoulders and hips
|
| 26 |
+
"arms": [11, 12, 13, 14], # Shoulders and elbows
|
| 27 |
+
"hands": [15, 16, 17, 18, 19, 20, 21,
|
| 28 |
+
22], # Wrists, pinkies, indices, thumbs
|
| 29 |
+
"legs": [23, 24, 25, 26], # Hips and knees
|
| 30 |
+
"feet": [27, 28, 29, 30, 31, 32] # Ankles, heels, foot indices
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def create_annotated_video(video_path,
|
| 35 |
+
frames,
|
| 36 |
+
detections,
|
| 37 |
+
pose_data,
|
| 38 |
+
swing_phases,
|
| 39 |
+
trajectory_data,
|
| 40 |
+
output_dir="downloads",
|
| 41 |
+
sample_rate=5):
|
| 42 |
+
"""
|
| 43 |
+
Create an annotated video with swing analysis visualizations
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
video_path (str): Path to the original video
|
| 47 |
+
frames (list): List of video frames
|
| 48 |
+
detections (list): List of Detection objects
|
| 49 |
+
pose_data (dict): Pose estimation data
|
| 50 |
+
swing_phases (dict): Swing phase segmentation data
|
| 51 |
+
trajectory_data (dict): Trajectory and speed analysis data
|
| 52 |
+
output_dir (str): Directory to save the output video
|
| 53 |
+
sample_rate (int): The frame sampling rate used during processing
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
str: Path to the annotated video
|
| 57 |
+
"""
|
| 58 |
+
try:
|
| 59 |
+
# Create output directory if it doesn't exist
|
| 60 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
# Get original video filename without extension
|
| 63 |
+
video_name = os.path.splitext(os.path.basename(video_path))[0]
|
| 64 |
+
output_path = os.path.join(output_dir, f"{video_name}_annotated.mp4")
|
| 65 |
+
|
| 66 |
+
# Get video properties
|
| 67 |
+
if not frames or len(frames) == 0:
|
| 68 |
+
raise ValueError("No frames provided for annotation")
|
| 69 |
+
|
| 70 |
+
height, width = frames[0].shape[:2]
|
| 71 |
+
fps = 30 # Default fps
|
| 72 |
+
|
| 73 |
+
# Create video writer
|
| 74 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 75 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 76 |
+
|
| 77 |
+
if not out.isOpened():
|
| 78 |
+
raise IOError(
|
| 79 |
+
f"Failed to create video writer for {output_path}. Check directory permissions."
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Process each frame
|
| 83 |
+
for i, frame in enumerate(tqdm(frames,
|
| 84 |
+
desc="Creating annotated video")):
|
| 85 |
+
# Create a copy of the frame for annotations
|
| 86 |
+
annotated_frame = frame.copy()
|
| 87 |
+
|
| 88 |
+
# Draw detections
|
| 89 |
+
frame_detections = [
|
| 90 |
+
d for d in detections if d.frame_idx == i * sample_rate
|
| 91 |
+
]
|
| 92 |
+
for detection in frame_detections:
|
| 93 |
+
x1, y1, x2, y2 = map(int, detection.bbox)
|
| 94 |
+
|
| 95 |
+
# Draw bounding box
|
| 96 |
+
color = (0, 255,
|
| 97 |
+
0) if detection.class_name == "person" else (0, 0,
|
| 98 |
+
255)
|
| 99 |
+
cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), color, 2)
|
| 100 |
+
|
| 101 |
+
# Draw label
|
| 102 |
+
label = f"{detection.class_name}: {detection.confidence:.2f}"
|
| 103 |
+
cv2.putText(annotated_frame, label, (x1, y1 - 10),
|
| 104 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
| 105 |
+
|
| 106 |
+
# Draw pose keypoints with different colors for different body parts
|
| 107 |
+
if i in pose_data:
|
| 108 |
+
keypoints = pose_data[i]
|
| 109 |
+
|
| 110 |
+
# Draw each keypoint with its corresponding body part color
|
| 111 |
+
for part_name, part_indices in BODY_PARTS_MAPPING.items():
|
| 112 |
+
color = BODY_PART_COLORS[part_name]
|
| 113 |
+
for idx in part_indices:
|
| 114 |
+
if idx < len(keypoints
|
| 115 |
+
) and keypoints[idx] is not None and len(
|
| 116 |
+
keypoints[idx]) >= 2:
|
| 117 |
+
x, y = int(keypoints[idx][0]), int(
|
| 118 |
+
keypoints[idx][1])
|
| 119 |
+
cv2.circle(annotated_frame, (x, y), 5, color, -1)
|
| 120 |
+
|
| 121 |
+
# Draw connections between keypoints
|
| 122 |
+
mp_pose = mp.solutions.pose
|
| 123 |
+
connections = mp_pose.POSE_CONNECTIONS
|
| 124 |
+
|
| 125 |
+
for connection in connections:
|
| 126 |
+
start_idx, end_idx = connection
|
| 127 |
+
|
| 128 |
+
if (start_idx < len(keypoints) and end_idx < len(keypoints)
|
| 129 |
+
and keypoints[start_idx] is not None
|
| 130 |
+
and keypoints[end_idx] is not None
|
| 131 |
+
and len(keypoints[start_idx]) >= 2
|
| 132 |
+
and len(keypoints[end_idx]) >= 2):
|
| 133 |
+
|
| 134 |
+
# Determine the color based on the body part of the start point
|
| 135 |
+
color = None
|
| 136 |
+
for part_name, part_indices in BODY_PARTS_MAPPING.items(
|
| 137 |
+
):
|
| 138 |
+
if start_idx in part_indices:
|
| 139 |
+
color = BODY_PART_COLORS[part_name]
|
| 140 |
+
break
|
| 141 |
+
|
| 142 |
+
# If no color found, use white
|
| 143 |
+
if color is None:
|
| 144 |
+
color = (255, 255, 255)
|
| 145 |
+
|
| 146 |
+
start_point = (int(keypoints[start_idx][0]),
|
| 147 |
+
int(keypoints[start_idx][1]))
|
| 148 |
+
end_point = (int(keypoints[end_idx][0]),
|
| 149 |
+
int(keypoints[end_idx][1]))
|
| 150 |
+
|
| 151 |
+
cv2.line(annotated_frame, start_point, end_point,
|
| 152 |
+
color, 2)
|
| 153 |
+
|
| 154 |
+
# Draw swing phase information
|
| 155 |
+
phase = None
|
| 156 |
+
for phase_name, phase_frames in swing_phases.items():
|
| 157 |
+
if i in phase_frames:
|
| 158 |
+
phase = phase_name
|
| 159 |
+
break
|
| 160 |
+
|
| 161 |
+
if phase:
|
| 162 |
+
cv2.putText(annotated_frame, f"Phase: {phase}", (10, 30),
|
| 163 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
|
| 164 |
+
|
| 165 |
+
# Draw trajectory information if available
|
| 166 |
+
if i in trajectory_data:
|
| 167 |
+
traj_info = trajectory_data[i]
|
| 168 |
+
if "club_speed" in traj_info and traj_info["club_speed"]:
|
| 169 |
+
cv2.putText(
|
| 170 |
+
annotated_frame,
|
| 171 |
+
f"Club Speed: {traj_info['club_speed']:.1f} mph",
|
| 172 |
+
(10, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0),
|
| 173 |
+
2)
|
| 174 |
+
|
| 175 |
+
if "ball_trajectory" in traj_info and traj_info[
|
| 176 |
+
"ball_trajectory"]:
|
| 177 |
+
points = traj_info["ball_trajectory"]
|
| 178 |
+
for j in range(1, len(points)):
|
| 179 |
+
pt1 = (int(points[j - 1][0]), int(points[j - 1][1]))
|
| 180 |
+
pt2 = (int(points[j][0]), int(points[j][1]))
|
| 181 |
+
cv2.line(annotated_frame, pt1, pt2, (0, 255, 255), 2)
|
| 182 |
+
|
| 183 |
+
# Add legend for body part colors
|
| 184 |
+
legend_y_start = 110
|
| 185 |
+
legend_y_spacing = 30
|
| 186 |
+
legend_x = 10
|
| 187 |
+
legend_box_size = 20
|
| 188 |
+
|
| 189 |
+
# Draw legend title
|
| 190 |
+
cv2.putText(annotated_frame, "Body Parts Legend:",
|
| 191 |
+
(legend_x, legend_y_start - 10),
|
| 192 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
| 193 |
+
|
| 194 |
+
# Draw color boxes and labels for each body part
|
| 195 |
+
for idx, (part_name, color) in enumerate(BODY_PART_COLORS.items()):
|
| 196 |
+
y_pos = legend_y_start + idx * legend_y_spacing
|
| 197 |
+
|
| 198 |
+
# Draw color box
|
| 199 |
+
cv2.rectangle(annotated_frame,
|
| 200 |
+
(legend_x, y_pos - legend_box_size + 5),
|
| 201 |
+
(legend_x + legend_box_size, y_pos + 5), color,
|
| 202 |
+
-1)
|
| 203 |
+
|
| 204 |
+
# Draw part name
|
| 205 |
+
cv2.putText(annotated_frame, part_name.capitalize(),
|
| 206 |
+
(legend_x + legend_box_size + 10, y_pos + 5),
|
| 207 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
| 208 |
+
|
| 209 |
+
# Write the annotated frame to the output video
|
| 210 |
+
out.write(annotated_frame)
|
| 211 |
+
|
| 212 |
+
# Release video writer
|
| 213 |
+
out.release()
|
| 214 |
+
|
| 215 |
+
# Verify the file was created
|
| 216 |
+
if not os.path.exists(output_path) or os.path.getsize(
|
| 217 |
+
output_path) == 0:
|
| 218 |
+
raise IOError(f"Failed to create video file at {output_path}")
|
| 219 |
+
|
| 220 |
+
print(f"Annotated video saved to: {output_path}")
|
| 221 |
+
return output_path
|
| 222 |
+
|
| 223 |
+
except Exception as e:
|
| 224 |
+
print(f"Error creating annotated video: {str(e)}")
|
| 225 |
+
raise
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
opencv-python==4.8.1.78
|
| 2 |
+
yt-dlp==2025.2.19
|
| 3 |
+
ultralytics==8.1.0
|
| 4 |
+
mediapipe==0.10.13
|
| 5 |
+
numpy==1.26.4
|
| 6 |
+
matplotlib==3.8.0
|
| 7 |
+
torch==2.2.0
|
| 8 |
+
torchvision==0.17.0
|
| 9 |
+
openai==1.6.0
|
| 10 |
+
python-dotenv==1.0.0
|
| 11 |
+
tqdm==4.66.1
|
| 12 |
+
streamlit==1.30.0
|
run_streamlit.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Activate the virtual environment
|
| 4 |
+
source .venv/bin/activate
|
| 5 |
+
|
| 6 |
+
# Run the Streamlit app
|
| 7 |
+
streamlit run app/streamlit_app.py
|
| 8 |
+
|
| 9 |
+
# Deactivate the virtual environment when done
|
| 10 |
+
deactivate
|
setup_directories.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Create downloads directory for the application
|
| 4 |
+
mkdir -p downloads
|
| 5 |
+
|
| 6 |
+
# Set permissions
|
| 7 |
+
chmod 755 downloads
|
| 8 |
+
|
| 9 |
+
echo "Directory created and permissions set:"
|
| 10 |
+
echo "- downloads: for storing downloaded YouTube videos and annotated videos"
|
| 11 |
+
|
| 12 |
+
# Create .env file if it doesn't exist
|
| 13 |
+
if [ ! -f .env ]; then
|
| 14 |
+
echo "Creating .env file template..."
|
| 15 |
+
echo "OPENAI_API_KEY=your_api_key_here" > .env
|
| 16 |
+
echo ".env file created. Please edit it to add your OpenAI API key."
|
| 17 |
+
fi
|
| 18 |
+
|
| 19 |
+
echo "Setup complete!"
|