| import gradio as gr |
| import cv2 |
| import numpy as np |
| import base64 |
| import json |
| import math |
| from PIL import Image |
| import io |
| import time |
| from collections import deque |
|
|
| class PoolBallDetector: |
| def __init__(self): |
| self.ball_history = deque(maxlen=10) |
| self.cue_history = deque(maxlen=5) |
| self.table_bounds = None |
| |
| |
| self.setup_detection_params() |
| |
| def setup_detection_params(self): |
| |
| self.ball_colors = { |
| 'cue': {'lower': np.array([0, 0, 200]), 'upper': np.array([180, 30, 255])}, |
| 'black': {'lower': np.array([0, 0, 0]), 'upper': np.array([180, 255, 50])}, |
| 'solid': {'lower': np.array([0, 50, 50]), 'upper': np.array([10, 255, 255])}, |
| 'stripe': {'lower': np.array([20, 50, 50]), 'upper': np.array([30, 255, 255])} |
| } |
| |
| |
| self.cue_color = { |
| 'lower': np.array([10, 50, 20]), |
| 'upper': np.array([20, 255, 200]) |
| } |
| |
| def detect_table_bounds(self, frame): |
| """Detect the pool table boundaries""" |
| hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) |
| |
| |
| green_lower = np.array([40, 50, 50]) |
| green_upper = np.array([80, 255, 255]) |
| |
| mask = cv2.inRange(hsv, green_lower, green_upper) |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| |
| if contours: |
| largest_contour = max(contours, key=cv2.contourArea) |
| x, y, w, h = cv2.boundingRect(largest_contour) |
| self.table_bounds = (x, y, x + w, y + h) |
| return self.table_bounds |
| |
| return None |
| |
| def detect_balls(self, frame): |
| """Detect all balls on the table""" |
| if self.table_bounds is None: |
| self.detect_table_bounds(frame) |
| |
| hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) |
| balls = [] |
| |
| |
| for ball_type, color_range in self.ball_colors.items(): |
| mask = cv2.inRange(hsv, color_range['lower'], color_range['upper']) |
| |
| |
| kernel = np.ones((5, 5), np.uint8) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) |
| |
| |
| circles = cv2.HoughCircles( |
| mask, cv2.HOUGH_GRADIENT, dp=1, minDist=30, |
| param1=50, param2=30, minRadius=10, maxRadius=50 |
| ) |
| |
| if circles is not None: |
| circles = np.round(circles[0, :]).astype("int") |
| for (x, y, r) in circles: |
| |
| if self.table_bounds and self.is_within_table(x, y): |
| balls.append({ |
| 'type': ball_type, |
| 'x': float(x), |
| 'y': float(y), |
| 'radius': float(r), |
| 'confidence': self.calculate_ball_confidence(mask, x, y, r) |
| }) |
| |
| |
| balls = self.filter_duplicate_balls(balls) |
| |
| |
| self.ball_history.append(balls) |
| |
| return balls |
| |
| def detect_cue_stick(self, frame): |
| """Detect the cue stick position and angle""" |
| hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) |
| |
| |
| mask = cv2.inRange(hsv, self.cue_color['lower'], self.cue_color['upper']) |
| |
| |
| kernel = np.ones((3, 3), np.uint8) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) |
| |
| |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| |
| cue_data = None |
| |
| if contours: |
| |
| longest_contour = max(contours, key=lambda c: cv2.arcLength(c, False)) |
| |
| if cv2.contourArea(longest_contour) > 500: |
| |
| rect = cv2.minAreaRect(longest_contour) |
| box = cv2.boxPoints(rect) |
| box = np.int0(box) |
| |
| |
| center_x, center_y = rect[0] |
| angle = rect[2] |
| |
| |
| length = max(rect[1]) / 2 |
| angle_rad = math.radians(angle) |
| |
| start_x = center_x - length * math.cos(angle_rad) |
| start_y = center_y - length * math.sin(angle_rad) |
| end_x = center_x + length * math.cos(angle_rad) |
| end_y = center_y + length * math.sin(angle_rad) |
| |
| cue_data = { |
| 'detected': True, |
| 'center_x': float(center_x), |
| 'center_y': float(center_y), |
| 'angle': float(angle), |
| 'start_x': float(start_x), |
| 'start_y': float(start_y), |
| 'end_x': float(end_x), |
| 'end_y': float(end_y), |
| 'length': float(length * 2) |
| } |
| |
| self.cue_history.append(cue_data) |
| |
| return cue_data or {'detected': False} |
| |
| def calculate_trajectory(self, cue_data, balls): |
| """Calculate the predicted trajectory based on cue position and ball positions""" |
| if not cue_data.get('detected') or not balls: |
| return [] |
| |
| |
| cue_ball = None |
| target_balls = [] |
| |
| for ball in balls: |
| if ball['type'] == 'cue': |
| cue_ball = ball |
| else: |
| target_balls.append(ball) |
| |
| if not cue_ball: |
| return [] |
| |
| |
| cue_angle_rad = math.radians(cue_data['angle']) |
| cue_x, cue_y = cue_ball['x'], cue_ball['y'] |
| |
| |
| power = self.calculate_shot_power(cue_data, cue_ball) |
| |
| |
| trajectory = [] |
| dt = 0.1 |
| velocity_x = power * math.cos(cue_angle_rad) * 10 |
| velocity_y = power * math.sin(cue_angle_rad) * 10 |
| |
| x, y = cue_x, cue_y |
| friction = 0.98 |
| |
| for i in range(50): |
| x += velocity_x * dt |
| y += velocity_y * dt |
| |
| |
| velocity_x *= friction |
| velocity_y *= friction |
| |
| |
| if self.table_bounds: |
| x1, y1, x2, y2 = self.table_bounds |
| if x <= x1 or x >= x2: |
| velocity_x *= -0.8 |
| x = max(x1, min(x2, x)) |
| if y <= y1 or y >= y2: |
| velocity_y *= -0.8 |
| y = max(y1, min(y2, y)) |
| |
| |
| collision_detected = False |
| for target_ball in target_balls: |
| dist = math.sqrt((x - target_ball['x'])**2 + (y - target_ball['y'])**2) |
| if dist < (cue_ball['radius'] + target_ball['radius']): |
| collision_detected = True |
| break |
| |
| trajectory.append({'x': float(x), 'y': float(y)}) |
| |
| |
| if math.sqrt(velocity_x**2 + velocity_y**2) < 0.5 or collision_detected: |
| break |
| |
| return trajectory |
| |
| def calculate_shot_power(self, cue_data, cue_ball): |
| """Calculate shot power based on cue stick distance from cue ball""" |
| if not cue_data.get('detected'): |
| return 0.0 |
| |
| |
| cue_end_x, cue_end_y = cue_data['end_x'], cue_data['end_y'] |
| ball_x, ball_y = cue_ball['x'], cue_ball['y'] |
| |
| distance = math.sqrt((cue_end_x - ball_x)**2 + (cue_end_y - ball_y)**2) |
| |
| |
| max_distance = 200 |
| power = max(0, 1 - (distance / max_distance)) |
| |
| return power |
| |
| def is_within_table(self, x, y): |
| """Check if a point is within the table bounds""" |
| if not self.table_bounds: |
| return True |
| |
| x1, y1, x2, y2 = self.table_bounds |
| return x1 <= x <= x2 and y1 <= y <= y2 |
| |
| def calculate_ball_confidence(self, mask, x, y, r): |
| """Calculate confidence score for ball detection""" |
| |
| circle_mask = np.zeros(mask.shape, dtype=np.uint8) |
| cv2.circle(circle_mask, (x, y), r, 255, -1) |
| |
| intersection = cv2.bitwise_and(mask, circle_mask) |
| circle_area = np.pi * r * r |
| white_pixels = np.sum(intersection == 255) |
| |
| confidence = white_pixels / circle_area if circle_area > 0 else 0 |
| return min(confidence, 1.0) |
| |
| def filter_duplicate_balls(self, balls): |
| """Remove duplicate ball detections""" |
| filtered_balls = [] |
| |
| for ball in balls: |
| is_duplicate = False |
| for existing_ball in filtered_balls: |
| distance = math.sqrt( |
| (ball['x'] - existing_ball['x'])**2 + |
| (ball['y'] - existing_ball['y'])**2 |
| ) |
| if distance < 30: |
| if ball['confidence'] > existing_ball['confidence']: |
| |
| filtered_balls.remove(existing_ball) |
| break |
| else: |
| is_duplicate = True |
| break |
| |
| if not is_duplicate: |
| filtered_balls.append(ball) |
| |
| return filtered_balls |
|
|
| |
| detector = PoolBallDetector() |
|
|
| def process_pool_image(image_data): |
| """Process image and return predictions""" |
| try: |
| |
| if isinstance(image_data, str): |
| image_data = base64.b64decode(image_data) |
| image = Image.open(io.BytesIO(image_data)) |
| else: |
| image = image_data |
| |
| |
| frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
| |
| |
| balls = detector.detect_balls(frame) |
| cue_data = detector.detect_cue_stick(frame) |
| |
| |
| trajectory = [] |
| if cue_data.get('detected'): |
| trajectory = detector.calculate_trajectory(cue_data, balls) |
| |
| |
| shot_angle = cue_data.get('angle', 0) if cue_data.get('detected') else 0 |
| shot_power = 0 |
| |
| if cue_data.get('detected') and balls: |
| cue_ball = next((ball for ball in balls if ball['type'] == 'cue'), None) |
| if cue_ball: |
| shot_power = detector.calculate_shot_power(cue_data, cue_ball) |
| |
| |
| result_frame = frame.copy() |
| |
| |
| if detector.table_bounds: |
| x1, y1, x2, y2 = detector.table_bounds |
| cv2.rectangle(result_frame, (x1, y1), (x2, y2), (255, 0, 0), 2) |
| |
| |
| for ball in balls: |
| color = (255, 255, 255) if ball['type'] == 'cue' else (0, 255, 0) |
| if ball['type'] == 'black': |
| color = (128, 128, 128) |
| elif ball['type'] == 'solid': |
| color = (0, 0, 255) |
| elif ball['type'] == 'stripe': |
| color = (0, 255, 255) |
| |
| cv2.circle(result_frame, (int(ball['x']), int(ball['y'])), int(ball['radius']), color, 2) |
| cv2.putText(result_frame, ball['type'], (int(ball['x']-20), int(ball['y']-30)), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) |
| |
| |
| cv2.putText(result_frame, f"{ball['confidence']:.2f}", |
| (int(ball['x']-10), int(ball['y']+40)), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.3, color, 1) |
| |
| |
| if cue_data.get('detected'): |
| cv2.line(result_frame, |
| (int(cue_data['start_x']), int(cue_data['start_y'])), |
| (int(cue_data['end_x']), int(cue_data['end_y'])), |
| (0, 255, 255), 3) |
| |
| |
| cv2.circle(result_frame, (int(cue_data['center_x']), int(cue_data['center_y'])), 5, (0, 255, 255), -1) |
| |
| |
| if trajectory: |
| trajectory_points = [(int(p['x']), int(p['y'])) for p in trajectory] |
| for i in range(len(trajectory_points) - 1): |
| |
| alpha = max(0.3, 1.0 - (i / len(trajectory_points))) |
| color_intensity = int(255 * alpha) |
| cv2.line(result_frame, trajectory_points[i], trajectory_points[i+1], |
| (0, 0, color_intensity), 2) |
| |
| |
| if trajectory_points: |
| cv2.circle(result_frame, trajectory_points[0], 8, (0, 0, 255), -1) |
| |
| cv2.circle(result_frame, trajectory_points[-1], 6, (255, 0, 0), -1) |
| |
| |
| if shot_power > 0: |
| |
| bar_x, bar_y = 50, 50 |
| bar_width, bar_height = 200, 20 |
| |
| |
| cv2.rectangle(result_frame, (bar_x, bar_y), (bar_x + bar_width, bar_y + bar_height), (64, 64, 64), -1) |
| |
| |
| power_width = int(bar_width * min(shot_power, 1.0)) |
| if shot_power < 0.3: |
| power_color = (0, 255, 0) |
| elif shot_power < 0.7: |
| power_color = (0, 255, 255) |
| else: |
| power_color = (0, 0, 255) |
| |
| cv2.rectangle(result_frame, (bar_x, bar_y), (bar_x + power_width, bar_y + bar_height), power_color, -1) |
| |
| |
| cv2.putText(result_frame, f"Power: {shot_power:.2f}", (bar_x, bar_y - 10), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1) |
| |
| |
| if shot_angle != 0: |
| cv2.putText(result_frame, f"Angle: {shot_angle:.1f}Β°", (50, 100), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1) |
| |
| |
| result_image = Image.fromarray(cv2.cvtColor(result_frame, cv2.COLOR_BGR2RGB)) |
| |
| |
| info_text = f""" |
| π± Detection Results: |
| ββββββββββββββββββββ |
| π― Cue Detected: {cue_data.get('detected', False)} |
| π Balls Found: {len(balls)} |
| β‘ Shot Power: {shot_power:.2f} ({get_power_level(shot_power)}) |
| π Shot Angle: {shot_angle:.1f}Β° |
| π Trajectory Points: {len(trajectory)} |
| π Table Bounds: {'Detected' if detector.table_bounds else 'Not Detected'} |
| Ball Details: |
| {format_ball_details(balls)} |
| Trajectory Info: |
| {format_trajectory_info(trajectory)} |
| """ |
| |
| return result_image, info_text |
| |
| except Exception as e: |
| error_frame = np.zeros((480, 640, 3), dtype=np.uint8) |
| cv2.putText(error_frame, f"Error: {str(e)}", (50, 240), |
| cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) |
| error_image = Image.fromarray(cv2.cvtColor(error_frame, cv2.COLOR_BGR2RGB)) |
| return error_image, f"β Error: {str(e)}" |
|
|
| def get_power_level(power): |
| """Convert power value to descriptive text""" |
| if power < 0.2: |
| return "Gentle" |
| elif power < 0.4: |
| return "Light" |
| elif power < 0.6: |
| return "Medium" |
| elif power < 0.8: |
| return "Strong" |
| else: |
| return "Maximum" |
|
|
| def format_ball_details(balls): |
| """Format ball information for display""" |
| if not balls: |
| return "No balls detected" |
| |
| details = [] |
| for i, ball in enumerate(balls): |
| details.append(f" β’ {ball['type'].capitalize()}: ({ball['x']:.0f}, {ball['y']:.0f}) - Confidence: {ball['confidence']:.2f}") |
| |
| return "\n".join(details) |
|
|
| def format_trajectory_info(trajectory): |
| """Format trajectory information for display""" |
| if not trajectory: |
| return "No trajectory calculated" |
| |
| total_distance = 0 |
| if len(trajectory) > 1: |
| for i in range(len(trajectory) - 1): |
| dx = trajectory[i+1]['x'] - trajectory[i]['x'] |
| dy = trajectory[i+1]['y'] - trajectory[i]['y'] |
| total_distance += math.sqrt(dx*dx + dy*dy) |
| |
| return f" β’ Total Distance: {total_distance:.1f} pixels\n β’ Path Length: {len(trajectory)} points" |
|
|
| |
| PoolBallDetector.get_power_level = lambda self, power: get_power_level(power) |
| PoolBallDetector.format_ball_details = lambda self, balls: format_ball_details(balls) |
| PoolBallDetector.format_trajectory_info = lambda self, trajectory: format_trajectory_info(trajectory) |
|
|
| def predict_api(image_b64): |
| """API endpoint for mobile app""" |
| try: |
| |
| image_data = base64.b64decode(image_b64) |
| image = Image.open(io.BytesIO(image_data)) |
| frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
| |
| |
| balls = detector.detect_balls(frame) |
| cue_data = detector.detect_cue_stick(frame) |
| trajectory = [] |
| |
| if cue_data.get('detected'): |
| trajectory = detector.calculate_trajectory(cue_data, balls) |
| |
| |
| shot_angle = cue_data.get('angle', 0) if cue_data.get('detected') else 0 |
| shot_power = 0 |
| |
| if cue_data.get('detected') and balls: |
| cue_ball = next((ball for ball in balls if ball['type'] == 'cue'), None) |
| if cue_ball: |
| shot_power = detector.calculate_shot_power(cue_data, cue_ball) |
| |
| response = { |
| 'timestamp': int(time.time() * 1000), |
| 'cue_detected': cue_data.get('detected', False), |
| 'balls': balls, |
| 'trajectory': trajectory, |
| 'power': shot_power, |
| 'angle': shot_angle, |
| 'table_bounds': detector.table_bounds, |
| 'status': 'success' |
| } |
| |
| if cue_data.get('detected'): |
| response['cue_line'] = { |
| 'start_x': cue_data['start_x'], |
| 'start_y': cue_data['start_y'], |
| 'end_x': cue_data['end_x'], |
| 'end_y': cue_data['end_y'], |
| 'center_x': cue_data['center_x'], |
| 'center_y': cue_data['center_y'], |
| 'length': cue_data['length'] |
| } |
| |
| return json.dumps(response) |
| |
| except Exception as e: |
| error_response = { |
| 'error': str(e), |
| 'timestamp': int(time.time() * 1000), |
| 'status': 'error' |
| } |
| return json.dumps(error_response) |
|
|
| |
| with gr.Blocks(title="8-Ball Pool Trajectory Predictor", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# π± 8-Ball Pool Trajectory Predictor") |
| gr.Markdown("Upload a screenshot of your 8-ball pool game to get real-time trajectory predictions!") |
| |
| with gr.Row(): |
| with gr.Column(): |
| input_image = gr.Image(type="pil", label="Pool Table Screenshot") |
| predict_btn = gr.Button("π― Predict Trajectory", variant="primary", size="lg") |
| |
| gr.Markdown("### π Instructions:") |
| gr.Markdown(""" |
| 1. Take a screenshot of your 8-ball pool game |
| 2. Upload the image above |
| 3. Click 'Predict Trajectory' to see the analysis |
| 4. View the predicted ball path in red lines |
| 5. Check the power and angle indicators |
| """) |
| |
| with gr.Column(): |
| output_image = gr.Image(label="π― Prediction Results") |
| output_text = gr.Textbox(label="π Detection Info", lines=12, max_lines=20) |
| |
| predict_btn.click( |
| fn=process_pool_image, |
| inputs=[input_image], |
| outputs=[output_image, output_text] |
| ) |
| |
| gr.Markdown("---") |
| |
| |
| with gr.Row(): |
| gr.Markdown("## π± Mobile App API") |
| |
| with gr.Row(): |
| with gr.Column(): |
| api_input = gr.Textbox(label="Base64 Image Data (for mobile app)", lines=3, |
| placeholder="Paste base64 encoded image data here...") |
| api_btn = gr.Button("π Process API Request", variant="secondary") |
| |
| with gr.Column(): |
| api_output = gr.Textbox(label="JSON Response", lines=10, max_lines=15) |
| |
| api_btn.click( |
| fn=predict_api, |
| inputs=[api_input], |
| outputs=[api_output] |
| ) |
| |
| gr.Markdown("### π API Usage for Android:") |
| gr.Markdown(""" |
| ``` |
| POST /predict |
| Content-Type: application/json |
| |
| { |
| "image": "base64_encoded_image_data_here" |
| } |
| ``` |
| """) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |