import gradio as gr import cv2 import numpy as np import base64 import json import math from PIL import Image import io import time from collections import deque class PoolBallDetector: def __init__(self): self.ball_history = deque(maxlen=10) # Track ball positions over time self.cue_history = deque(maxlen=5) # Track cue stick positions self.table_bounds = None # Initialize ball detection parameters self.setup_detection_params() def setup_detection_params(self): # HSV ranges for different colored balls self.ball_colors = { 'cue': {'lower': np.array([0, 0, 200]), 'upper': np.array([180, 30, 255])}, # White 'black': {'lower': np.array([0, 0, 0]), 'upper': np.array([180, 255, 50])}, # Black (8-ball) 'solid': {'lower': np.array([0, 50, 50]), 'upper': np.array([10, 255, 255])}, # Red/solid colors 'stripe': {'lower': np.array([20, 50, 50]), 'upper': np.array([30, 255, 255])} # Yellow/stripe colors } # Cue stick detection (brown/wooden color) self.cue_color = { 'lower': np.array([10, 50, 20]), 'upper': np.array([20, 255, 200]) } def detect_table_bounds(self, frame): """Detect the pool table boundaries""" hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) # Green table detection green_lower = np.array([40, 50, 50]) green_upper = np.array([80, 255, 255]) mask = cv2.inRange(hsv, green_lower, green_upper) contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if contours: largest_contour = max(contours, key=cv2.contourArea) x, y, w, h = cv2.boundingRect(largest_contour) self.table_bounds = (x, y, x + w, y + h) return self.table_bounds return None def detect_balls(self, frame): """Detect all balls on the table""" if self.table_bounds is None: self.detect_table_bounds(frame) hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) balls = [] # Detect each type of ball for ball_type, color_range in self.ball_colors.items(): mask = cv2.inRange(hsv, color_range['lower'], color_range['upper']) # Apply morphological operations to clean up the mask kernel = np.ones((5, 5), np.uint8) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) # Find circles using HoughCircles circles = cv2.HoughCircles( mask, cv2.HOUGH_GRADIENT, dp=1, minDist=30, param1=50, param2=30, minRadius=10, maxRadius=50 ) if circles is not None: circles = np.round(circles[0, :]).astype("int") for (x, y, r) in circles: # Verify the ball is within table bounds if self.table_bounds and self.is_within_table(x, y): balls.append({ 'type': ball_type, 'x': float(x), 'y': float(y), 'radius': float(r), 'confidence': self.calculate_ball_confidence(mask, x, y, r) }) # Filter out duplicate detections balls = self.filter_duplicate_balls(balls) # Update history self.ball_history.append(balls) return balls def detect_cue_stick(self, frame): """Detect the cue stick position and angle""" hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) # Create mask for cue stick color mask = cv2.inRange(hsv, self.cue_color['lower'], self.cue_color['upper']) # Apply morphological operations kernel = np.ones((3, 3), np.uint8) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) # Find contours contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cue_data = None if contours: # Find the longest contour (likely the cue stick) longest_contour = max(contours, key=lambda c: cv2.arcLength(c, False)) if cv2.contourArea(longest_contour) > 500: # Minimum area threshold # Get the minimum area rectangle rect = cv2.minAreaRect(longest_contour) box = cv2.boxPoints(rect) box = np.int0(box) # Calculate cue stick line center_x, center_y = rect[0] angle = rect[2] # Get the two endpoints of the cue stick length = max(rect[1]) / 2 angle_rad = math.radians(angle) start_x = center_x - length * math.cos(angle_rad) start_y = center_y - length * math.sin(angle_rad) end_x = center_x + length * math.cos(angle_rad) end_y = center_y + length * math.sin(angle_rad) cue_data = { 'detected': True, 'center_x': float(center_x), 'center_y': float(center_y), 'angle': float(angle), 'start_x': float(start_x), 'start_y': float(start_y), 'end_x': float(end_x), 'end_y': float(end_y), 'length': float(length * 2) } self.cue_history.append(cue_data) return cue_data or {'detected': False} def calculate_trajectory(self, cue_data, balls): """Calculate the predicted trajectory based on cue position and ball positions""" if not cue_data.get('detected') or not balls: return [] # Find the cue ball cue_ball = None target_balls = [] for ball in balls: if ball['type'] == 'cue': cue_ball = ball else: target_balls.append(ball) if not cue_ball: return [] # Calculate trajectory from cue stick direction cue_angle_rad = math.radians(cue_data['angle']) cue_x, cue_y = cue_ball['x'], cue_ball['y'] # Calculate power based on cue stick proximity to cue ball power = self.calculate_shot_power(cue_data, cue_ball) # Generate trajectory points trajectory = [] dt = 0.1 # Time step velocity_x = power * math.cos(cue_angle_rad) * 10 # Scale factor velocity_y = power * math.sin(cue_angle_rad) * 10 x, y = cue_x, cue_y friction = 0.98 # Friction coefficient for i in range(50): # Maximum trajectory points x += velocity_x * dt y += velocity_y * dt # Apply friction velocity_x *= friction velocity_y *= friction # Check for table boundaries if self.table_bounds: x1, y1, x2, y2 = self.table_bounds if x <= x1 or x >= x2: velocity_x *= -0.8 # Bounce with energy loss x = max(x1, min(x2, x)) if y <= y1 or y >= y2: velocity_y *= -0.8 y = max(y1, min(y2, y)) # Check for collisions with other balls collision_detected = False for target_ball in target_balls: dist = math.sqrt((x - target_ball['x'])**2 + (y - target_ball['y'])**2) if dist < (cue_ball['radius'] + target_ball['radius']): collision_detected = True break trajectory.append({'x': float(x), 'y': float(y)}) # Stop if velocity is too low or collision detected if math.sqrt(velocity_x**2 + velocity_y**2) < 0.5 or collision_detected: break return trajectory def calculate_shot_power(self, cue_data, cue_ball): """Calculate shot power based on cue stick distance from cue ball""" if not cue_data.get('detected'): return 0.0 # Distance from cue stick end to cue ball cue_end_x, cue_end_y = cue_data['end_x'], cue_data['end_y'] ball_x, ball_y = cue_ball['x'], cue_ball['y'] distance = math.sqrt((cue_end_x - ball_x)**2 + (cue_end_y - ball_y)**2) # Convert distance to power (closer = more power) max_distance = 200 # Maximum meaningful distance power = max(0, 1 - (distance / max_distance)) return power def is_within_table(self, x, y): """Check if a point is within the table bounds""" if not self.table_bounds: return True x1, y1, x2, y2 = self.table_bounds return x1 <= x <= x2 and y1 <= y <= y2 def calculate_ball_confidence(self, mask, x, y, r): """Calculate confidence score for ball detection""" # Check the percentage of white pixels in the circle area circle_mask = np.zeros(mask.shape, dtype=np.uint8) cv2.circle(circle_mask, (x, y), r, 255, -1) intersection = cv2.bitwise_and(mask, circle_mask) circle_area = np.pi * r * r white_pixels = np.sum(intersection == 255) confidence = white_pixels / circle_area if circle_area > 0 else 0 return min(confidence, 1.0) def filter_duplicate_balls(self, balls): """Remove duplicate ball detections""" filtered_balls = [] for ball in balls: is_duplicate = False for existing_ball in filtered_balls: distance = math.sqrt( (ball['x'] - existing_ball['x'])**2 + (ball['y'] - existing_ball['y'])**2 ) if distance < 30: # If balls are too close, consider them duplicates if ball['confidence'] > existing_ball['confidence']: # Replace with higher confidence detection filtered_balls.remove(existing_ball) break else: is_duplicate = True break if not is_duplicate: filtered_balls.append(ball) return filtered_balls # Global detector instance detector = PoolBallDetector() def process_pool_image(image_data): """Process image and return predictions""" try: # Decode base64 image if it's a string if isinstance(image_data, str): image_data = base64.b64decode(image_data) image = Image.open(io.BytesIO(image_data)) else: image = image_data # Convert to OpenCV format frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # Detect components balls = detector.detect_balls(frame) cue_data = detector.detect_cue_stick(frame) # Calculate trajectory trajectory = [] if cue_data.get('detected'): trajectory = detector.calculate_trajectory(cue_data, balls) # Calculate metrics shot_angle = cue_data.get('angle', 0) if cue_data.get('detected') else 0 shot_power = 0 if cue_data.get('detected') and balls: cue_ball = next((ball for ball in balls if ball['type'] == 'cue'), None) if cue_ball: shot_power = detector.calculate_shot_power(cue_data, cue_ball) # Create visualization result_frame = frame.copy() # Draw table bounds if detected if detector.table_bounds: x1, y1, x2, y2 = detector.table_bounds cv2.rectangle(result_frame, (x1, y1), (x2, y2), (255, 0, 0), 2) # Draw balls for ball in balls: color = (255, 255, 255) if ball['type'] == 'cue' else (0, 255, 0) if ball['type'] == 'black': color = (128, 128, 128) elif ball['type'] == 'solid': color = (0, 0, 255) # Red for solid balls elif ball['type'] == 'stripe': color = (0, 255, 255) # Yellow for stripe balls cv2.circle(result_frame, (int(ball['x']), int(ball['y'])), int(ball['radius']), color, 2) cv2.putText(result_frame, ball['type'], (int(ball['x']-20), int(ball['y']-30)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) # Draw confidence score cv2.putText(result_frame, f"{ball['confidence']:.2f}", (int(ball['x']-10), int(ball['y']+40)), cv2.FONT_HERSHEY_SIMPLEX, 0.3, color, 1) # Draw cue stick if cue_data.get('detected'): cv2.line(result_frame, (int(cue_data['start_x']), int(cue_data['start_y'])), (int(cue_data['end_x']), int(cue_data['end_y'])), (0, 255, 255), 3) # Draw cue center point cv2.circle(result_frame, (int(cue_data['center_x']), int(cue_data['center_y'])), 5, (0, 255, 255), -1) # Draw trajectory if trajectory: trajectory_points = [(int(p['x']), int(p['y'])) for p in trajectory] for i in range(len(trajectory_points) - 1): # Fade the trajectory line as it gets further alpha = max(0.3, 1.0 - (i / len(trajectory_points))) color_intensity = int(255 * alpha) cv2.line(result_frame, trajectory_points[i], trajectory_points[i+1], (0, 0, color_intensity), 2) # Draw trajectory start point if trajectory_points: cv2.circle(result_frame, trajectory_points[0], 8, (0, 0, 255), -1) # Draw trajectory end point cv2.circle(result_frame, trajectory_points[-1], 6, (255, 0, 0), -1) # Draw power indicator on image if shot_power > 0: # Power bar bar_x, bar_y = 50, 50 bar_width, bar_height = 200, 20 # Background cv2.rectangle(result_frame, (bar_x, bar_y), (bar_x + bar_width, bar_y + bar_height), (64, 64, 64), -1) # Power fill power_width = int(bar_width * min(shot_power, 1.0)) if shot_power < 0.3: power_color = (0, 255, 0) # Green elif shot_power < 0.7: power_color = (0, 255, 255) # Yellow else: power_color = (0, 0, 255) # Red cv2.rectangle(result_frame, (bar_x, bar_y), (bar_x + power_width, bar_y + bar_height), power_color, -1) # Power text cv2.putText(result_frame, f"Power: {shot_power:.2f}", (bar_x, bar_y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1) # Draw angle indicator if shot_angle != 0: cv2.putText(result_frame, f"Angle: {shot_angle:.1f}°", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1) # Convert back to PIL for Gradio result_image = Image.fromarray(cv2.cvtColor(result_frame, cv2.COLOR_BGR2RGB)) # Prepare detailed text output info_text = f""" 🎱 Detection Results: ━━━━━━━━━━━━━━━━━━━━ 🎯 Cue Detected: {cue_data.get('detected', False)} 🏐 Balls Found: {len(balls)} ⚡ Shot Power: {shot_power:.2f} ({get_power_level(shot_power)}) 📐 Shot Angle: {shot_angle:.1f}° 📈 Trajectory Points: {len(trajectory)} 🏓 Table Bounds: {'Detected' if detector.table_bounds else 'Not Detected'} Ball Details: {format_ball_details(balls)} Trajectory Info: {format_trajectory_info(trajectory)} """ return result_image, info_text except Exception as e: error_frame = np.zeros((480, 640, 3), dtype=np.uint8) cv2.putText(error_frame, f"Error: {str(e)}", (50, 240), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) error_image = Image.fromarray(cv2.cvtColor(error_frame, cv2.COLOR_BGR2RGB)) return error_image, f"❌ Error: {str(e)}" def get_power_level(power): """Convert power value to descriptive text""" if power < 0.2: return "Gentle" elif power < 0.4: return "Light" elif power < 0.6: return "Medium" elif power < 0.8: return "Strong" else: return "Maximum" def format_ball_details(balls): """Format ball information for display""" if not balls: return "No balls detected" details = [] for i, ball in enumerate(balls): details.append(f" • {ball['type'].capitalize()}: ({ball['x']:.0f}, {ball['y']:.0f}) - Confidence: {ball['confidence']:.2f}") return "\n".join(details) def format_trajectory_info(trajectory): """Format trajectory information for display""" if not trajectory: return "No trajectory calculated" total_distance = 0 if len(trajectory) > 1: for i in range(len(trajectory) - 1): dx = trajectory[i+1]['x'] - trajectory[i]['x'] dy = trajectory[i+1]['y'] - trajectory[i]['y'] total_distance += math.sqrt(dx*dx + dy*dy) return f" • Total Distance: {total_distance:.1f} pixels\n • Path Length: {len(trajectory)} points" # Add the methods to the detector class PoolBallDetector.get_power_level = lambda self, power: get_power_level(power) PoolBallDetector.format_ball_details = lambda self, balls: format_ball_details(balls) PoolBallDetector.format_trajectory_info = lambda self, trajectory: format_trajectory_info(trajectory) def predict_api(image_b64): """API endpoint for mobile app""" try: # Process the image image_data = base64.b64decode(image_b64) image = Image.open(io.BytesIO(image_data)) frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # Detect components balls = detector.detect_balls(frame) cue_data = detector.detect_cue_stick(frame) trajectory = [] if cue_data.get('detected'): trajectory = detector.calculate_trajectory(cue_data, balls) # Calculate metrics shot_angle = cue_data.get('angle', 0) if cue_data.get('detected') else 0 shot_power = 0 if cue_data.get('detected') and balls: cue_ball = next((ball for ball in balls if ball['type'] == 'cue'), None) if cue_ball: shot_power = detector.calculate_shot_power(cue_data, cue_ball) response = { 'timestamp': int(time.time() * 1000), 'cue_detected': cue_data.get('detected', False), 'balls': balls, 'trajectory': trajectory, 'power': shot_power, 'angle': shot_angle, 'table_bounds': detector.table_bounds, 'status': 'success' } if cue_data.get('detected'): response['cue_line'] = { 'start_x': cue_data['start_x'], 'start_y': cue_data['start_y'], 'end_x': cue_data['end_x'], 'end_y': cue_data['end_y'], 'center_x': cue_data['center_x'], 'center_y': cue_data['center_y'], 'length': cue_data['length'] } return json.dumps(response) except Exception as e: error_response = { 'error': str(e), 'timestamp': int(time.time() * 1000), 'status': 'error' } return json.dumps(error_response) # Create Gradio interface with gr.Blocks(title="8-Ball Pool Trajectory Predictor", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎱 8-Ball Pool Trajectory Predictor") gr.Markdown("Upload a screenshot of your 8-ball pool game to get real-time trajectory predictions!") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Pool Table Screenshot") predict_btn = gr.Button("🎯 Predict Trajectory", variant="primary", size="lg") gr.Markdown("### 📝 Instructions:") gr.Markdown(""" 1. Take a screenshot of your 8-ball pool game 2. Upload the image above 3. Click 'Predict Trajectory' to see the analysis 4. View the predicted ball path in red lines 5. Check the power and angle indicators """) with gr.Column(): output_image = gr.Image(label="🎯 Prediction Results") output_text = gr.Textbox(label="📊 Detection Info", lines=12, max_lines=20) predict_btn.click( fn=process_pool_image, inputs=[input_image], outputs=[output_image, output_text] ) gr.Markdown("---") # API endpoint for mobile app with gr.Row(): gr.Markdown("## 📱 Mobile App API") with gr.Row(): with gr.Column(): api_input = gr.Textbox(label="Base64 Image Data (for mobile app)", lines=3, placeholder="Paste base64 encoded image data here...") api_btn = gr.Button("🔄 Process API Request", variant="secondary") with gr.Column(): api_output = gr.Textbox(label="JSON Response", lines=10, max_lines=15) api_btn.click( fn=predict_api, inputs=[api_input], outputs=[api_output] ) gr.Markdown("### 🔗 API Usage for Android:") gr.Markdown(""" ``` POST /predict Content-Type: application/json { "image": "base64_encoded_image_data_here" } ``` """) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)