Oamitai commited on
Commit
f2cc874
ยท
verified ยท
1 Parent(s): 284127c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +657 -204
app.py CHANGED
@@ -1,273 +1,726 @@
1
- # IMPORTANT: Import spaces first to prevent premature CUDA initialization.
2
- import spaces
3
-
4
- import gradio as gr
5
- from huggingface_hub import hf_hub_download
6
- import torch
7
  import cv2
8
  import numpy as np
 
9
  import tensorflow as tf
10
  from tensorflow.keras.models import load_model
 
11
  from ultralytics import YOLO
 
 
12
  from PIL import Image
 
 
13
  import traceback
14
- import json
15
- import pandas as pd
16
- from itertools import combinations
17
 
18
  # =============================================================================
19
- # MODEL LOADING
20
  # =============================================================================
 
 
 
21
 
22
- # Load YOLO Card Detection Model from HuggingFace Hub
23
- card_model_path = hf_hub_download("Oamitai/card-detection", "best.pt")
24
- card_detection_model = YOLO(card_model_path)
25
- card_detection_model.conf = 0.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- # Load YOLO Shape Detection Model from HuggingFace Hub
28
- shape_model_path = hf_hub_download("Oamitai/shape-detection", "best.pt")
29
- shape_detection_model = YOLO(shape_model_path)
30
- shape_detection_model.conf = 0.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- # Load Shape Classification Model (Keras) from HuggingFace Hub
33
- shape_classification_model = load_model(
34
- hf_hub_download("Oamitai/shape-classification", "shape_model.keras")
35
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Load Fill Classification Model (Keras) from HuggingFace Hub
38
- fill_classification_model = load_model(
39
- hf_hub_download("Oamitai/fill-classification", "fill_model.keras")
40
- )
 
 
 
 
 
 
41
 
42
  # =============================================================================
43
- # UTILITY & PROCESSING FUNCTIONS
44
  # =============================================================================
45
-
46
- def check_and_rotate_input_image(board_image: np.ndarray, detector) -> (np.ndarray, bool):
47
  """
48
- Detect card regions and determine if the image needs to be rotated.
 
 
49
  """
50
- card_results = detector(board_image)
51
- card_boxes = card_results[0].boxes.xyxy.cpu().numpy().astype(int)
52
- if card_boxes.size == 0:
53
  return board_image, False
54
 
55
- widths = card_boxes[:, 2] - card_boxes[:, 0]
56
- heights = card_boxes[:, 3] - card_boxes[:, 1]
 
 
57
  if np.mean(heights) > np.mean(widths):
58
  return cv2.rotate(board_image, cv2.ROTATE_90_CLOCKWISE), True
59
- return board_image, False
 
60
 
61
- def restore_original_orientation(image: np.ndarray, was_rotated: bool) -> np.ndarray:
62
  """
63
- Restore the original orientation of the image if it was rotated.
64
  """
65
  if was_rotated:
66
- return cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
67
- return image
68
 
69
- def predict_color(shape_image: np.ndarray) -> str:
70
  """
71
- Determine the dominant color in a shape image using HSV thresholds.
72
  """
73
- hsv_image = cv2.cvtColor(shape_image, cv2.COLOR_BGR2HSV)
74
- green_mask = cv2.inRange(hsv_image, np.array([40, 50, 50]), np.array([80, 255, 255]))
75
- purple_mask = cv2.inRange(hsv_image, np.array([120, 50, 50]), np.array([160, 255, 255]))
76
- red_mask1 = cv2.inRange(hsv_image, np.array([0, 50, 50]), np.array([10, 255, 255]))
77
- red_mask2 = cv2.inRange(hsv_image, np.array([170, 50, 50]), np.array([180, 255, 255]))
78
- red_mask = cv2.bitwise_or(red_mask1, red_mask2)
79
-
80
- color_counts = {
81
- 'green': cv2.countNonZero(green_mask),
82
- 'purple': cv2.countNonZero(purple_mask),
83
- 'red': cv2.countNonZero(red_mask)
 
 
84
  }
85
- return max(color_counts, key=color_counts.get)
86
 
87
- def predict_card_features(card_image: np.ndarray, shape_detector, fill_model, shape_model, box: list) -> dict:
88
  """
89
- Detect and classify features on a card image.
 
90
  """
91
- shape_results = shape_detector(card_image)
92
- card_h, card_w = card_image.shape[:2]
93
- card_area = card_w * card_h
94
-
95
- filtered_boxes = [
96
- [int(x1), int(y1), int(x2), int(y2)]
97
- for x1, y1, x2, y2 in shape_results[0].boxes.xyxy.cpu().numpy()
98
- if (x2 - x1) * (y2 - y1) > 0.03 * card_area
99
- ]
100
-
101
- if not filtered_boxes:
102
- return {'count': 0, 'color': 'unknown', 'fill': 'unknown', 'shape': 'unknown', 'box': box}
103
-
104
- fill_input_shape = fill_model.input_shape[1:3]
105
- shape_input_shape = shape_model.input_shape[1:3]
106
- fill_imgs, shape_imgs, color_list = [], [], []
107
-
108
- for fb in filtered_boxes:
109
- x1, y1, x2, y2 = fb
110
- shape_img = card_image[y1:y2, x1:x2]
111
- fill_img = cv2.resize(shape_img, tuple(fill_input_shape)) / 255.0
112
- shape_img_resized = cv2.resize(shape_img, tuple(shape_input_shape)) / 255.0
113
- fill_imgs.append(fill_img)
114
- shape_imgs.append(shape_img_resized)
115
- color_list.append(predict_color(shape_img))
116
-
117
- fill_imgs = np.array(fill_imgs)
118
- shape_imgs = np.array(shape_imgs)
119
-
120
- fill_preds = fill_model.predict(fill_imgs, batch_size=len(fill_imgs))
121
- shape_preds = shape_model.predict(shape_imgs, batch_size=len(shape_imgs))
122
-
123
- fill_labels_list = ['empty', 'full', 'striped']
124
- shape_labels_list = ['diamond', 'oval', 'squiggle']
125
-
126
- predicted_fill = [fill_labels_list[np.argmax(pred)] for pred in fill_preds]
127
- predicted_shape = [shape_labels_list[np.argmax(pred)] for pred in shape_preds]
128
 
129
- color_label = max(set(color_list), key=color_list.count)
130
- fill_label = max(set(predicted_fill), key=predicted_fill.count)
131
- shape_label = max(set(predicted_shape), key=predicted_shape.count)
132
-
133
- return {'count': len(filtered_boxes), 'color': color_label,
134
- 'fill': fill_label, 'shape': shape_label, 'box': box}
135
 
136
- def is_set(cards: list) -> bool:
 
 
 
 
 
 
137
  """
138
- Check if a group of cards forms a valid set. For each feature,
139
- values must be all identical or all distinct.
140
  """
141
- for feature in ['Count', 'Color', 'Fill', 'Shape']:
142
- if len({card[feature] for card in cards}) not in [1, 3]:
143
- return False
144
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- def find_sets(card_df: pd.DataFrame) -> list:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  """
148
- Iterate over all combinations of three cards to identify valid sets.
 
149
  """
150
- sets_found = []
151
- for combo in combinations(card_df.iterrows(), 3):
152
- cards = [entry[1] for entry in combo]
153
- if is_set(cards):
154
- sets_found.append({
155
- 'set_indices': [entry[0] for entry in combo],
156
- 'cards': [{feature: card[feature] for feature in
157
- ['Count', 'Color', 'Fill', 'Shape', 'Coordinates']} for card in cards]
158
- })
159
- return sets_found
160
 
161
- def detect_cards_from_image(board_image: np.ndarray, detector) -> list:
 
 
 
 
 
 
 
 
 
 
 
 
162
  """
163
- Extract card regions from the board image using the YOLO card detection model.
164
  """
165
- card_results = detector(board_image)
166
- card_boxes = card_results[0].boxes.xyxy.cpu().numpy().astype(int)
167
- return [(board_image[y1:y2, x1:x2], [x1, y1, x2, y2]) for x1, y1, x2, y2 in card_boxes]
 
168
 
169
- def classify_cards_from_board_image(board_image: np.ndarray, card_detector, shape_detector, fill_model, shape_model) -> pd.DataFrame:
170
  """
171
- Detect cards from the board image and classify their features.
 
172
  """
173
- cards = detect_cards_from_image(board_image, card_detector)
174
- card_data = []
175
- for card_image, box in cards:
176
- features = predict_card_features(card_image, shape_detector, fill_model, shape_model, box)
177
- card_data.append({
178
- "Count": features['count'],
179
- "Color": features['color'],
180
- "Fill": features['fill'],
181
- "Shape": features['shape'],
182
- "Coordinates": f"{box[0]}, {box[1]}, {box[2]}, {box[3]}"
183
- })
184
- return pd.DataFrame(card_data)
185
 
186
- def draw_sets_on_image(board_image: np.ndarray, sets_info: list) -> np.ndarray:
187
  """
188
- Draw bounding boxes and labels for each detected set on the board image.
 
 
189
  """
190
- colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255),
191
- (255, 255, 0), (255, 0, 255), (0, 255, 255)]
 
 
 
