File size: 4,136 Bytes
a422282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4416c36
 
 
95ad9d6
4416c36
 
a422282
 
 
 
 
 
 
 
 
 
 
 
 
 
bf56017
 
a422282
 
 
ed08beb
a422282
 
 
 
 
 
 
 
 
 
 
 
36d65da
a422282
 
 
95ad9d6
a422282
34aaec8
 
a422282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be88a1a
c6988d2
a422282
 
 
 
ed08beb
 
 
 
 
a422282
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python3
"""
Golf Swing Analysis - Main Application
"""

import os
import sys
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Add the app directory to the path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from utils.video_downloader import download_youtube_video, cleanup_video_file
from utils.video_processor import process_video
from models.pose_estimator import analyze_pose
from models.swing_analyzer import segment_swing_pose_based, analyze_trajectory
from models.llm_analyzer import generate_swing_analysis
from utils.visualizer import create_annotated_video


def main():
    """Main application function"""
    print("\n===== Golf Swing Analysis =====\n")

    # Step 1: Get YouTube URL from user
    youtube_url = input("Enter YouTube URL of golf swing: ")

    # Step 2: Configure analysis options
    enable_gpt = input(
        "\nEnable GPT analysis? (y/n, default: y): ").lower() != 'n'

    sample_rate_input = input(
        "\nFrame processing rate for YOLO (1-10, default: 1 for all frames): ")
    sample_rate = 1  # Default value - process all frames
    if sample_rate_input.isdigit():
        sample_rate = max(1, min(10, int(sample_rate_input)))

    video_path = None  # Initialize video_path for cleanup
    try:
        # Step 3: Download the video
        print("\nDownloading video...")
        video_path = download_youtube_video(youtube_url)
        print(f"Video downloaded to: {video_path}")

        # Step 4: Process video and detect golfer, club, and ball
        print("\nProcessing video and detecting objects...")
        frames, detections = process_video(video_path, sample_rate=sample_rate)

        # Step 5: Analyze pose throughout the swing
        print("\nAnalyzing golfer's pose...")
        pose_data, world_landmarks = analyze_pose(frames)

        # Step 6: Segment swing into phases
        print("\nSegmenting swing phases...")
        swing_phases = segment_swing_pose_based(pose_data,
                                     detections,
                                     sample_rate=sample_rate,
                                     fps=30.0)

        # Step 7: Analyze trajectory and speed
        print("\nAnalyzing trajectory and speed...")
        trajectory_data = analyze_trajectory(frames,
                                             detections,
                                             swing_phases,
                                             sample_rate=sample_rate)

        # Step 8: Generate swing analysis using LLM (if enabled)
        if enable_gpt:
            print("\nGenerating swing analysis and coaching tips...")
            analysis = generate_swing_analysis(pose_data, swing_phases,
                                               trajectory_data)

            # Display results
            print("\n===== Swing Analysis Results =====\n")
            print(analysis)
        else:
            print("\nGPT analysis disabled. Skipping swing evaluation.")

        # Step 9: Create annotated video (optional)
        create_video = input(
            "\nCreate annotated video? (y/n): ").lower() == 'y'
        if create_video:
            print("\nCreating annotated video...")
            output_path = create_annotated_video(video_path,
                                                 frames,
                                                 detections,
                                                 pose_data,
                                                 swing_phases,
                                                 trajectory_data,
                                                 sample_rate=sample_rate)
            print(f"Annotated video saved to: {output_path}")

        # Telemetry removed per request

        print("\nAnalysis complete!")

    except Exception as e:
        print(f"\nError: {str(e)}")
    finally:
        # Clean up the original downloaded video file after processing
        if video_path:
            print("\nCleaning up downloaded video file...")
            cleanup_video_file(video_path)


if __name__ == "__main__":
    main()