Set-Game-Solver / app.py
Oamitai's picture
Update app.py
844e48a verified
import os
import cv2
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import load_model
import torch
from ultralytics import YOLO
from itertools import combinations
import gradio as gr
import traceback
import time
from typing import Tuple, List, Dict
import logging
# Force CPU mode for TensorFlow
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
tf.config.set_visible_devices([], 'GPU')
# Import spaces for ZeroGPU wrapper
try:
import spaces
except ImportError:
# Create a dummy spaces class for local development
class spaces:
@staticmethod
def GPU(func):
return func
# =============================================================================
# LOGGING CONFIGURATION
# =============================================================================
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("set_detector")
# =============================================================================
# MODEL LOADING
# =============================================================================
# Global variables for model caching
_CARD_DETECTOR = None
_SHAPE_DETECTOR = None
_SHAPE_CLASSIFIER = None
_FILL_CLASSIFIER = None
def load_models():
"""
Load all models needed for SET detection in CPU-only mode.
Returns tuple of (card_detector, shape_detector, shape_classifier, fill_classifier)
"""
global _CARD_DETECTOR, _SHAPE_DETECTOR, _SHAPE_CLASSIFIER, _FILL_CLASSIFIER
# Return cached models if already loaded
if all([_CARD_DETECTOR, _SHAPE_DETECTOR, _SHAPE_CLASSIFIER, _FILL_CLASSIFIER]):
logger.info("Using cached models")
return _CARD_DETECTOR, _SHAPE_DETECTOR, _SHAPE_CLASSIFIER, _FILL_CLASSIFIER
try:
from huggingface_hub import hf_hub_download
logger.info("Loading models from Hugging Face Hub...")
# Load Shape Classification Model (TensorFlow)
logger.info("Loading shape classification model...")
shape_classifier = load_model(
hf_hub_download("Oamitai/shape-classification", "shape_model.keras")
)
# Load Fill Classification Model (TensorFlow)
logger.info("Loading fill classification model...")
fill_classifier = load_model(
hf_hub_download("Oamitai/fill-classification", "fill_model.keras")
)
# Load YOLO Card Detection Model (PyTorch)
logger.info("Loading card detection model...")
card_model_path = hf_hub_download("Oamitai/card-detection", "best.pt")
card_detector = YOLO(card_model_path)
card_detector.conf = 0.5
# Load YOLO Shape Detection Model (PyTorch)
logger.info("Loading shape detection model...")
shape_model_path = hf_hub_download("Oamitai/shape-detection", "best.pt")
shape_detector = YOLO(shape_model_path)
shape_detector.conf = 0.5
# Explicitly set to CPU mode
logger.info("Setting models to CPU mode...")
card_detector.to("cpu")
shape_detector.to("cpu")
# Cache the models
_CARD_DETECTOR = card_detector
_SHAPE_DETECTOR = shape_detector
_SHAPE_CLASSIFIER = shape_classifier
_FILL_CLASSIFIER = fill_classifier
logger.info("All models loaded successfully in CPU mode!")
return card_detector, shape_detector, shape_classifier, fill_classifier
except Exception as e:
error_msg = f"Error loading models: {str(e)}"
logger.error(error_msg)
logger.error(traceback.format_exc())
raise ValueError(error_msg)
# =============================================================================
# UTILITY & DETECTION FUNCTIONS
# =============================================================================
def verify_and_rotate_image(board_image: np.ndarray, card_detector: YOLO) -> Tuple[np.ndarray, bool]:
"""
Checks if the detected cards are oriented primarily vertically or horizontally.
If they're vertical, rotates the board_image 90 degrees clockwise for consistent processing.
Returns (possibly_rotated_image, was_rotated_flag).
"""
detection = card_detector(board_image)
boxes = detection[0].boxes.xyxy.cpu().numpy().astype(int)
if boxes.size == 0:
return board_image, False
widths = boxes[:, 2] - boxes[:, 0]
heights = boxes[:, 3] - boxes[:, 1]
# Rotate if average height > average width
if np.mean(heights) > np.mean(widths):
return cv2.rotate(board_image, cv2.ROTATE_90_CLOCKWISE), True
else:
return board_image, False
def restore_orientation(img: np.ndarray, was_rotated: bool) -> np.ndarray:
"""
Restores original orientation if the image was previously rotated.
"""
if was_rotated:
return cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
return img
def predict_color(img_bgr: np.ndarray) -> str:
"""
Rough color classification using HSV thresholds to differentiate 'red', 'green', 'purple'.
"""
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
mask_green = cv2.inRange(hsv, np.array([40, 50, 50]), np.array([80, 255, 255]))
mask_purple = cv2.inRange(hsv, np.array([120, 50, 50]), np.array([160, 255, 255]))
# Red can wrap around hue=0, so we combine both ends
mask_red1 = cv2.inRange(hsv, np.array([0, 50, 50]), np.array([10, 255, 255]))
mask_red2 = cv2.inRange(hsv, np.array([170, 50, 50]), np.array([180, 255, 255]))
mask_red = cv2.bitwise_or(mask_red1, mask_red2)
counts = {
"green": cv2.countNonZero(mask_green),
"purple": cv2.countNonZero(mask_purple),
"red": cv2.countNonZero(mask_red),
}
return max(counts, key=counts.get)
def detect_cards(board_img: np.ndarray, card_detector: YOLO) -> List[Tuple[np.ndarray, List[int]]]:
"""
Runs YOLO on the board_img to detect card bounding boxes.
Returns a list of (card_image, [x1, y1, x2, y2]) for each detected card.
"""
result = card_detector(board_img)
boxes = result[0].boxes.xyxy.cpu().numpy().astype(int)
detected_cards = []
for x1, y1, x2, y2 in boxes:
detected_cards.append((board_img[y1:y2, x1:x2], [x1, y1, x2, y2]))
return detected_cards
def predict_card_features(
card_img: np.ndarray,
shape_detector: YOLO,
fill_model: tf.keras.Model,
shape_model: tf.keras.Model,
card_box: List[int]
) -> Dict:
"""
Predicts the 'count', 'color', 'fill', 'shape' features for a single card.
It uses a shape_detector YOLO model to locate shapes, then passes them to fill_model and shape_model.
"""
# Detect shapes on the card
shape_detections = shape_detector(card_img)
c_h, c_w = card_img.shape[:2]
card_area = c_w * c_h
# Filter out spurious shape detections
shape_boxes = []
for coords in shape_detections[0].boxes.xyxy.cpu().numpy():
x1, y1, x2, y2 = coords.astype(int)
if (x2 - x1) * (y2 - y1) > 0.03 * card_area:
shape_boxes.append([x1, y1, x2, y2])
if not shape_boxes:
return {
'count': 0,
'color': 'unknown',
'fill': 'unknown',
'shape': 'unknown',
'box': card_box
}
fill_input_size = fill_model.input_shape[1:3]
shape_input_size = shape_model.input_shape[1:3]
fill_imgs = []
shape_imgs = []
color_candidates = []
# Prepare each detected shape region for classification
for sb in shape_boxes:
sx1, sy1, sx2, sy2 = sb
shape_crop = card_img[sy1:sy2, sx1:sx2]
fill_crop = cv2.resize(shape_crop, fill_input_size) / 255.0
shape_crop_resized = cv2.resize(shape_crop, shape_input_size) / 255.0
fill_imgs.append(fill_crop)
shape_imgs.append(shape_crop_resized)
color_candidates.append(predict_color(shape_crop))
# Handle TensorFlow prediction - process one image at a time to avoid memory issues
fill_preds = []
shape_preds = []
for img in fill_imgs:
try:
pred = fill_model.predict(np.array([img]), verbose=0)
fill_preds.append(pred[0])
except Exception as e:
logger.error(f"Fill prediction error: {e}")
fill_preds.append(np.array([0.33, 0.33, 0.34])) # Fallback
for img in shape_imgs:
try:
pred = shape_model.predict(np.array([img]), verbose=0)
shape_preds.append(pred[0])
except Exception as e:
logger.error(f"Shape prediction error: {e}")
shape_preds.append(np.array([0.33, 0.33, 0.34])) # Fallback
fill_labels = ['empty', 'full', 'striped']
shape_labels = ['diamond', 'oval', 'squiggle']
fill_result = [fill_labels[np.argmax(fp)] for fp in fill_preds]
shape_result = [shape_labels[np.argmax(sp)] for sp in shape_preds]
# Take the most common color/fill/shape across all shape detections for the card
if color_candidates:
final_color = max(set(color_candidates), key=color_candidates.count)
else:
final_color = "unknown"
if fill_result:
final_fill = max(set(fill_result), key=fill_result.count)
else:
final_fill = "unknown"
if shape_result:
final_shape = max(set(shape_result), key=shape_result.count)
else:
final_shape = "unknown"
return {
'count': len(shape_boxes),
'color': final_color,
'fill': final_fill,
'shape': final_shape,
'box': card_box
}
def classify_cards_on_board(
board_img: np.ndarray,
card_detector: YOLO,
shape_detector: YOLO,
fill_model: tf.keras.Model,
shape_model: tf.keras.Model
) -> pd.DataFrame:
"""
Detects cards on the board, then classifies each card's features.
Returns a DataFrame with columns: 'Count', 'Color', 'Fill', 'Shape', 'Coordinates'.
"""
detected_cards = detect_cards(board_img, card_detector)
card_rows = []
for (card_img, box) in detected_cards:
card_feats = predict_card_features(card_img, shape_detector, fill_model, shape_model, box)
card_rows.append({
"Count": card_feats['count'],
"Color": card_feats['color'],
"Fill": card_feats['fill'],
"Shape": card_feats['shape'],
"Coordinates": card_feats['box']
})
return pd.DataFrame(card_rows)
def valid_set(cards: List[dict]) -> bool:
"""
Checks if the given 3 cards collectively form a valid SET.
"""
for feature in ["Count", "Color", "Fill", "Shape"]:
if len({card[feature] for card in cards}) not in (1, 3):
return False
return True
def locate_all_sets(cards_df: pd.DataFrame) -> List[dict]:
"""
Finds all possible SETs from the card DataFrame.
Each SET is a dictionary with 'set_indices' and 'cards' fields.
"""
found_sets = []
for combo in combinations(cards_df.iterrows(), 3):
cards = [c[1] for c in combo] # c is (index, row)
if valid_set(cards):
found_sets.append({
'set_indices': [c[0] for c in combo],
'cards': [
{f: card[f] for f in ['Count', 'Color', 'Fill', 'Shape', 'Coordinates']}
for card in cards
]
})
return found_sets
def draw_detected_sets(board_img: np.ndarray, sets_detected: List[dict]) -> np.ndarray:
"""
Annotates the board image with bounding boxes for each detected SET.
Each SET is drawn in a different color and offset (thickness & expansion)
so that overlapping sets are visible.
"""
# Some distinct BGR colors
colors = [
(255, 0, 0), (0, 255, 0), (0, 0, 255),
(255, 255, 0), (255, 0, 255), (0, 255, 255)
]
base_thickness = 8
base_expansion = 5
for idx, single_set in enumerate(sets_detected):
color = colors[idx % len(colors)]
thickness = base_thickness + 2 * idx
expansion = base_expansion + 15 * idx
for i, card_info in enumerate(single_set["cards"]):
x1, y1, x2, y2 = card_info["Coordinates"]
# Expand the bounding box slightly
x1e = max(0, x1 - expansion)
y1e = max(0, y1 - expansion)
x2e = min(board_img.shape[1], x2 + expansion)
y2e = min(board_img.shape[0], y2 + expansion)
cv2.rectangle(board_img, (x1e, y1e), (x2e, y2e), color, thickness)
# Label only the first card's box with "Set <number>"
if i == 0:
cv2.putText(
board_img,
f"Set {idx + 1}",
(x1e, y1e - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.9,
color,
thickness
)
return board_img
def optimize_image_size(image_array: np.ndarray, max_dim=1200) -> np.ndarray:
"""
Resizes an image if its largest dimension exceeds max_dim, to reduce processing time.
"""
if image_array is None:
return None
height, width = image_array.shape[:2]
if max(width, height) > max_dim:
if width > height:
new_width = max_dim
new_height = int(height * (max_dim / width))
else:
new_height = max_dim
new_width = int(width * (max_dim / height))
return cv2.resize(image_array, (new_width, new_height), interpolation=cv2.INTER_AREA)
return image_array
def process_image(input_image):
"""
CPU-only processing function for SET detection.
"""
if input_image is None:
return None, "Please upload an image."
try:
start_time = time.time()
# Load models (CPU-only)
card_detector, shape_detector, shape_model, fill_model = load_models()
# Optimize image size
optimized_img = optimize_image_size(input_image)
# Convert to BGR (OpenCV format)
if len(optimized_img.shape) == 3 and optimized_img.shape[2] == 4: # RGBA
optimized_img = cv2.cvtColor(optimized_img, cv2.COLOR_RGBA2BGR)
elif len(optimized_img.shape) == 3 and optimized_img.shape[2] == 3:
# RGB to BGR
optimized_img = cv2.cvtColor(optimized_img, cv2.COLOR_RGB2BGR)
# Check and fix orientation
processed_img, was_rotated = verify_and_rotate_image(optimized_img, card_detector)
# Detect cards
cards = detect_cards(processed_img, card_detector)
if not cards:
return cv2.cvtColor(optimized_img, cv2.COLOR_BGR2RGB), "No cards detected. Please check that it's a SET game board."
# Classify cards and find sets
df_cards = classify_cards_on_board(processed_img, card_detector, shape_detector, fill_model, shape_model)
found_sets = locate_all_sets(df_cards)
if not found_sets:
return cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB), "Cards detected, but no valid SETs found!"
# Draw sets on the image
annotated = draw_detected_sets(processed_img.copy(), found_sets)
# Restore original orientation if needed
final_output = restore_orientation(annotated, was_rotated)
# Convert back to RGB for display
final_output_rgb = cv2.cvtColor(final_output, cv2.COLOR_BGR2RGB)
process_time = time.time() - start_time
return final_output_rgb, f"Found {len(found_sets)} SET(s) in {process_time:.2f} seconds."
except Exception as e:
error_message = f"Error processing image: {str(e)}"
logger.error(error_message)
logger.error(traceback.format_exc())
return input_image, error_message
# Keep the spaces.GPU decorator for ZeroGPU API but use CPU internally
@spaces.GPU
def process_image_wrapper(input_image):
"""
Wrapper for process_image that uses the spaces.GPU decorator
but internally works in CPU-only mode.
"""
return process_image(input_image)
# =============================================================================
# SIMPLIFIED GRADIO INTERFACE
# =============================================================================
with gr.Blocks(title="SET Game Detector") as demo:
gr.HTML("""
<div style="text-align: center; margin-bottom: 1rem;">
<h1 style="margin-bottom: 0.5rem;">🎴 SET Game Detector</h1>
<p>Upload an image of a SET game board to find all valid sets</p>
</div>
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Upload SET Board Image",
type="numpy"
)
find_sets_btn = gr.Button(
"🔎 Find Sets",
variant="primary"
)
with gr.Column():
output_image = gr.Image(
label="Detected Sets"
)
status = gr.Textbox(
label="Status",
value="Upload an image and click 'Find Sets'",
interactive=False
)
# Function bindings
find_sets_btn.click(
fn=process_image_wrapper,
inputs=[input_image],
outputs=[output_image, status]
)
gr.HTML("""
<div style="text-align: center; margin-top: 1rem; padding: 0.5rem; font-size: 0.8rem;">
<p>SET Game Detector by <a href="https://github.com/omamitai" target="_blank">omamitai</a> |
Gradio version adapted for Hugging Face Spaces</p>
</div>
""")
# =============================================================================
# MAIN EXECUTION
# =============================================================================
if __name__ == "__main__":
# Launch the app
demo.queue().launch()