192
  base_thickness = 8
193
  base_expansion = 5
194
- for index, set_info in enumerate(sets_info):
195
- color = colors[index % len(colors)]
196
- thickness = base_thickness + 2 * index
197
- expansion = base_expansion + 15 * index
198
- for i, card in enumerate(set_info['cards']):
199
- coordinates = list(map(int, card['Coordinates'].split(',')))
200
- x1, y1, x2, y2 = coordinates
201
- x1_exp = max(0, x1 - expansion)
202
- y1_exp = max(0, y1 - expansion)
203
- x2_exp = min(board_image.shape[1], x2 + expansion)
204
- y2_exp = min(board_image.shape[0], y2 + expansion)
205
- cv2.rectangle(board_image, (x1_exp, y1_exp), (x2_exp, y2_exp), color, thickness)
 
 
 
 
 
206
  if i == 0:
207
- cv2.putText(board_image, f"Set {index + 1}", (x1_exp, y1_exp - 10),
208
- cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, thickness)
209
- return board_image
 
 
 
 
 
 
 
210
 
211
- def classify_and_find_sets_from_array(board_image: np.ndarray, card_detector, shape_detector, fill_model, shape_model) -> (list, np.ndarray):
 
 
212
  """
213
- Process the input image: adjust orientation, classify card features, detect sets, and annotate the image.
 
214
  """
215
- processed_image, was_rotated = check_and_rotate_input_image(board_image, card_detector)
216
- card_df = classify_cards_from_board_image(processed_image, card_detector, shape_detector, fill_model, shape_model)
217
- sets_found = find_sets(card_df)
218
- annotated_image = draw_sets_on_image(processed_image.copy(), sets_found)
219
- final_image = restore_original_orientation(annotated_image, was_rotated)
220
- return sets_found, final_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
- # =============================================================================
223
- # GRADIO INFERENCE FUNCTION
224
- # =============================================================================
 
 
 
 
 
