anthony01 commited on
Commit
c72a920
·
verified ·
1 Parent(s): 2ab4877

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -0
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ import base64
5
+ import json
6
+ import math
7
+ from PIL import Image
8
+ import io
9
+ import time
10
+ from collections import deque
11
+
12
+ # Import your pool detector class (you can copy the class definition here)
13
+ # from pool_detector import PoolBallDetector
14
+
15
+ class PoolBallDetector:
16
+ # ... (copy the entire PoolBallDetector class from the previous artifact)
17
+ def __init__(self):
18
+ self.ball_history = deque(maxlen=10)
19
+ self.cue_history = deque(maxlen=5)
20
+ self.table_bounds = None
21
+ self.setup_detection_params()
22
+
23
+ def setup_detection_params(self):
24
+ self.ball_colors = {
25
+ 'cue': {'lower': np.array([0, 0, 200]), 'upper': np.array([180, 30, 255])},
26
+ 'black': {'lower': np.array([0, 0, 0]), 'upper': np.array([180, 255, 50])},
27
+ 'solid': {'lower': np.array([0, 50, 50]), 'upper': np.array([10, 255, 255])},
28
+ 'stripe': {'lower': np.array([20, 50, 50]), 'upper': np.array([30, 255, 255])}
29
+ }
30
+ self.cue_color = {'lower': np.array([10, 50, 20]), 'upper': np.array([20, 255, 200])}
31
+
32
+ # ... (include all other methods from PoolBallDetector)
33
+
34
+ # Global detector instance
35
+ detector = PoolBallDetector()
36
+
37
+ def process_pool_image(image_data):
38
+ """Process image and return predictions"""
39
+ try:
40
+ # Decode base64 image if it's a string
41
+ if isinstance(image_data, str):
42
+ image_data = base64.b64decode(image_data)
43
+ image = Image.open(io.BytesIO(image_data))
44
+ else:
45
+ image = image_data
46
+
47
+ # Convert to OpenCV format
48
+ frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
49
+
50
+ # Detect components
51
+ balls = detector.detect_balls(frame)
52
+ cue_data = detector.detect_cue_stick(frame)
53
+
54
+ # Calculate trajectory
55
+ trajectory = []
56
+ if cue_data.get('detected'):
57
+ trajectory = detector.calculate_trajectory(cue_data, balls)
58
+
59
+ # Calculate metrics
60
+ shot_angle = cue_data.get('angle', 0) if cue_data.get('detected') else 0
61
+ shot_power = 0
62
+
63
+ if cue_data.get('detected') and balls:
64
+ cue_ball = next((ball for ball in balls if ball['type'] == 'cue'), None)
65
+ if cue_ball:
66
+ shot_power = detector.calculate_shot_power(cue_data, cue_ball)
67
+
68
+ # Create visualization
69
+ result_frame = frame.copy()
70
+
71
+ # Draw balls
72
+ for ball in balls:
73
+ color = (255, 255, 255) if ball['type'] == 'cue' else (0, 255, 0)
74
+ cv2.circle(result_frame, (int(ball['x']), int(ball['y'])), int(ball['radius']), color, 2)
75
+ cv2.putText(result_frame, ball['type'], (int(ball['x']-20), int(ball['y']-30)),
76
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
77
+
78
+ # Draw cue stick
79
+ if cue_data.get('detected'):
80
+ cv2.line(result_frame,
81
+ (int(cue_data['start_x']), int(cue_data['start_y'])),
82
+ (int(cue_data['end_x']), int(cue_data['end_y'])),
83
+ (0, 255, 255), 3)
84
+
85
+ # Draw trajectory
86
+ if trajectory:
87
+ trajectory_points = [(int(p['x']), int(p['y'])) for p in trajectory]
88
+ for i in range(len(trajectory_points) - 1):
89
+ cv2.line(result_frame, trajectory_points[i], trajectory_points[i+1], (0, 0, 255), 2)
90
+
91
+ # Convert back to PIL for Gradio
92
+ result_image = Image.fromarray(cv2.cvtColor(result_frame, cv2.COLOR_BGR2RGB))
93
+
94
+ # Prepare text output
95
+ info_text = f"""
96
+ Detection Results:
97
+ - Cue Detected: {cue_data.get('detected', False)}
98
+ - Balls Found: {len(balls)}
99
+ - Shot Power: {shot_power:.2f}
100
+ - Shot Angle: {shot_angle:.1f}°
101
+ - Trajectory Points: {len(trajectory)}
102
+ """
103
+
104
+ return result_image, info_text
105
+
106
+ except Exception as e:
107
+ return None, f"Error: {str(e)}"
108
+
109
+ def predict_api(image_b64):
110
+ """API endpoint for mobile app"""
111
+ try:
112
+ # Process the image
113
+ image_data = base64.b64decode(image_b64)
114
+ image = Image.open(io.BytesIO(image_data))
115
+ frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
116
+
117
+ # Detect components
118
+ balls = detector.detect_balls(frame)
119
+ cue_data = detector.detect_cue_stick(frame)
120
+ trajectory = []
121
+
122
+ if cue_data.get('detected'):
123
+ trajectory = detector.calculate_trajectory(cue_data, balls)
124
+
125
+ # Calculate metrics
126
+ shot_angle = cue_data.get('angle', 0) if cue_data.get('detected') else 0
127
+ shot_power = 0
128
+
129
+ if cue_data.get('detected') and balls:
130
+ cue_ball = next((ball for ball in balls if ball['type'] == 'cue'), None)
131
+ if cue_ball:
132
+ shot_power = detector.calculate_shot_power(cue_data, cue_ball)
133
+
134
+ response = {
135
+ 'timestamp': int(time.time() * 1000),
136
+ 'cue_detected': cue_data.get('detected', False),
137
+ 'balls': balls,
138
+ 'trajectory': trajectory,
139
+ 'power': shot_power,
140
+ 'angle': shot_angle,
141
+ 'table_bounds': detector.table_bounds
142
+ }
143
+
144
+ if cue_data.get('detected'):
145
+ response['cue_line'] = {
146
+ 'start_x': cue_data['start_x'],
147
+ 'start_y': cue_data['start_y'],
148
+ 'end_x': cue_data['end_x'],
149
+ 'end_y': cue_data['end_y'],
150
+ 'center_x': cue_data['center_x'],
151
+ 'center_y': cue_data['center_y'],
152
+ 'length': cue_data['length']
153
+ }
154
+
155
+ return json.dumps(response)
156
+
157
+ except Exception as e:
158
+ return json.dumps({'error': str(e)})
159
+
160
+ # Create Gradio interface
161
+ with gr.Blocks(title="8-Ball Pool Trajectory Predictor") as demo:
162
+ gr.Markdown("# 🎱 8-Ball Pool Trajectory Predictor")
163
+ gr.Markdown("Upload a screenshot of your 8-ball pool game to get trajectory predictions!")
164
+
165
+ with gr.Row():
166
+ with gr.Column():
167
+ input_image = gr.Image(type="pil", label="Pool Table Screenshot")
168
+ predict_btn = gr.Button("Predict Trajectory", variant="primary")
169
+
170
+ with gr.Column():
171
+ output_image = gr.Image(label="Prediction Results")
172
+ output_text = gr.Textbox(label="Detection Info", lines=8)
173
+
174
+ predict_btn.click(
175
+ fn=process_pool_image,
176
+ inputs=[input_image],
177
+ outputs=[output_image, output_text]
178
+ )
179
+
180
+ # API endpoint for mobile app
181
+ gr.Interface(
182
+ fn=predict_api,
183
+ inputs=gr.Textbox(label="Base64 Image Data"),
184
+ outputs=gr.Textbox(label="JSON Response"),
185
+ title="API Endpoint",
186
+ description="For mobile app integration - send base64 encoded image"
187
+ )
188
+
189
+ if __name__ == "__main__":
190
+ demo.launch(server_name="0.0.0.0", server_port=7860)