import cv2 import numpy as np import base64 import json import math from flask import Flask, request, jsonify from PIL import Image import io import torch import torchvision.transforms as transforms from collections import deque import threading import time app = Flask(__name__) class PoolBallDetector: def __init__(self): self.ball_history = deque(maxlen=10) # Track ball positions over time self.cue_history = deque(maxlen=5) # Track cue stick positions self.table_bounds = None # Initialize ball detection parameters self.setup_detection_params() def setup_detection_params(self): # HSV ranges for different colored balls self.ball_colors = { 'cue': {'lower': np.array([0, 0, 200]), 'upper': np.array([180, 30, 255])}, # White 'black': {'lower': np.array([0, 0, 0]), 'upper': np.array([180, 255, 50])}, # Black (8-ball) 'solid': {'lower': np.array([0, 50, 50]), 'upper': np.array([10, 255, 255])}, # Red/solid colors 'stripe': {'lower': np.array([20, 50, 50]), 'upper': np.array([30, 255, 255])} # Yellow/stripe colors } # Cue stick detection (brown/wooden color) self.cue_color = { 'lower': np.array([10, 50, 20]), 'upper': np.array([20, 255, 200]) } def detect_table_bounds(self, frame): """Detect the pool table boundaries""" hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) # Green table detection green_lower = np.array([40, 50, 50]) green_upper = np.array([80, 255, 255]) mask = cv2.inRange(hsv, green_lower, green_upper) contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if contours: largest_contour = max(contours, key=cv2.contourArea) x, y, w, h = cv2.boundingRect(largest_contour) self.table_bounds = (x, y, x + w, y + h) return self.table_bounds return None def detect_balls(self, frame): """Detect all balls on the table""" if self.table_bounds is None: self.detect_table_bounds(frame) hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) balls = [] # Detect each type of ball for ball_type, color_range in self.ball_colors.items(): mask = cv2.inRange(hsv, color_range['lower'], color_range['upper']) # Apply morphological operations to clean up the mask kernel = np.ones((5, 5), np.uint8) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) # Find circles using HoughCircles circles = cv2.HoughCircles( mask, cv2.HOUGH_GRADIENT, dp=1, minDist=30, param1=50, param2=30, minRadius=10, maxRadius=50 ) if circles is not None: circles = np.round(circles[0, :]).astype("int") for (x, y, r) in circles: # Verify the ball is within table bounds if self.table_bounds and self.is_within_table(x, y): balls.append({ 'type': ball_type, 'x': float(x), 'y': float(y), 'radius': float(r), 'confidence': self.calculate_ball_confidence(mask, x, y, r) }) # Filter out duplicate detections balls = self.filter_duplicate_balls(balls) # Update history self.ball_history.append(balls) return balls def detect_cue_stick(self, frame): """Detect the cue stick position and angle""" hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) # Create mask for cue stick color mask = cv2.inRange(hsv, self.cue_color['lower'], self.cue_color['upper']) # Apply morphological operations kernel = np.ones((3, 3), np.uint8) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) # Find contours contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cue_data = None if contours: # Find the longest contour (likely the cue stick) longest_contour = max(contours, key=lambda c: cv2.arcLength(c, False)) if cv2.contourArea(longest_contour) > 500: # Minimum area threshold # Get the minimum area rectangle rect = cv2.minAreaRect(longest_contour) box = cv2.boxPoints(rect) box = np.int0(box) # Calculate cue stick line center_x, center_y = rect[0] angle = rect[2] # Get the two endpoints of the cue stick length = max(rect[1]) / 2 angle_rad = math.radians(angle) start_x = center_x - length * math.cos(angle_rad) start_y = center_y - length * math.sin(angle_rad) end_x = center_x + length * math.cos(angle_rad) end_y = center_y + length * math.sin(angle_rad) cue_data = { 'detected': True, 'center_x': float(center_x), 'center_y': float(center_y), 'angle': float(angle), 'start_x': float(start_x), 'start_y': float(start_y), 'end_x': float(end_x), 'end_y': float(end_y), 'length': float(length * 2) } self.cue_history.append(cue_data) return cue_data or {'detected': False} def calculate_trajectory(self, cue_data, balls): """Calculate the predicted trajectory based on cue position and ball positions""" if not cue_data.get('detected') or not balls: return [] # Find the cue ball cue_ball = None target_balls = [] for ball in balls: if ball['type'] == 'cue': cue_ball = ball else: target_balls.append(ball) if not cue_ball: return [] # Calculate trajectory from cue stick direction cue_angle_rad = math.radians(cue_data['angle']) cue_x, cue_y = cue_ball['x'], cue_ball['y'] # Calculate power based on cue stick proximity to cue ball power = self.calculate_shot_power(cue_data, cue_ball) # Generate trajectory points trajectory = [] dt = 0.1 # Time step velocity_x = power * math.cos(cue_angle_rad) * 10 # Scale factor velocity_y = power * math.sin(cue_angle_rad) * 10 x, y = cue_x, cue_y friction = 0.98 # Friction coefficient for i in range(50): # Maximum trajectory points x += velocity_x * dt y += velocity_y * dt # Apply friction velocity_x *= friction velocity_y *= friction # Check for table boundaries if self.table_bounds: x1, y1, x2, y2 = self.table_bounds if x <= x1 or x >= x2: velocity_x *= -0.8 # Bounce with energy loss x = max(x1, min(x2, x)) if y <= y1 or y >= y2: velocity_y *= -0.8 y = max(y1, min(y2, y)) # Check for collisions with other balls collision_detected = False for target_ball in target_balls: dist = math.sqrt((x - target_ball['x'])**2 + (y - target_ball['y'])**2) if dist < (cue_ball['radius'] + target_ball['radius']): collision_detected = True break trajectory.append({'x': float(x), 'y': float(y)}) # Stop if velocity is too low or collision detected if math.sqrt(velocity_x**2 + velocity_y**2) < 0.5 or collision_detected: break return trajectory def calculate_shot_power(self, cue_data, cue_ball): """Calculate shot power based on cue stick distance from cue ball""" if not cue_data.get('detected'): return 0.0 # Distance from cue stick end to cue ball cue_end_x, cue_end_y = cue_data['end_x'], cue_data['end_y'] ball_x, ball_y = cue_ball['x'], cue_ball['y'] distance = math.sqrt((cue_end_x - ball_x)**2 + (cue_end_y - ball_y)**2) # Convert distance to power (closer = more power) max_distance = 200 # Maximum meaningful distance power = max(0, 1 - (distance / max_distance)) return power def is_within_table(self, x, y): """Check if a point is within the table bounds""" if not self.table_bounds: return True x1, y1, x2, y2 = self.table_bounds return x1 <= x <= x2 and y1 <= y <= y2 def calculate_ball_confidence(self, mask, x, y, r): """Calculate confidence score for ball detection""" # Check the percentage of white pixels in the circle area circle_mask = np.zeros(mask.shape, dtype=np.uint8) cv2.circle(circle_mask, (x, y), r, 255, -1) intersection = cv2.bitwise_and(mask, circle_mask) circle_area = np.pi * r * r white_pixels = np.sum(intersection == 255) confidence = white_pixels / circle_area if circle_area > 0 else 0 return min(confidence, 1.0) def filter_duplicate_balls(self, balls): """Remove duplicate ball detections""" filtered_balls = [] for ball in balls: is_duplicate = False for existing_ball in filtered_balls: distance = math.sqrt( (ball['x'] - existing_ball['x'])**2 + (ball['y'] - existing_ball['y'])**2 ) if distance < 30: # If balls are too close, consider them duplicates if ball['confidence'] > existing_ball['confidence']: # Replace with higher confidence detection filtered_balls.remove(existing_ball) break else: is_duplicate = True break if not is_duplicate: filtered_balls.append(ball) return filtered_balls # Global detector instance detector = PoolBallDetector() @app.route('/predict', methods=['POST']) def predict(): try: # Parse JSON request data = request.get_json() if not data or 'image' not in data: return jsonify({'error': 'No image data provided'}), 400 # Decode base64 image image_data = base64.b64decode(data['image']) image = Image.open(io.BytesIO(image_data)) # Convert PIL image to OpenCV format frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # Detect balls and cue stick balls = detector.detect_balls(frame) cue_data = detector.detect_cue_stick(frame) # Calculate trajectory if cue is detected trajectory = [] if cue_data.get('detected'): trajectory = detector.calculate_trajectory(cue_data, balls) # Calculate additional metrics shot_angle = cue_data.get('angle', 0) if cue_data.get('detected') else 0 shot_power = 0 if cue_data.get('detected') and balls: cue_ball = next((ball for ball in balls if ball['type'] == 'cue'), None) if cue_ball: shot_power = detector.calculate_shot_power(cue_data, cue_ball) # Prepare response response = { 'timestamp': data.get('timestamp', int(time.time() * 1000)), 'cue_detected': cue_data.get('detected', False), 'balls': balls, 'trajectory': trajectory, 'power': shot_power, 'angle': shot_angle, 'table_bounds': detector.table_bounds } # Add cue line data if detected if cue_data.get('detected'): response['cue_line'] = { 'start_x': cue_data['start_x'], 'start_y': cue_data['start_y'], 'end_x': cue_data['end_x'], 'end_y': cue_data['end_y'], 'center_x': cue_data['center_x'], 'center_y': cue_data['center_y'], 'length': cue_data['length'] } return jsonify(response) except Exception as e: print(f"Error in prediction: {str(e)}") return jsonify({'error': f'Prediction failed: {str(e)}'}), 500 @app.route('/health', methods=['GET']) def health(): return jsonify({'status': 'healthy', 'service': '8-ball-pool-predictor'}) @app.route('/reset', methods=['POST']) def reset(): """Reset the detector state""" global detector detector = PoolBallDetector() return jsonify({'status': 'reset_complete'}) if __name__ == '__main__': app.run(host='0.0.0.0', port=7860, debug=False) # Port 7860 for Hugging Face Spaces