225
 
226
- @spaces.GPU()
227
- def detect_sets(input_image: Image.Image):
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  """
229
- Process an uploaded image and return the annotated image along with detected sets info.
230
  """
231
- try:
232
- # Convert the PIL image to OpenCV BGR format
233
- image_cv = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
234
- # Run the detection pipeline
235
- sets_info, annotated_image = classify_and_find_sets_from_array(
236
- image_cv,
237
- card_detection_model,
238
- shape_detection_model,
239
- fill_classification_model,
240
- shape_classification_model
241
- )
242
- # Convert annotated image back to RGB for display
243
- annotated_image_rgb = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
244
- return annotated_image_rgb, json.dumps(sets_info, indent=2)
245
- except Exception:
246
- return None, f"Error occurred: {traceback.format_exc()}"
247
 
248
  # =============================================================================
249
- # GRADIO INTERFACE
250
  # =============================================================================
 
 
 
 
 
 
 
 
 
 
 
251
 
252
- with gr.Blocks(css="#col-container { margin: 0 auto; max-width: 800px; }") as demo:
253
- gr.Markdown("# Set Game Detector\nUpload an image of a Set game board to detect valid sets.")
254
-
255
- with gr.Row(elem_id="col-container"):
256
- image_input = gr.Image(label="Upload Set Game Board", type="pil")
257
- detect_button = gr.Button("Detect Sets")
258
 
259
- with gr.Row():
260
- result_image = gr.Image(label="Annotated Image")
261
- result_info = gr.JSON(label="Detected Sets Info")
 
262
 
