| import cv2 |
| import numpy as np |
| import base64 |
| import json |
| import math |
| from flask import Flask, request, jsonify |
| from PIL import Image |
| import io |
| import torch |
| import torchvision.transforms as transforms |
| from collections import deque |
| import threading |
| import time |
|
|
| app = Flask(__name__) |
|
|
| class PoolBallDetector: |
| def __init__(self): |
| self.ball_history = deque(maxlen=10) |
| self.cue_history = deque(maxlen=5) |
| self.table_bounds = None |
| |
| |
| self.setup_detection_params() |
| |
| def setup_detection_params(self): |
| |
| self.ball_colors = { |
| 'cue': {'lower': np.array([0, 0, 200]), 'upper': np.array([180, 30, 255])}, |
| 'black': {'lower': np.array([0, 0, 0]), 'upper': np.array([180, 255, 50])}, |
| 'solid': {'lower': np.array([0, 50, 50]), 'upper': np.array([10, 255, 255])}, |
| 'stripe': {'lower': np.array([20, 50, 50]), 'upper': np.array([30, 255, 255])} |
| } |
| |
| |
| self.cue_color = { |
| 'lower': np.array([10, 50, 20]), |
| 'upper': np.array([20, 255, 200]) |
| } |
| |
| def detect_table_bounds(self, frame): |
| """Detect the pool table boundaries""" |
| hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) |
| |
| |
| green_lower = np.array([40, 50, 50]) |
| green_upper = np.array([80, 255, 255]) |
| |
| mask = cv2.inRange(hsv, green_lower, green_upper) |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| |
| if contours: |
| largest_contour = max(contours, key=cv2.contourArea) |
| x, y, w, h = cv2.boundingRect(largest_contour) |
| self.table_bounds = (x, y, x + w, y + h) |
| return self.table_bounds |
| |
| return None |
| |
| def detect_balls(self, frame): |
| """Detect all balls on the table""" |
| if self.table_bounds is None: |
| self.detect_table_bounds(frame) |
| |
| hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) |
| balls = [] |
| |
| |
| for ball_type, color_range in self.ball_colors.items(): |
| mask = cv2.inRange(hsv, color_range['lower'], color_range['upper']) |
| |
| |
| kernel = np.ones((5, 5), np.uint8) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) |
| |
| |
| circles = cv2.HoughCircles( |
| mask, cv2.HOUGH_GRADIENT, dp=1, minDist=30, |
| param1=50, param2=30, minRadius=10, maxRadius=50 |
| ) |
| |
| if circles is not None: |
| circles = np.round(circles[0, :]).astype("int") |
| for (x, y, r) in circles: |
| |
| if self.table_bounds and self.is_within_table(x, y): |
| balls.append({ |
| 'type': ball_type, |
| 'x': float(x), |
| 'y': float(y), |
| 'radius': float(r), |
| 'confidence': self.calculate_ball_confidence(mask, x, y, r) |
| }) |
| |
| |
| balls = self.filter_duplicate_balls(balls) |
| |
| |
| self.ball_history.append(balls) |
| |
| return balls |
| |
| def detect_cue_stick(self, frame): |
| """Detect the cue stick position and angle""" |
| hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) |
| |
| |
| mask = cv2.inRange(hsv, self.cue_color['lower'], self.cue_color['upper']) |
| |
| |
| kernel = np.ones((3, 3), np.uint8) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) |
| |
| |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| |
| cue_data = None |
| |
| if contours: |
| |
| longest_contour = max(contours, key=lambda c: cv2.arcLength(c, False)) |
| |
| if cv2.contourArea(longest_contour) > 500: |
| |
| rect = cv2.minAreaRect(longest_contour) |
| box = cv2.boxPoints(rect) |
| box = np.int0(box) |
| |
| |
| center_x, center_y = rect[0] |
| angle = rect[2] |
| |
| |
| length = max(rect[1]) / 2 |
| angle_rad = math.radians(angle) |
| |
| start_x = center_x - length * math.cos(angle_rad) |
| start_y = center_y - length * math.sin(angle_rad) |
| end_x = center_x + length * math.cos(angle_rad) |
| end_y = center_y + length * math.sin(angle_rad) |
| |
| cue_data = { |
| 'detected': True, |
| 'center_x': float(center_x), |
| 'center_y': float(center_y), |
| 'angle': float(angle), |
| 'start_x': float(start_x), |
| 'start_y': float(start_y), |
| 'end_x': float(end_x), |
| 'end_y': float(end_y), |
| 'length': float(length * 2) |
| } |
| |
| self.cue_history.append(cue_data) |
| |
| return cue_data or {'detected': False} |
| |
| def calculate_trajectory(self, cue_data, balls): |
| """Calculate the predicted trajectory based on cue position and ball positions""" |
| if not cue_data.get('detected') or not balls: |
| return [] |
| |
| |
| cue_ball = None |
| target_balls = [] |
| |
| for ball in balls: |
| if ball['type'] == 'cue': |
| cue_ball = ball |
| else: |
| target_balls.append(ball) |
| |
| if not cue_ball: |
| return [] |
| |
| |
| cue_angle_rad = math.radians(cue_data['angle']) |
| cue_x, cue_y = cue_ball['x'], cue_ball['y'] |
| |
| |
| power = self.calculate_shot_power(cue_data, cue_ball) |
| |
| |
| trajectory = [] |
| dt = 0.1 |
| velocity_x = power * math.cos(cue_angle_rad) * 10 |
| velocity_y = power * math.sin(cue_angle_rad) * 10 |
| |
| x, y = cue_x, cue_y |
| friction = 0.98 |
| |
| for i in range(50): |
| x += velocity_x * dt |
| y += velocity_y * dt |
| |
| |
| velocity_x *= friction |
| velocity_y *= friction |
| |
| |
| if self.table_bounds: |
| x1, y1, x2, y2 = self.table_bounds |
| if x <= x1 or x >= x2: |
| velocity_x *= -0.8 |
| x = max(x1, min(x2, x)) |
| if y <= y1 or y >= y2: |
| velocity_y *= -0.8 |
| y = max(y1, min(y2, y)) |
| |
| |
| collision_detected = False |
| for target_ball in target_balls: |
| dist = math.sqrt((x - target_ball['x'])**2 + (y - target_ball['y'])**2) |
| if dist < (cue_ball['radius'] + target_ball['radius']): |
| collision_detected = True |
| break |
| |
| trajectory.append({'x': float(x), 'y': float(y)}) |
| |
| |
| if math.sqrt(velocity_x**2 + velocity_y**2) < 0.5 or collision_detected: |
| break |
| |
| return trajectory |
| |
| def calculate_shot_power(self, cue_data, cue_ball): |
| """Calculate shot power based on cue stick distance from cue ball""" |
| if not cue_data.get('detected'): |
| return 0.0 |
| |
| |
| cue_end_x, cue_end_y = cue_data['end_x'], cue_data['end_y'] |
| ball_x, ball_y = cue_ball['x'], cue_ball['y'] |
| |
| distance = math.sqrt((cue_end_x - ball_x)**2 + (cue_end_y - ball_y)**2) |
| |
| |
| max_distance = 200 |
| power = max(0, 1 - (distance / max_distance)) |
| |
| return power |
| |
| def is_within_table(self, x, y): |
| """Check if a point is within the table bounds""" |
| if not self.table_bounds: |
| return True |
| |
| x1, y1, x2, y2 = self.table_bounds |
| return x1 <= x <= x2 and y1 <= y <= y2 |
| |
| def calculate_ball_confidence(self, mask, x, y, r): |
| """Calculate confidence score for ball detection""" |
| |
| circle_mask = np.zeros(mask.shape, dtype=np.uint8) |
| cv2.circle(circle_mask, (x, y), r, 255, -1) |
| |
| intersection = cv2.bitwise_and(mask, circle_mask) |
| circle_area = np.pi * r * r |
| white_pixels = np.sum(intersection == 255) |
| |
| confidence = white_pixels / circle_area if circle_area > 0 else 0 |
| return min(confidence, 1.0) |
| |
| def filter_duplicate_balls(self, balls): |
| """Remove duplicate ball detections""" |
| filtered_balls = [] |
| |
| for ball in balls: |
| is_duplicate = False |
| for existing_ball in filtered_balls: |
| distance = math.sqrt( |
| (ball['x'] - existing_ball['x'])**2 + |
| (ball['y'] - existing_ball['y'])**2 |
| ) |
| if distance < 30: |
| if ball['confidence'] > existing_ball['confidence']: |
| |
| filtered_balls.remove(existing_ball) |
| break |
| else: |
| is_duplicate = True |
| break |
| |
| if not is_duplicate: |
| filtered_balls.append(ball) |
| |
| return filtered_balls |
|
|
| |
| detector = PoolBallDetector() |
|
|
| @app.route('/predict', methods=['POST']) |
| def predict(): |
| try: |
| |
| data = request.get_json() |
| |
| if not data or 'image' not in data: |
| return jsonify({'error': 'No image data provided'}), 400 |
| |
| |
| image_data = base64.b64decode(data['image']) |
| image = Image.open(io.BytesIO(image_data)) |
| |
| |
| frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
| |
| |
| balls = detector.detect_balls(frame) |
| cue_data = detector.detect_cue_stick(frame) |
| |
| |
| trajectory = [] |
| if cue_data.get('detected'): |
| trajectory = detector.calculate_trajectory(cue_data, balls) |
| |
| |
| shot_angle = cue_data.get('angle', 0) if cue_data.get('detected') else 0 |
| shot_power = 0 |
| |
| if cue_data.get('detected') and balls: |
| cue_ball = next((ball for ball in balls if ball['type'] == 'cue'), None) |
| if cue_ball: |
| shot_power = detector.calculate_shot_power(cue_data, cue_ball) |
| |
| |
| response = { |
| 'timestamp': data.get('timestamp', int(time.time() * 1000)), |
| 'cue_detected': cue_data.get('detected', False), |
| 'balls': balls, |
| 'trajectory': trajectory, |
| 'power': shot_power, |
| 'angle': shot_angle, |
| 'table_bounds': detector.table_bounds |
| } |
| |
| |
| if cue_data.get('detected'): |
| response['cue_line'] = { |
| 'start_x': cue_data['start_x'], |
| 'start_y': cue_data['start_y'], |
| 'end_x': cue_data['end_x'], |
| 'end_y': cue_data['end_y'], |
| 'center_x': cue_data['center_x'], |
| 'center_y': cue_data['center_y'], |
| 'length': cue_data['length'] |
| } |
| |
| return jsonify(response) |
| |
| except Exception as e: |
| print(f"Error in prediction: {str(e)}") |
| return jsonify({'error': f'Prediction failed: {str(e)}'}), 500 |
|
|
| @app.route('/health', methods=['GET']) |
| def health(): |
| return jsonify({'status': 'healthy', 'service': '8-ball-pool-predictor'}) |
|
|
| @app.route('/reset', methods=['POST']) |
| def reset(): |
| """Reset the detector state""" |
| global detector |
| detector = PoolBallDetector() |
| return jsonify({'status': 'reset_complete'}) |
|
|
| if __name__ == '__main__': |
| app.run(host='0.0.0.0', port=7860, debug=False) |