anthony01 commited on
Commit
a1c7cdc
·
verified ·
1 Parent(s): 2e43f98

Create server.py

Browse files
Files changed (1) hide show
  1. server.py +368 -0
server.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import base64
4
+ import json
5
+ import math
6
+ from flask import Flask, request, jsonify
7
+ from PIL import Image
8
+ import io
9
+ import torch
10
+ import torchvision.transforms as transforms
11
+ from collections import deque
12
+ import threading
13
+ import time
14
+
15
+ app = Flask(__name__)
16
+
17
+ class PoolBallDetector:
18
+ def __init__(self):
19
+ self.ball_history = deque(maxlen=10) # Track ball positions over time
20
+ self.cue_history = deque(maxlen=5) # Track cue stick positions
21
+ self.table_bounds = None
22
+
23
+ # Initialize ball detection parameters
24
+ self.setup_detection_params()
25
+
26
+ def setup_detection_params(self):
27
+ # HSV ranges for different colored balls
28
+ self.ball_colors = {
29
+ 'cue': {'lower': np.array([0, 0, 200]), 'upper': np.array([180, 30, 255])}, # White
30
+ 'black': {'lower': np.array([0, 0, 0]), 'upper': np.array([180, 255, 50])}, # Black (8-ball)
31
+ 'solid': {'lower': np.array([0, 50, 50]), 'upper': np.array([10, 255, 255])}, # Red/solid colors
32
+ 'stripe': {'lower': np.array([20, 50, 50]), 'upper': np.array([30, 255, 255])} # Yellow/stripe colors
33
+ }
34
+
35
+ # Cue stick detection (brown/wooden color)
36
+ self.cue_color = {
37
+ 'lower': np.array([10, 50, 20]),
38
+ 'upper': np.array([20, 255, 200])
39
+ }
40
+
41
+ def detect_table_bounds(self, frame):
42
+ """Detect the pool table boundaries"""
43
+ hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
44
+
45
+ # Green table detection
46
+ green_lower = np.array([40, 50, 50])
47
+ green_upper = np.array([80, 255, 255])
48
+
49
+ mask = cv2.inRange(hsv, green_lower, green_upper)
50
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
51
+
52
+ if contours:
53
+ largest_contour = max(contours, key=cv2.contourArea)
54
+ x, y, w, h = cv2.boundingRect(largest_contour)
55
+ self.table_bounds = (x, y, x + w, y + h)
56
+ return self.table_bounds
57
+
58
+ return None
59
+
60
+ def detect_balls(self, frame):
61
+ """Detect all balls on the table"""
62
+ if self.table_bounds is None:
63
+ self.detect_table_bounds(frame)
64
+
65
+ hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
66
+ balls = []
67
+
68
+ # Detect each type of ball
69
+ for ball_type, color_range in self.ball_colors.items():
70
+ mask = cv2.inRange(hsv, color_range['lower'], color_range['upper'])
71
+
72
+ # Apply morphological operations to clean up the mask
73
+ kernel = np.ones((5, 5), np.uint8)
74
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
75
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
76
+
77
+ # Find circles using HoughCircles
78
+ circles = cv2.HoughCircles(
79
+ mask, cv2.HOUGH_GRADIENT, dp=1, minDist=30,
80
+ param1=50, param2=30, minRadius=10, maxRadius=50
81
+ )
82
+
83
+ if circles is not None:
84
+ circles = np.round(circles[0, :]).astype("int")
85
+ for (x, y, r) in circles:
86
+ # Verify the ball is within table bounds
87
+ if self.table_bounds and self.is_within_table(x, y):
88
+ balls.append({
89
+ 'type': ball_type,
90
+ 'x': float(x),
91
+ 'y': float(y),
92
+ 'radius': float(r),
93
+ 'confidence': self.calculate_ball_confidence(mask, x, y, r)
94
+ })
95
+
96
+ # Filter out duplicate detections
97
+ balls = self.filter_duplicate_balls(balls)
98
+
99
+ # Update history
100
+ self.ball_history.append(balls)
101
+
102
+ return balls
103
+
104
+ def detect_cue_stick(self, frame):
105
+ """Detect the cue stick position and angle"""
106
+ hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
107
+
108
+ # Create mask for cue stick color
109
+ mask = cv2.inRange(hsv, self.cue_color['lower'], self.cue_color['upper'])
110
+
111
+ # Apply morphological operations
112
+ kernel = np.ones((3, 3), np.uint8)
113
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
114
+
115
+ # Find contours
116
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
117
+
118
+ cue_data = None
119
+
120
+ if contours:
121
+ # Find the longest contour (likely the cue stick)
122
+ longest_contour = max(contours, key=lambda c: cv2.arcLength(c, False))
123
+
124
+ if cv2.contourArea(longest_contour) > 500: # Minimum area threshold
125
+ # Get the minimum area rectangle
126
+ rect = cv2.minAreaRect(longest_contour)
127
+ box = cv2.boxPoints(rect)
128
+ box = np.int0(box)
129
+
130
+ # Calculate cue stick line
131
+ center_x, center_y = rect[0]
132
+ angle = rect[2]
133
+
134
+ # Get the two endpoints of the cue stick
135
+ length = max(rect[1]) / 2
136
+ angle_rad = math.radians(angle)
137
+
138
+ start_x = center_x - length * math.cos(angle_rad)
139
+ start_y = center_y - length * math.sin(angle_rad)
140
+ end_x = center_x + length * math.cos(angle_rad)
141
+ end_y = center_y + length * math.sin(angle_rad)
142
+
143
+ cue_data = {
144
+ 'detected': True,
145
+ 'center_x': float(center_x),
146
+ 'center_y': float(center_y),
147
+ 'angle': float(angle),
148
+ 'start_x': float(start_x),
149
+ 'start_y': float(start_y),
150
+ 'end_x': float(end_x),
151
+ 'end_y': float(end_y),
152
+ 'length': float(length * 2)
153
+ }
154
+
155
+ self.cue_history.append(cue_data)
156
+
157
+ return cue_data or {'detected': False}
158
+
159
+ def calculate_trajectory(self, cue_data, balls):
160
+ """Calculate the predicted trajectory based on cue position and ball positions"""
161
+ if not cue_data.get('detected') or not balls:
162
+ return []
163
+
164
+ # Find the cue ball
165
+ cue_ball = None
166
+ target_balls = []
167
+
168
+ for ball in balls:
169
+ if ball['type'] == 'cue':
170
+ cue_ball = ball
171
+ else:
172
+ target_balls.append(ball)
173
+
174
+ if not cue_ball:
175
+ return []
176
+
177
+ # Calculate trajectory from cue stick direction
178
+ cue_angle_rad = math.radians(cue_data['angle'])
179
+ cue_x, cue_y = cue_ball['x'], cue_ball['y']
180
+
181
+ # Calculate power based on cue stick proximity to cue ball
182
+ power = self.calculate_shot_power(cue_data, cue_ball)
183
+
184
+ # Generate trajectory points
185
+ trajectory = []
186
+ dt = 0.1 # Time step
187
+ velocity_x = power * math.cos(cue_angle_rad) * 10 # Scale factor
188
+ velocity_y = power * math.sin(cue_angle_rad) * 10
189
+
190
+ x, y = cue_x, cue_y
191
+ friction = 0.98 # Friction coefficient
192
+
193
+ for i in range(50): # Maximum trajectory points
194
+ x += velocity_x * dt
195
+ y += velocity_y * dt
196
+
197
+ # Apply friction
198
+ velocity_x *= friction
199
+ velocity_y *= friction
200
+
201
+ # Check for table boundaries
202
+ if self.table_bounds:
203
+ x1, y1, x2, y2 = self.table_bounds
204
+ if x <= x1 or x >= x2:
205
+ velocity_x *= -0.8 # Bounce with energy loss
206
+ x = max(x1, min(x2, x))
207
+ if y <= y1 or y >= y2:
208
+ velocity_y *= -0.8
209
+ y = max(y1, min(y2, y))
210
+
211
+ # Check for collisions with other balls
212
+ collision_detected = False
213
+ for target_ball in target_balls:
214
+ dist = math.sqrt((x - target_ball['x'])**2 + (y - target_ball['y'])**2)
215
+ if dist < (cue_ball['radius'] + target_ball['radius']):
216
+ collision_detected = True
217
+ break
218
+
219
+ trajectory.append({'x': float(x), 'y': float(y)})
220
+
221
+ # Stop if velocity is too low or collision detected
222
+ if math.sqrt(velocity_x**2 + velocity_y**2) < 0.5 or collision_detected:
223
+ break
224
+
225
+ return trajectory
226
+
227
+ def calculate_shot_power(self, cue_data, cue_ball):
228
+ """Calculate shot power based on cue stick distance from cue ball"""
229
+ if not cue_data.get('detected'):
230
+ return 0.0
231
+
232
+ # Distance from cue stick end to cue ball
233
+ cue_end_x, cue_end_y = cue_data['end_x'], cue_data['end_y']
234
+ ball_x, ball_y = cue_ball['x'], cue_ball['y']
235
+
236
+ distance = math.sqrt((cue_end_x - ball_x)**2 + (cue_end_y - ball_y)**2)
237
+
238
+ # Convert distance to power (closer = more power)
239
+ max_distance = 200 # Maximum meaningful distance
240
+ power = max(0, 1 - (distance / max_distance))
241
+
242
+ return power
243
+
244
+ def is_within_table(self, x, y):
245
+ """Check if a point is within the table bounds"""
246
+ if not self.table_bounds:
247
+ return True
248
+
249
+ x1, y1, x2, y2 = self.table_bounds
250
+ return x1 <= x <= x2 and y1 <= y <= y2
251
+
252
+ def calculate_ball_confidence(self, mask, x, y, r):
253
+ """Calculate confidence score for ball detection"""
254
+ # Check the percentage of white pixels in the circle area
255
+ circle_mask = np.zeros(mask.shape, dtype=np.uint8)
256
+ cv2.circle(circle_mask, (x, y), r, 255, -1)
257
+
258
+ intersection = cv2.bitwise_and(mask, circle_mask)
259
+ circle_area = np.pi * r * r
260
+ white_pixels = np.sum(intersection == 255)
261
+
262
+ confidence = white_pixels / circle_area if circle_area > 0 else 0
263
+ return min(confidence, 1.0)
264
+
265
+ def filter_duplicate_balls(self, balls):
266
+ """Remove duplicate ball detections"""
267
+ filtered_balls = []
268
+
269
+ for ball in balls:
270
+ is_duplicate = False
271
+ for existing_ball in filtered_balls:
272
+ distance = math.sqrt(
273
+ (ball['x'] - existing_ball['x'])**2 +
274
+ (ball['y'] - existing_ball['y'])**2
275
+ )
276
+ if distance < 30: # If balls are too close, consider them duplicates
277
+ if ball['confidence'] > existing_ball['confidence']:
278
+ # Replace with higher confidence detection
279
+ filtered_balls.remove(existing_ball)
280
+ break
281
+ else:
282
+ is_duplicate = True
283
+ break
284
+
285
+ if not is_duplicate:
286
+ filtered_balls.append(ball)
287
+
288
+ return filtered_balls
289
+
290
+ # Global detector instance
291
+ detector = PoolBallDetector()
292
+
293
+ @app.route('/predict', methods=['POST'])
294
+ def predict():
295
+ try:
296
+ # Parse JSON request
297
+ data = request.get_json()
298
+
299
+ if not data or 'image' not in data:
300
+ return jsonify({'error': 'No image data provided'}), 400
301
+
302
+ # Decode base64 image
303
+ image_data = base64.b64decode(data['image'])
304
+ image = Image.open(io.BytesIO(image_data))
305
+
306
+ # Convert PIL image to OpenCV format
307
+ frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
308
+
309
+ # Detect balls and cue stick
310
+ balls = detector.detect_balls(frame)
311
+ cue_data = detector.detect_cue_stick(frame)
312
+
313
+ # Calculate trajectory if cue is detected
314
+ trajectory = []
315
+ if cue_data.get('detected'):
316
+ trajectory = detector.calculate_trajectory(cue_data, balls)
317
+
318
+ # Calculate additional metrics
319
+ shot_angle = cue_data.get('angle', 0) if cue_data.get('detected') else 0
320
+ shot_power = 0
321
+
322
+ if cue_data.get('detected') and balls:
323
+ cue_ball = next((ball for ball in balls if ball['type'] == 'cue'), None)
324
+ if cue_ball:
325
+ shot_power = detector.calculate_shot_power(cue_data, cue_ball)
326
+
327
+ # Prepare response
328
+ response = {
329
+ 'timestamp': data.get('timestamp', int(time.time() * 1000)),
330
+ 'cue_detected': cue_data.get('detected', False),
331
+ 'balls': balls,
332
+ 'trajectory': trajectory,
333
+ 'power': shot_power,
334
+ 'angle': shot_angle,
335
+ 'table_bounds': detector.table_bounds
336
+ }
337
+
338
+ # Add cue line data if detected
339
+ if cue_data.get('detected'):
340
+ response['cue_line'] = {
341
+ 'start_x': cue_data['start_x'],
342
+ 'start_y': cue_data['start_y'],
343
+ 'end_x': cue_data['end_x'],
344
+ 'end_y': cue_data['end_y'],
345
+ 'center_x': cue_data['center_x'],
346
+ 'center_y': cue_data['center_y'],
347
+ 'length': cue_data['length']
348
+ }
349
+
350
+ return jsonify(response)
351
+
352
+ except Exception as e:
353
+ print(f"Error in prediction: {str(e)}")
354
+ return jsonify({'error': f'Prediction failed: {str(e)}'}), 500
355
+
356
+ @app.route('/health', methods=['GET'])
357
+ def health():
358
+ return jsonify({'status': 'healthy', 'service': '8-ball-pool-predictor'})
359
+
360
+ @app.route('/reset', methods=['POST'])
361
+ def reset():
362
+ """Reset the detector state"""
363
+ global detector
364
+ detector = PoolBallDetector()
365
+ return jsonify({'status': 'reset_complete'})
366
+
367
+ if __name__ == '__main__':
368
+ app.run(host='0.0.0.0', port=7860, debug=False) # Port 7860 for Hugging Face Spaces