Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import numpy as np | |
| import pandas as pd | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| import torch | |
| from ultralytics import YOLO | |
| from itertools import combinations | |
| import gradio as gr | |
| import traceback | |
| import time | |
| from typing import Tuple, List, Dict | |
| import logging | |
| # Force CPU mode for TensorFlow | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
| tf.config.set_visible_devices([], 'GPU') | |
| # Import spaces for ZeroGPU wrapper | |
| try: | |
| import spaces | |
| except ImportError: | |
| # Create a dummy spaces class for local development | |
| class spaces: | |
| def GPU(func): | |
| return func | |
| # ============================================================================= | |
| # LOGGING CONFIGURATION | |
| # ============================================================================= | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger("set_detector") | |
| # ============================================================================= | |
| # MODEL LOADING | |
| # ============================================================================= | |
| # Global variables for model caching | |
| _CARD_DETECTOR = None | |
| _SHAPE_DETECTOR = None | |
| _SHAPE_CLASSIFIER = None | |
| _FILL_CLASSIFIER = None | |
| def load_models(): | |
| """ | |
| Load all models needed for SET detection in CPU-only mode. | |
| Returns tuple of (card_detector, shape_detector, shape_classifier, fill_classifier) | |
| """ | |
| global _CARD_DETECTOR, _SHAPE_DETECTOR, _SHAPE_CLASSIFIER, _FILL_CLASSIFIER | |
| # Return cached models if already loaded | |
| if all([_CARD_DETECTOR, _SHAPE_DETECTOR, _SHAPE_CLASSIFIER, _FILL_CLASSIFIER]): | |
| logger.info("Using cached models") | |
| return _CARD_DETECTOR, _SHAPE_DETECTOR, _SHAPE_CLASSIFIER, _FILL_CLASSIFIER | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| logger.info("Loading models from Hugging Face Hub...") | |
| # Load Shape Classification Model (TensorFlow) | |
| logger.info("Loading shape classification model...") | |
| shape_classifier = load_model( | |
| hf_hub_download("Oamitai/shape-classification", "shape_model.keras") | |
| ) | |
| # Load Fill Classification Model (TensorFlow) | |
| logger.info("Loading fill classification model...") | |
| fill_classifier = load_model( | |
| hf_hub_download("Oamitai/fill-classification", "fill_model.keras") | |
| ) | |
| # Load YOLO Card Detection Model (PyTorch) | |
| logger.info("Loading card detection model...") | |
| card_model_path = hf_hub_download("Oamitai/card-detection", "best.pt") | |
| card_detector = YOLO(card_model_path) | |
| card_detector.conf = 0.5 | |
| # Load YOLO Shape Detection Model (PyTorch) | |
| logger.info("Loading shape detection model...") | |
| shape_model_path = hf_hub_download("Oamitai/shape-detection", "best.pt") | |
| shape_detector = YOLO(shape_model_path) | |
| shape_detector.conf = 0.5 | |
| # Explicitly set to CPU mode | |
| logger.info("Setting models to CPU mode...") | |
| card_detector.to("cpu") | |
| shape_detector.to("cpu") | |
| # Cache the models | |
| _CARD_DETECTOR = card_detector | |
| _SHAPE_DETECTOR = shape_detector | |
| _SHAPE_CLASSIFIER = shape_classifier | |
| _FILL_CLASSIFIER = fill_classifier | |
| logger.info("All models loaded successfully in CPU mode!") | |
| return card_detector, shape_detector, shape_classifier, fill_classifier | |
| except Exception as e: | |
| error_msg = f"Error loading models: {str(e)}" | |
| logger.error(error_msg) | |
| logger.error(traceback.format_exc()) | |
| raise ValueError(error_msg) | |
| # ============================================================================= | |
| # UTILITY & DETECTION FUNCTIONS | |
| # ============================================================================= | |
| def verify_and_rotate_image(board_image: np.ndarray, card_detector: YOLO) -> Tuple[np.ndarray, bool]: | |
| """ | |
| Checks if the detected cards are oriented primarily vertically or horizontally. | |
| If they're vertical, rotates the board_image 90 degrees clockwise for consistent processing. | |
| Returns (possibly_rotated_image, was_rotated_flag). | |
| """ | |
| detection = card_detector(board_image) | |
| boxes = detection[0].boxes.xyxy.cpu().numpy().astype(int) | |
| if boxes.size == 0: | |
| return board_image, False | |
| widths = boxes[:, 2] - boxes[:, 0] | |
| heights = boxes[:, 3] - boxes[:, 1] | |
| # Rotate if average height > average width | |
| if np.mean(heights) > np.mean(widths): | |
| return cv2.rotate(board_image, cv2.ROTATE_90_CLOCKWISE), True | |
| else: | |
| return board_image, False | |
| def restore_orientation(img: np.ndarray, was_rotated: bool) -> np.ndarray: | |
| """ | |
| Restores original orientation if the image was previously rotated. | |
| """ | |
| if was_rotated: | |
| return cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) | |
| return img | |
| def predict_color(img_bgr: np.ndarray) -> str: | |
| """ | |
| Rough color classification using HSV thresholds to differentiate 'red', 'green', 'purple'. | |
| """ | |
| hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV) | |
| mask_green = cv2.inRange(hsv, np.array([40, 50, 50]), np.array([80, 255, 255])) | |
| mask_purple = cv2.inRange(hsv, np.array([120, 50, 50]), np.array([160, 255, 255])) | |
| # Red can wrap around hue=0, so we combine both ends | |
| mask_red1 = cv2.inRange(hsv, np.array([0, 50, 50]), np.array([10, 255, 255])) | |
| mask_red2 = cv2.inRange(hsv, np.array([170, 50, 50]), np.array([180, 255, 255])) | |
| mask_red = cv2.bitwise_or(mask_red1, mask_red2) | |
| counts = { | |
| "green": cv2.countNonZero(mask_green), | |
| "purple": cv2.countNonZero(mask_purple), | |
| "red": cv2.countNonZero(mask_red), | |
| } | |
| return max(counts, key=counts.get) | |
| def detect_cards(board_img: np.ndarray, card_detector: YOLO) -> List[Tuple[np.ndarray, List[int]]]: | |
| """ | |
| Runs YOLO on the board_img to detect card bounding boxes. | |
| Returns a list of (card_image, [x1, y1, x2, y2]) for each detected card. | |
| """ | |
| result = card_detector(board_img) | |
| boxes = result[0].boxes.xyxy.cpu().numpy().astype(int) | |
| detected_cards = [] | |
| for x1, y1, x2, y2 in boxes: | |
| detected_cards.append((board_img[y1:y2, x1:x2], [x1, y1, x2, y2])) | |
| return detected_cards | |
| def predict_card_features( | |
| card_img: np.ndarray, | |
| shape_detector: YOLO, | |
| fill_model: tf.keras.Model, | |
| shape_model: tf.keras.Model, | |
| card_box: List[int] | |
| ) -> Dict: | |
| """ | |
| Predicts the 'count', 'color', 'fill', 'shape' features for a single card. | |
| It uses a shape_detector YOLO model to locate shapes, then passes them to fill_model and shape_model. | |
| """ | |
| # Detect shapes on the card | |
| shape_detections = shape_detector(card_img) | |
| c_h, c_w = card_img.shape[:2] | |
| card_area = c_w * c_h | |
| # Filter out spurious shape detections | |
| shape_boxes = [] | |
| for coords in shape_detections[0].boxes.xyxy.cpu().numpy(): | |
| x1, y1, x2, y2 = coords.astype(int) | |
| if (x2 - x1) * (y2 - y1) > 0.03 * card_area: | |
| shape_boxes.append([x1, y1, x2, y2]) | |
| if not shape_boxes: | |
| return { | |
| 'count': 0, | |
| 'color': 'unknown', | |
| 'fill': 'unknown', | |
| 'shape': 'unknown', | |
| 'box': card_box | |
| } | |
| fill_input_size = fill_model.input_shape[1:3] | |
| shape_input_size = shape_model.input_shape[1:3] | |
| fill_imgs = [] | |
| shape_imgs = [] | |
| color_candidates = [] | |
| # Prepare each detected shape region for classification | |
| for sb in shape_boxes: | |
| sx1, sy1, sx2, sy2 = sb | |
| shape_crop = card_img[sy1:sy2, sx1:sx2] | |
| fill_crop = cv2.resize(shape_crop, fill_input_size) / 255.0 | |
| shape_crop_resized = cv2.resize(shape_crop, shape_input_size) / 255.0 | |
| fill_imgs.append(fill_crop) | |
| shape_imgs.append(shape_crop_resized) | |
| color_candidates.append(predict_color(shape_crop)) | |
| # Handle TensorFlow prediction - process one image at a time to avoid memory issues | |
| fill_preds = [] | |
| shape_preds = [] | |
| for img in fill_imgs: | |
| try: | |
| pred = fill_model.predict(np.array([img]), verbose=0) | |
| fill_preds.append(pred[0]) | |
| except Exception as e: | |
| logger.error(f"Fill prediction error: {e}") | |
| fill_preds.append(np.array([0.33, 0.33, 0.34])) # Fallback | |
| for img in shape_imgs: | |
| try: | |
| pred = shape_model.predict(np.array([img]), verbose=0) | |
| shape_preds.append(pred[0]) | |
| except Exception as e: | |
| logger.error(f"Shape prediction error: {e}") | |
| shape_preds.append(np.array([0.33, 0.33, 0.34])) # Fallback | |
| fill_labels = ['empty', 'full', 'striped'] | |
| shape_labels = ['diamond', 'oval', 'squiggle'] | |
| fill_result = [fill_labels[np.argmax(fp)] for fp in fill_preds] | |
| shape_result = [shape_labels[np.argmax(sp)] for sp in shape_preds] | |
| # Take the most common color/fill/shape across all shape detections for the card | |
| if color_candidates: | |
| final_color = max(set(color_candidates), key=color_candidates.count) | |
| else: | |
| final_color = "unknown" | |
| if fill_result: | |
| final_fill = max(set(fill_result), key=fill_result.count) | |
| else: | |
| final_fill = "unknown" | |
| if shape_result: | |
| final_shape = max(set(shape_result), key=shape_result.count) | |
| else: | |
| final_shape = "unknown" | |
| return { | |
| 'count': len(shape_boxes), | |
| 'color': final_color, | |
| 'fill': final_fill, | |
| 'shape': final_shape, | |
| 'box': card_box | |
| } | |
| def classify_cards_on_board( | |
| board_img: np.ndarray, | |
| card_detector: YOLO, | |
| shape_detector: YOLO, | |
| fill_model: tf.keras.Model, | |
| shape_model: tf.keras.Model | |
| ) -> pd.DataFrame: | |
| """ | |
| Detects cards on the board, then classifies each card's features. | |
| Returns a DataFrame with columns: 'Count', 'Color', 'Fill', 'Shape', 'Coordinates'. | |
| """ | |
| detected_cards = detect_cards(board_img, card_detector) | |
| card_rows = [] | |
| for (card_img, box) in detected_cards: | |
| card_feats = predict_card_features(card_img, shape_detector, fill_model, shape_model, box) | |
| card_rows.append({ | |
| "Count": card_feats['count'], | |
| "Color": card_feats['color'], | |
| "Fill": card_feats['fill'], | |
| "Shape": card_feats['shape'], | |
| "Coordinates": card_feats['box'] | |
| }) | |
| return pd.DataFrame(card_rows) | |
| def valid_set(cards: List[dict]) -> bool: | |
| """ | |
| Checks if the given 3 cards collectively form a valid SET. | |
| """ | |
| for feature in ["Count", "Color", "Fill", "Shape"]: | |
| if len({card[feature] for card in cards}) not in (1, 3): | |
| return False | |
| return True | |
| def locate_all_sets(cards_df: pd.DataFrame) -> List[dict]: | |
| """ | |
| Finds all possible SETs from the card DataFrame. | |
| Each SET is a dictionary with 'set_indices' and 'cards' fields. | |
| """ | |
| found_sets = [] | |
| for combo in combinations(cards_df.iterrows(), 3): | |
| cards = [c[1] for c in combo] # c is (index, row) | |
| if valid_set(cards): | |
| found_sets.append({ | |
| 'set_indices': [c[0] for c in combo], | |
| 'cards': [ | |
| {f: card[f] for f in ['Count', 'Color', 'Fill', 'Shape', 'Coordinates']} | |
| for card in cards | |
| ] | |
| }) | |
| return found_sets | |
| def draw_detected_sets(board_img: np.ndarray, sets_detected: List[dict]) -> np.ndarray: | |
| """ | |
| Annotates the board image with bounding boxes for each detected SET. | |
| Each SET is drawn in a different color and offset (thickness & expansion) | |
| so that overlapping sets are visible. | |
| """ | |
| # Some distinct BGR colors | |
| colors = [ | |
| (255, 0, 0), (0, 255, 0), (0, 0, 255), | |
| (255, 255, 0), (255, 0, 255), (0, 255, 255) | |
| ] | |
| base_thickness = 8 | |
| base_expansion = 5 | |
| for idx, single_set in enumerate(sets_detected): | |
| color = colors[idx % len(colors)] | |
| thickness = base_thickness + 2 * idx | |
| expansion = base_expansion + 15 * idx | |
| for i, card_info in enumerate(single_set["cards"]): | |
| x1, y1, x2, y2 = card_info["Coordinates"] | |
| # Expand the bounding box slightly | |
| x1e = max(0, x1 - expansion) | |
| y1e = max(0, y1 - expansion) | |
| x2e = min(board_img.shape[1], x2 + expansion) | |
| y2e = min(board_img.shape[0], y2 + expansion) | |
| cv2.rectangle(board_img, (x1e, y1e), (x2e, y2e), color, thickness) | |
| # Label only the first card's box with "Set <number>" | |
| if i == 0: | |
| cv2.putText( | |
| board_img, | |
| f"Set {idx + 1}", | |
| (x1e, y1e - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.9, | |
| color, | |
| thickness | |
| ) | |
| return board_img | |
| def optimize_image_size(image_array: np.ndarray, max_dim=1200) -> np.ndarray: | |
| """ | |
| Resizes an image if its largest dimension exceeds max_dim, to reduce processing time. | |
| """ | |
| if image_array is None: | |
| return None | |
| height, width = image_array.shape[:2] | |
| if max(width, height) > max_dim: | |
| if width > height: | |
| new_width = max_dim | |
| new_height = int(height * (max_dim / width)) | |
| else: | |
| new_height = max_dim | |
| new_width = int(width * (max_dim / height)) | |
| return cv2.resize(image_array, (new_width, new_height), interpolation=cv2.INTER_AREA) | |
| return image_array | |
| def process_image(input_image): | |
| """ | |
| CPU-only processing function for SET detection. | |
| """ | |
| if input_image is None: | |
| return None, "Please upload an image." | |
| try: | |
| start_time = time.time() | |
| # Load models (CPU-only) | |
| card_detector, shape_detector, shape_model, fill_model = load_models() | |
| # Optimize image size | |
| optimized_img = optimize_image_size(input_image) | |
| # Convert to BGR (OpenCV format) | |
| if len(optimized_img.shape) == 3 and optimized_img.shape[2] == 4: # RGBA | |
| optimized_img = cv2.cvtColor(optimized_img, cv2.COLOR_RGBA2BGR) | |
| elif len(optimized_img.shape) == 3 and optimized_img.shape[2] == 3: | |
| # RGB to BGR | |
| optimized_img = cv2.cvtColor(optimized_img, cv2.COLOR_RGB2BGR) | |
| # Check and fix orientation | |
| processed_img, was_rotated = verify_and_rotate_image(optimized_img, card_detector) | |
| # Detect cards | |
| cards = detect_cards(processed_img, card_detector) | |
| if not cards: | |
| return cv2.cvtColor(optimized_img, cv2.COLOR_BGR2RGB), "No cards detected. Please check that it's a SET game board." | |
| # Classify cards and find sets | |
| df_cards = classify_cards_on_board(processed_img, card_detector, shape_detector, fill_model, shape_model) | |
| found_sets = locate_all_sets(df_cards) | |
| if not found_sets: | |
| return cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB), "Cards detected, but no valid SETs found!" | |
| # Draw sets on the image | |
| annotated = draw_detected_sets(processed_img.copy(), found_sets) | |
| # Restore original orientation if needed | |
| final_output = restore_orientation(annotated, was_rotated) | |
| # Convert back to RGB for display | |
| final_output_rgb = cv2.cvtColor(final_output, cv2.COLOR_BGR2RGB) | |
| process_time = time.time() - start_time | |
| return final_output_rgb, f"Found {len(found_sets)} SET(s) in {process_time:.2f} seconds." | |
| except Exception as e: | |
| error_message = f"Error processing image: {str(e)}" | |
| logger.error(error_message) | |
| logger.error(traceback.format_exc()) | |
| return input_image, error_message | |
| # Keep the spaces.GPU decorator for ZeroGPU API but use CPU internally | |
| def process_image_wrapper(input_image): | |
| """ | |
| Wrapper for process_image that uses the spaces.GPU decorator | |
| but internally works in CPU-only mode. | |
| """ | |
| return process_image(input_image) | |
| # ============================================================================= | |
| # SIMPLIFIED GRADIO INTERFACE | |
| # ============================================================================= | |
| with gr.Blocks(title="SET Game Detector") as demo: | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-bottom: 1rem;"> | |
| <h1 style="margin-bottom: 0.5rem;">🎴 SET Game Detector</h1> | |
| <p>Upload an image of a SET game board to find all valid sets</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| label="Upload SET Board Image", | |
| type="numpy" | |
| ) | |
| find_sets_btn = gr.Button( | |
| "🔎 Find Sets", | |
| variant="primary" | |
| ) | |
| with gr.Column(): | |
| output_image = gr.Image( | |
| label="Detected Sets" | |
| ) | |
| status = gr.Textbox( | |
| label="Status", | |
| value="Upload an image and click 'Find Sets'", | |
| interactive=False | |
| ) | |
| # Function bindings | |
| find_sets_btn.click( | |
| fn=process_image_wrapper, | |
| inputs=[input_image], | |
| outputs=[output_image, status] | |
| ) | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-top: 1rem; padding: 0.5rem; font-size: 0.8rem;"> | |
| <p>SET Game Detector by <a href="https://github.com/omamitai" target="_blank">omamitai</a> | | |
| Gradio version adapted for Hugging Face Spaces</p> | |
| </div> | |
| """) | |
| # ============================================================================= | |
| # MAIN EXECUTION | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| # Launch the app | |
| demo.queue().launch() |