263
- detect_button.click(
264
- detect_sets,
265
- inputs=[image_input],
266
- outputs=[result_image, result_info]
267
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  # =============================================================================
270
- # LAUNCH THE APP
271
  # =============================================================================
272
-
273
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
 
 
 
 
 
2
  import cv2
3
  import numpy as np
4
+ import pandas as pd
5
  import tensorflow as tf
6
  from tensorflow.keras.models import load_model
7
+ import torch
8
  from ultralytics import YOLO
9
+ from itertools import combinations
10
+ from pathlib import Path
11
  from PIL import Image
12
+ import gradio as gr
13
+ import functools
14
  import traceback
15
+ import time
16
+ from typing import Tuple, List, Dict
17
+ import logging
18
 
19
  # =============================================================================
20
+ # LOGGING CONFIGURATION
21
  # =============================================================================
22
+ logging.basicConfig(level=logging.INFO,
23
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
24
+ logger = logging.getLogger("set_detector")
25
 
26
+ # =============================================================================
27
+ # MODEL PATHS & LOADING
28
+ # =============================================================================
29
+ # For loading models from Hugging Face Hub
30
+ HF_MODEL_REPO_PREFIX = "Omamitai"
31
+
32
+ # Define model repos and paths
33
+ CARD_DETECTION_REPO = f"{HF_MODEL_REPO_PREFIX}/card-detection"
34
+ SHAPE_DETECTION_REPO = f"{HF_MODEL_REPO_PREFIX}/shape-detection"
35
+ SHAPE_CLASSIFICATION_REPO = f"{HF_MODEL_REPO_PREFIX}/shape-classification"
36
+ FILL_CLASSIFICATION_REPO = f"{HF_MODEL_REPO_PREFIX}/fill-classification"
37
+
38
+ # Model filenames
39
+ CARD_MODEL_FILENAME = "best.pt"
40
+ SHAPE_MODEL_FILENAME = "best.pt"
41
+ SHAPE_CLASS_MODEL_FILENAME = "shape_model.keras"
42
+ FILL_CLASS_MODEL_FILENAME = "fill_model.keras"
43
+
44
+ # For local testing: fallback to local models if HF downloads fail
45
+ # Use the local directory structure as fallback
46
+ if os.path.exists("/home/user"): # Check if we're on HF Spaces
47
+ local_base_dir = Path("/home/user/app/models")
48
+ else:
49
+ local_base_dir = Path(os.path.dirname(os.path.abspath(__file__))) / "models"
50
+
51
+ local_char_path = local_base_dir / "Characteristics" / "11022025"
52
+ local_shape_path = local_base_dir / "Shape" / "15052024"
53
+ local_card_path = local_base_dir / "Card" / "16042024"
54
+
55
+ # Global variables for model caching
56
+ _MODEL_SHAPE = None
57
+ _MODEL_FILL = None
58
+ _DETECTOR_CARD = None
59
+ _DETECTOR_SHAPE = None
60
+ _MODELS_LOADED = False
61
+ _MODEL_LOADING_ERROR = None
62
+
63
+ def load_classification_models() -> Tuple[tf.keras.Model, tf.keras.Model]:
64
+ """
65
+ Loads the Keras models for 'shape' and 'fill' classification from HuggingFace Hub.
66
+ Returns (shape_model, fill_model).
67
+ """
68
+ global _MODEL_SHAPE, _MODEL_FILL, _MODEL_LOADING_ERROR
69
+
70
+ # If models are already loaded, return them
71
+ if _MODEL_SHAPE is not None and _MODEL_FILL is not None:
72
+ return _MODEL_SHAPE, _MODEL_FILL
73
+
74
+ try:
75
+ from huggingface_hub import hf_hub_download
76
+
77
+ # Try to download from HuggingFace Hub
78
+ logger.info(f"Downloading shape classification model from {SHAPE_CLASSIFICATION_REPO}...")
79
+ shape_model_path = hf_hub_download(
80
+ repo_id=SHAPE_CLASSIFICATION_REPO,
81
+ filename=SHAPE_CLASS_MODEL_FILENAME,
82
+ cache_dir="./hf_cache"
83
+ )
84
+
85
+ logger.info(f"Downloading fill classification model from {FILL_CLASSIFICATION_REPO}...")
86
+ fill_model_path = hf_hub_download(
87
+ repo_id=FILL_CLASSIFICATION_REPO,
88
+ filename=FILL_CLASS_MODEL_FILENAME,
89
+ cache_dir="./hf_cache"
90
+ )
91
+
92
+ # Load the models
93
+ logger.info("Loading classification models...")
94
+ model_shape = load_model(str(shape_model_path))
95
+ model_fill = load_model(str(fill_model_path))
96
+
97
+ logger.info("Classification models loaded successfully")
98
+ _MODEL_SHAPE, _MODEL_FILL = model_shape, model_fill
99
+ return model_shape, model_fill
100
+
101
+ except Exception as e:
102
+ error_msg = f"Error downloading classification models from HF Hub: {str(e)}"
103
+ logger.error(error_msg)
104
+
105
+ # Try fallback to local files
106
+ try:
107
+ logger.info("Trying fallback to local model files...")
108
+ shape_model_path = local_char_path / SHAPE_CLASS_MODEL_FILENAME
109
+ fill_model_path = local_char_path / FILL_CLASS_MODEL_FILENAME
110
+
111
+ if not shape_model_path.exists() or not fill_model_path.exists():
112
+ raise FileNotFoundError("Local model files not found")
113
+
114
+ # Load the models
115
+ model_shape = load_model(str(shape_model_path))
116
+ model_fill = load_model(str(fill_model_path))
117
+
118
+ logger.info("Classification models loaded successfully from local files")
119
+ _MODEL_SHAPE, _MODEL_FILL = model_shape, model_fill
120
+ return model_shape, model_fill
121
+
122
+ except Exception as fallback_error:
123
+ error_msg = f"{error_msg}\nFallback to local files also failed: {str(fallback_error)}"
124
+ logger.error(error_msg)
125
+ _MODEL_LOADING_ERROR = error_msg
126
+ return None, None
127
 
128
+ def load_detection_models() -> Tuple[YOLO, YOLO]:
129
+ """
130
+ Loads the YOLO detection models for cards and shapes from HuggingFace Hub.
131
+ Returns (card_detector, shape_detector).
132
+ """
133
+ global _DETECTOR_CARD, _DETECTOR_SHAPE, _MODEL_LOADING_ERROR
134
+
135
+ # If models are already loaded, return them
136
+ if _DETECTOR_CARD is not None and _DETECTOR_SHAPE is not None:
137
+ return _DETECTOR_CARD, _DETECTOR_SHAPE
138
+
139
+ try:
140
+ from huggingface_hub import hf_hub_download
141
+
142
+ # Try to download from HuggingFace Hub
143
+ logger.info(f"Downloading card detection model from {CARD_DETECTION_REPO}...")
144
+ card_model_path = hf_hub_download(
145
+ repo_id=CARD_DETECTION_REPO,
146
+ filename=CARD_MODEL_FILENAME,
147
+ cache_dir="./hf_cache"
148
+ )
149
+
150
+ logger.info(f"Downloading shape detection model from {SHAPE_DETECTION_REPO}...")
151
+ shape_model_path = hf_hub_download(
152
+ repo_id=SHAPE_DETECTION_REPO,
153
+ filename=SHAPE_MODEL_FILENAME,
154
+ cache_dir="./hf_cache"
155
+ )
156
+
157
+ # Load the models
158
+ logger.info("Loading detection models...")
159
+ detector_shape = YOLO(str(shape_model_path))
160
+ detector_shape.conf = 0.5
161
+ detector_card = YOLO(str(card_model_path))
162
+ detector_card.conf = 0.5
163
+
164
+ # Use GPU if available
165
+ if torch.cuda.is_available():
166
+ logger.info("CUDA is available. Using GPU for inference.")
167
+ detector_card.to("cuda")
168
+ detector_shape.to("cuda")
169
+ else:
170
+ logger.info("CUDA is not available. Using CPU for inference.")
171
+
172
+ logger.info("Detection models loaded successfully")
173
+ _DETECTOR_CARD, _DETECTOR_SHAPE = detector_card, detector_shape
174
+ return detector_card, detector_shape
175
+
176
+ except Exception as e:
177
+ error_msg = f"Error downloading detection models from HF Hub: {str(e)}"
178
+ logger.error(error_msg)
179
+
180
+ # Try fallback to local files
181
+ try:
182
+ logger.info("Trying fallback to local model files...")
183
+ shape_model_path = local_shape_path / SHAPE_MODEL_FILENAME
184
+ card_model_path = local_card_path / CARD_MODEL_FILENAME
185
+
186
+ if not shape_model_path.exists() or not card_model_path.exists():
187
+ raise FileNotFoundError("Local model files not found")
188
+
189
+ # Load the models
190
+ detector_shape = YOLO(str(shape_model_path))
191
+ detector_shape.conf = 0.5
192
+ detector_card = YOLO(str(card_model_path))
193
+ detector_card.conf = 0.5
194
+
195
+ # Use GPU if available
196
+ if torch.cuda.is_available():
197
+ logger.info("CUDA is available. Using GPU for inference.")
198
+ detector_card.to("cuda")
199
+ detector_shape.to("cuda")
200
+
201
+ logger.info("Detection models loaded successfully from local files")
202
+ _DETECTOR_CARD, _DETECTOR_SHAPE = detector_card, detector_shape
203
+ return detector_card, detector_shape
204
+
205
+ except Exception as fallback_error:
206
+ error_msg = f"{error_msg}\nFallback to local files also failed: {str(fallback_error)}"
207
+ logger.error(error_msg)
208
+ _MODEL_LOADING_ERROR = error_msg
209
+ return None, None
210
 
211
+ def load_all_models() -> bool:
212
+ """
213
+ Loads all required models and returns True if successful.
214
+ """
215
+ global _MODELS_LOADED, _MODEL_LOADING_ERROR
216
+
217
+ if _MODELS_LOADED:
218
+ return True
219
+
220
+ try:
221
+ model_shape, model_fill = load_classification_models()
222
+ detector_card, detector_shape = load_detection_models()
223
+
224
+ models_loaded = all([model_shape, model_fill, detector_card, detector_shape])
225
+ _MODELS_LOADED = models_loaded
226
+
227
+ if not models_loaded and _MODEL_LOADING_ERROR is None:
228
+ _MODEL_LOADING_ERROR = "Unknown error loading models"
229
+
230
+ return models_loaded
231
+ except Exception as e:
232
+ error_msg = f"Error loading models: {str(e)}"
233
+ logger.error(error_msg)
234
+ _MODEL_LOADING_ERROR = error_msg
235
+ return False
236
 
237
+ def get_model_status() -> str:
238
+ """
239
+ Returns a status message about the models.
240
+ """
241
+ if _MODELS_LOADED:
242
+ return "All models loaded successfully!"
243
+ elif _MODEL_LOADING_ERROR:
244
+ return f"Error: {_MODEL_LOADING_ERROR}"
245
+ else:
246
+ return "Models not loaded yet. Click 'Load Models' to preload them."
247
 
248
  # =============================================================================
249
+ # UTILITY & DETECTION FUNCTIONS
250
  # =============================================================================
251
+ def verify_and_rotate_image(board_image: np.ndarray, card_detector: YOLO) -> Tuple[np.ndarray, bool]:
 
252
  """
253
+ Checks if the detected cards are oriented primarily vertically or horizontally.
254
+ If they're vertical, rotates the board_image 90 degrees clockwise for consistent processing.
255
+ Returns (possibly_rotated_image, was_rotated_flag).
256
  """
257
+ detection = card_detector(board_image)
258
+ boxes = detection[0].boxes.xyxy.cpu().numpy().astype(int)
259
+ if boxes.size == 0:
260
  return board_image, False
261
 
262
+ widths = boxes[:, 2] - boxes[:, 0]
263
+ heights = boxes[:, 3] - boxes[:, 1]
264
+
265
+ # Rotate if average height > average width
266
  if np.mean(heights) > np.mean(widths):
267
  return cv2.rotate(board_image, cv2.ROTATE_90_CLOCKWISE), True
268
+ else:
269
+ return board_image, False
270
 
271
+ def restore_orientation(img: np.ndarray, was_rotated: bool) -> np.ndarray:
272
  """
273
+ Restores original orientation if the image was previously rotated.
274
  """
275
  if was_rotated:
276
+ return cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
277
+ return img
278
 
279
+ def predict_color(img_bgr: np.ndarray) -> str:
280
  """
281
+ Rough color classification using HSV thresholds to differentiate 'red', 'green', 'purple'.
282
  """
283
+ hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
284
+ mask_green = cv2.inRange(hsv, np.array([40, 50, 50]), np.array([80, 255, 255]))
285
+ mask_purple = cv2.inRange(hsv, np.array([120, 50, 50]), np.array([160, 255, 255]))
286
+
287
+ # Red can wrap around hue=0, so we combine both ends
288
+ mask_red1 = cv2.inRange(hsv, np.array([0, 50, 50]), np.array([10, 255, 255]))
289
+ mask_red2 = cv2.inRange(hsv, np.array([170, 50, 50]), np.array([180, 255, 255]))
290
+ mask_red = cv2.bitwise_or(mask_red1, mask_red2)
291
+
292
+ counts = {
293
+ "green": cv2.countNonZero(mask_green),
294
+ "purple": cv2.countNonZero(mask_purple),
295
+ "red": cv2.countNonZero(mask_red),
296
  }
297
+ return max(counts, key=counts.get)
298
 
299
+ def detect_cards(board_img: np.ndarray, card_detector: YOLO) -> List[Tuple[np.ndarray, List[int]]]:
300
  """
301
+ Runs YOLO on the board_img to detect card bounding boxes.
302
+ Returns a list of (card_image, [x1, y1, x2, y2]) for each detected card.
303
  """
304
+ result = card_detector(board_img)
305
+ boxes = result[0].boxes.xyxy.cpu().numpy().astype(int)
306
+ detected_cards = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
+ for x1, y1, x2, y2 in boxes:
309
+ detected_cards.append((board_img[y1:y2, x1:x2], [x1, y1, x2, y2]))
310
+ return detected_cards
 
 
 
311
 
312
+ def predict_card_features(
313
+ card_img: np.ndarray,
314
+ shape_detector: YOLO,
315
+ fill_model: tf.keras.Model,
316
+ shape_model: tf.keras.Model,
317
+ card_box: List[int]
318
+ ) -> Dict:
319
  """
320
+ Predicts the 'count', 'color', 'fill', 'shape' features for a single card.
321
+ It uses a shape_detector YOLO model to locate shapes, then passes them to fill_model and shape_model.
322
  """
323
+ # Detect shapes on the card
324
+ shape_detections = shape_detector(card_img)
325
+ c_h, c_w = card_img.shape[:2]
326
+ card_area = c_w * c_h
327
+
328
+ # Filter out spurious shape detections
329
+ shape_boxes = []
330
+ for coords in shape_detections[0].boxes.xyxy.cpu().numpy():
331
+ x1, y1, x2, y2 = coords.astype(int)
332
+ if (x2 - x1) * (y2 - y1) > 0.03 * card_area:
333
+ shape_boxes.append([x1, y1, x2, y2])
334
+
335
+ if not shape_boxes:
336
+ return {
337
+ 'count': 0,
338
+ 'color': 'unknown',
339
+ 'fill': 'unknown',
340
+ 'shape': 'unknown',
341
+ 'box': card_box
342
+ }
343
+
344
+ fill_input_size = fill_model.input_shape[1:3]
345
+ shape_input_size = shape_model.input_shape[1:3]
346
+ fill_imgs = []
347
+ shape_imgs = []
348
+ color_candidates = []
349
+
350
+ # Prepare each detected shape region for classification
351
+ for sb in shape_boxes:
352
+ sx1, sy1, sx2, sy2 = sb
353
+ shape_crop = card_img[sy1:sy2, sx1:sx2]
354
 
355
+ fill_crop = cv2.resize(shape_crop, fill_input_size) / 255.0
356
+ shape_crop_resized = cv2.resize(shape_crop, shape_input_size) / 255.0
357
+
358
+ fill_imgs.append(fill_crop)
359
+ shape_imgs.append(shape_crop_resized)
360
+ color_candidates.append(predict_color(shape_crop))
361
+
362
+ # Use verbose=0 to suppress progress bar
363
+ fill_preds = fill_model.predict(np.array(fill_imgs), batch_size=len(fill_imgs), verbose=0)
364
+ shape_preds = shape_model.predict(np.array(shape_imgs), batch_size=len(shape_imgs), verbose=0)
365
+
366
+ fill_labels = ['empty', 'full', 'striped']
367
+ shape_labels = ['diamond', 'oval', 'squiggle']
368
+
369
+ fill_result = [fill_labels[np.argmax(fp)] for fp in fill_preds]
370
+ shape_result = [shape_labels[np.argmax(sp)] for sp in shape_preds]
371
+
372
+ # Take the most common color/fill/shape across all shape detections for the card
373
+ final_color = max(set(color_candidates), key=color_candidates.count)
374
+ final_fill = max(set(fill_result), key=fill_result.count)
375
+ final_shape = max(set(shape_result), key=shape_result.count)
376
+
377
+ return {
378
+ 'count': len(shape_boxes),
379
+ 'color': final_color,
380
+ 'fill': final_fill,
381
+ 'shape': final_shape,
382
+ 'box': card_box
383
+ }
384
+
385
+ def classify_cards_on_board(
386
+ board_img: np.ndarray,
387
+ card_detector: YOLO,
388
+ shape_detector: YOLO,
389
+ fill_model: tf.keras.Model,
390
+ shape_model: tf.keras.Model
391
+ ) -> pd.DataFrame:
392
  """
393
+ Detects cards on the board, then classifies each card's features.
394
+ Returns a DataFrame with columns: 'Count', 'Color', 'Fill', 'Shape', 'Coordinates'.
395
  """
396
+ detected_cards = detect_cards(board_img, card_detector)
397
+ card_rows = []
 
 
 
 
 
 
 
 
398
 
399
+ for (card_img, box) in detected_cards:
400
+ card_feats = predict_card_features(card_img, shape_detector, fill_model, shape_model, box)
401
+ card_rows.append({
402
+ "Count": card_feats['count'],
403
+ "Color": card_feats['color'],
404
+ "Fill": card_feats['fill'],
405
+ "Shape": card_feats['shape'],
406
+ "Coordinates": card_feats['box']
407
+ })
408
+
409
+ return pd.DataFrame(card_rows)
410
+
411
+ def valid_set(cards: List[dict]) -> bool:
412
  """
413
+ Checks if the given 3 cards collectively form a valid SET.
414
  """
415
+ for feature in ["Count", "Color", "Fill", "Shape"]:
416
+ if len({card[feature] for card in cards}) not in (1, 3):
417
+ return False
418
+ return True
419
 
420
+ def locate_all_sets(cards_df: pd.DataFrame) -> List[dict]:
421
  """
422
+ Finds all possible SETs from the card DataFrame.
423
+ Each SET is a dictionary with 'set_indices' and 'cards' fields.
424
  """
425
+ found_sets = []
426
+ for combo in combinations(cards_df.iterrows(), 3):
427
+ cards = [c[1] for c in combo] # c is (index, row)
428
+ if valid_set(cards):
429
+ found_sets.append({
430
+ 'set_indices': [c[0] for c in combo],
431
+ 'cards': [
432
+ {f: card[f] for f in ['Count', 'Color', 'Fill', 'Shape', 'Coordinates']}
433
+ for card in cards
434
+ ]
435
+ })
436
+ return found_sets
437
 
438
+ def draw_detected_sets(board_img: np.ndarray, sets_detected: List[dict]) -> np.ndarray:
439
  """
440
+ Annotates the board image with bounding boxes for each detected SET.
441
+ Each SET is drawn in a different color and offset (thickness & expansion)
442
+ so that overlapping sets are visible.
443
  """
444
+ # Some distinct BGR colors
445
+ colors = [
446
+ (255, 0, 0), (0, 255, 0), (0, 0, 255),
447
+ (255, 255, 0), (255, 0, 255), (0, 255, 255)
448
+ ]
449
  base_thickness = 8
450
  base_expansion = 5
451
+
452
+ for idx, single_set in enumerate(sets_detected):
453
+ color = colors[idx % len(colors)]
454
+ thickness = base_thickness + 2 * idx
455
+ expansion = base_expansion + 15 * idx
456
+
457
+ for i, card_info in enumerate(single_set["cards"]):
458
+ x1, y1, x2, y2 = card_info["Coordinates"]
459
+ # Expand the bounding box slightly
460
+ x1e = max(0, x1 - expansion)
461
+ y1e = max(0, y1 - expansion)
462
+ x2e = min(board_img.shape[1], x2 + expansion)
463
+ y2e = min(board_img.shape[0], y2 + expansion)
464
+
465
+ cv2.rectangle(board_img, (x1e, y1e), (x2e, y2e), color, thickness)
466
+
467
+ # Label only the first card's box with "Set <number>"
468
  if i == 0:
469
+ cv2.putText(
470
+ board_img,
471
+ f"Set {idx + 1}",
472
+ (x1e, y1e - 10),
473
+ cv2.FONT_HERSHEY_SIMPLEX,
474
+ 0.9,
475
+ color,
476
+ thickness
477
+ )
478
+ return board_img
479
 
480
+ def identify_sets_from_image(
481
+ board_img: np.ndarray
482
+ ) -> Tuple[List[dict], np.ndarray, str]:
483
  """
484
+ End-to-end pipeline to classify cards on the board and detect valid sets.
485
+ Returns a tuple of (list of sets, annotated image, status message).
486
  """
487
+ # Load models
488
+ if not load_all_models():
489
+ error_msg = _MODEL_LOADING_ERROR or "Error: Could not load models."
490
+ return [], board_img, error_msg
491
+
492
+ card_detector, shape_detector = _DETECTOR_CARD, _DETECTOR_SHAPE
493
+ model_shape, model_fill = _MODEL_SHAPE, _MODEL_FILL
494
+
495
+ # Convert image to BGR if needed (OpenCV format)
496
+ if len(board_img.shape) == 3 and board_img.shape[2] == 4: # RGBA
497
+ board_img = cv2.cvtColor(board_img, cv2.COLOR_RGBA2BGR)
498
+ elif len(board_img.shape) == 3 and board_img.shape[2] == 3:
499
+ # We assume the image is already in BGR format (OpenCV standard)
500
+ # If it's in RGB format (common from web uploads), we'll convert it
501
+ board_img = cv2.cvtColor(board_img, cv2.COLOR_RGB2BGR)
502
+ else:
503
+ return [], board_img, "Error: Unsupported image format. Please upload a color image."
504
+
505
+ # 1. Check and fix orientation if needed
506
+ processed, was_rotated = verify_and_rotate_image(board_img, card_detector)
507
 
508
+ # 2. Verify that cards are present
509
+ cards = detect_cards(processed, card_detector)
510
+ if not cards:
511
+ return [], cv2.cvtColor(board_img, cv2.COLOR_BGR2RGB), "No cards detected in the image. Please check that it's a SET game board."
512
+
513
+ # 3. Classify each card's features, then find sets
514
+ df_cards = classify_cards_on_board(processed, card_detector, shape_detector, model_fill, model_shape)
515
+ found_sets = locate_all_sets(df_cards)
516
 
517
+ if not found_sets:
518
+ return [], cv2.cvtColor(processed, cv2.COLOR_BGR2RGB), "Cards detected, but no valid SETs found. You may need to add more cards to the table!"
519
+
520
+ # 4. Draw sets on a copy of the image
521
+ annotated = draw_detected_sets(processed.copy(), found_sets)
522
+
523
+ # 5. Restore orientation if we rotated earlier
524
+ final_output = restore_orientation(annotated, was_rotated)
525
+
526
+ # Convert back to RGB for display
527
+ final_output_rgb = cv2.cvtColor(final_output, cv2.COLOR_BGR2RGB)
528
+
529
+ return found_sets, final_output_rgb, f"Found {len(found_sets)} SET(s) in the image."
530
+
531
+ def optimize_image_size(image_array: np.ndarray, max_dim=1200) -> np.ndarray:
532
  """
533
+ Resizes an image if its largest dimension exceeds max_dim, to reduce processing time.
534
  """
535
+ if image_array is None:
536
+ return None
537
+
538
+ height, width = image_array.shape[:2]
539
+ if max(width, height) > max_dim:
540
+ if width > height:
541
+ new_width = max_dim
542
+ new_height = int(height * (max_dim / width))
543
+ else:
544
+ new_height = max_dim
545
+ new_width = int(width * (max_dim / height))
546
+
547
+ return cv2.resize(image_array, (new_width, new_height), interpolation=cv2.INTER_AREA)
548
+ return image_array
 
 
549
 
550
  # =============================================================================
551
+ # MAIN PROCESSING FUNCTIONS FOR GRADIO
552
  # =============================================================================
553
+ def preload_models():
554
+ """
555
+ Function to preload models and return status.
556
+ """
557
+ try:
558
+ if load_all_models():
559
+ return "Models loaded successfully! Ready to detect SETs."
560
+ else:
561
+ return f"Error loading models: {_MODEL_LOADING_ERROR or 'Unknown error'}"
562
+ except Exception as e:
563
+ return f"Error loading models: {str(e)}"
564
 
565
+ @spaces.GPU
566
+ def process_set_image(input_image):
567
+ """
568
+ Main processing function for the Gradio interface.
569
+ Takes an input image, processes it to find SETs, and returns the output image and status.
 
570
 
571
+ Uses @spaces.GPU for Hugging Face Spaces zero-GPU optimization.
572
+ """
573
+ if input_image is None:
574
+ return None, "Please upload an image."
575
 
576
+ try:
577
+ start_time = time.time()
578
+ logger.info("Processing image...")
579
+
580
+ # Check if image needs to be optimized (resized)
581
+ optimized_image = optimize_image_size(input_image)
582
+
583
+ # Identify sets
584
+ found_sets, annotated_image, status_message = identify_sets_from_image(optimized_image)
585
+
586
+ process_time = time.time() - start_time
587
+ logger.info(f"Image processed in {process_time:.2f} seconds.")
588
+
589
+ return annotated_image, status_message
590
+
591
+ except Exception as e:
592
+ error_message = f"Error processing image: {str(e)}"
593
+ logger.error(error_message)
594
+ logger.error(traceback.format_exc())
595
+ return input_image, error_message
596
 
597
  # =============================================================================
598
+ # GRADIO INTERFACE
599
  # =============================================================================
600
+ def create_gradio_interface():
601
+ """
602
+ Creates and returns the Gradio interface for the SET Game Detector.
603
+ """
604
+ # CSS for styling the Gradio interface
605
+ css = """
606
+ @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@400;500;600;700&display=swap');
607
+
608
+ .gradio-container {
609
+ font-family: 'Poppins', sans-serif;
610
+ }
611
+ .app-header {
612
+ text-align: center;
613
+ margin-bottom: 20px;
614
+ background: linear-gradient(90deg, rgba(124, 58, 237, 0.1) 0%, rgba(236, 72, 153, 0.1) 100%);
615
+ padding: 1rem;
616
+ border-radius: 12px;
617
+ }
618
+ .app-header h1 {
619
+ font-size: 2.5rem;
620
+ background: linear-gradient(90deg, #8B5CF6 0%, #7C3AED 50%, #EC4899 100%);
621
+ -webkit-background-clip: text;
622
+ background-clip: text;
623
+ -webkit-text-fill-color: transparent;
624
+ margin-bottom: 5px;
625
+ }
626
+ .app-header p {
627
+ font-size: 1.1rem;
628
+ opacity: 0.8;
629
+ margin-top: 0;
630
+ }
631
+ .footer {
632
+ text-align: center;
633
+ margin-top: 20px;
634
+ padding: 10px;
635
+ background: linear-gradient(90deg, rgba(124, 58, 237, 0.05) 0%, rgba(236, 72, 153, 0.05) 100%);
636
+ border-radius: 12px;
637
+ }
638
+
639
+ /* Responsive design for mobile */
640
+ @media (max-width: 600px) {
641
+ .app-header h1 {
642
+ font-size: 1.8rem;
643
+ }
644
+ .app-header p {
645
+ font-size: 0.9rem;
646
+ }
647
+ }
648
+
649
+ /* Custom styling for buttons */
650
+ #find-sets-btn {
651
+ background: linear-gradient(90deg, #7C3AED 0%, #EC4899 100%);
652
+ color: white !important;
653
+ }
654
+ #find-sets-btn:hover {
655
+ opacity: 0.9;
656
+ }
657
+
658
+ /* Image containers */
659
+ .image-container {
660
+ border-radius: 12px;
661
+ overflow: hidden;
662
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
663
+ }
664
+
665
+ /* Status box styling */
666
+ #status-box {
667
+ font-weight: 500;
668
+ border-radius: 8px;
669
+ }
670
+ """
671
+
672
+ # Create the Gradio interface
673
+ with gr.Blocks(css=css, title="SET Game Detector") as demo:
674
+ # Header
675
+ gr.HTML("""
676
+ <div class="app-header">
677
+ <h1>๐ŸŽด SET Game Detector</h1>
678
+ <p>Upload an image of a SET board to find all valid sets</p>
679
+ </div>
680
+ """)
681
+
682
+ # Model status display
683
+ model_status = gr.Textbox(
684
+ label="Model Status",
685
+ value=get_model_status(),
686
+ interactive=False
687
+ )
688
+ load_models_btn = gr.Button("๐Ÿ”„ Load Models", visible=not _MODELS_LOADED)
689
+
690
+ # Main layout
691
+ with gr.Row():
692
+ with gr.Column():
693
+ input_image = gr.Image(
694
+ label="Upload SET Board Image",
695
+ tool="upload",
696
+ type="numpy",
697
+ elem_id="input-image",
698
+ elem_classes="image-container"
699
+ )
700
+
701
+ with gr.Row():
702
+ process_btn = gr.Button(
703
+ "๐Ÿ”Ž Find Sets",
704
+ variant="primary",
705
+ elem_id="find-sets-btn",
706
+ interactive=_MODELS_LOADED
707
+ )
708
+ clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear", variant="secondary")
709
+
710
+ with gr.Column():
711
+ output_image = gr.Image(
712
+ label="Detected Sets",
713
+ elem_id="output-image",
714
+ elem_classes="image-container",
715
+ interactive=False
716
+ )
717
+ status = gr.Textbox(
718
+ label="Status",
719
+ placeholder="Upload an image and click 'Find Sets'",
720
+ elem_id="status-box",
721
+ interactive=False
722
+ )
723
+
724
+ # Example images section - Create an examples directory for deployment
725
+ examples_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "examples")
726
+ os.makedirs(examples_dir, exist_ok=True)