Oamitai commited on
Commit
e702eec
·
1 Parent(s): 9e7f164

Updated app.py with new improvements

Browse files
Files changed (2) hide show
  1. app.py +291 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import torch
5
+ from ultralytics import YOLO
6
+ from PIL import Image
7
+ import gradio as gr
8
+ import traceback
9
+ import pandas as pd
10
+ from itertools import combinations
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ # =============================================================================
14
+ # MODEL LOADING
15
+ # =============================================================================
16
+
17
+ # Load YOLO Card Detection Model
18
+ card_model_path = hf_hub_download("Oamitai/card-detection", "best.pt")
19
+ card_detection_model = YOLO(card_model_path)
20
+ card_detection_model.conf = 0.5
21
+
22
+ # Load YOLO Shape Detection Model
23
+ shape_model_path = hf_hub_download("Oamitai/shape-detection", "best.pt")
24
+ shape_detection_model = YOLO(shape_model_path)
25
+ shape_detection_model.conf = 0.5
26
+
27
+ # Load Shape Classification Model (Keras)
28
+ shape_classification_model = tf.keras.models.load_model(
29
+ hf_hub_download("Oamitai/shape-classification", "shape_model.keras")
30
+ )
31
+
32
+ # Load Fill Classification Model (Keras)
33
+ fill_classification_model = tf.keras.models.load_model(
34
+ hf_hub_download("Oamitai/fill-classification", "fill_model.keras")
35
+ )
36
+
37
+ # =============================================================================
38
+ # ORIENTATION CORRECTION FUNCTIONS
39
+ # =============================================================================
40
+ def check_and_rotate_input_image(board_image, card_detection_model):
41
+ """
42
+ Checks the orientation of the board image by detecting card bounding boxes.
43
+ If the average detected card height is greater than the average card width,
44
+ rotates the image 90° clockwise.
45
+ """
46
+ card_results = card_detection_model(board_image)
47
+ card_boxes = card_results[0].boxes.xyxy.cpu().numpy().astype(int)
48
+
49
+ # If no cards are detected, assume no rotation is needed.
50
+ if len(card_boxes) == 0:
51
+ return board_image, False
52
+
53
+ total_width = total_height = count = 0
54
+ for box in card_boxes:
55
+ x1, y1, x2, y2 = box
56
+ total_width += (x2 - x1)
57
+ total_height += (y2 - y1)
58
+ count += 1
59
+
60
+ avg_width = total_width / count
61
+ avg_height = total_height / count
62
+
63
+ if avg_height > avg_width:
64
+ rotated_image = cv2.rotate(board_image, cv2.ROTATE_90_CLOCKWISE)
65
+ return rotated_image, True
66
+ else:
67
+ return board_image, False
68
+
69
+ def restore_original_orientation(image, was_rotated):
70
+ """
71
+ If the image was rotated for processing, rotate it back to the original orientation.
72
+ """
73
+ if was_rotated:
74
+ return cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
75
+ return image
76
+
77
+ # =============================================================================
78
+ # PREDICTION FUNCTIONS
79
+ # =============================================================================
80
+ def predict_color(shape_image):
81
+ """
82
+ Predict the dominant color (green, purple, or red) using HSV thresholds.
83
+ """
84
+ hsv_image = cv2.cvtColor(shape_image, cv2.COLOR_BGR2HSV)
85
+ # Define HSV ranges
86
+ green_mask = cv2.inRange(hsv_image, np.array([40, 50, 50]), np.array([80, 255, 255]))
87
+ purple_mask = cv2.inRange(hsv_image, np.array([120, 50, 50]), np.array([160, 255, 255]))
88
+ red_mask1 = cv2.inRange(hsv_image, np.array([0, 50, 50]), np.array([10, 255, 255]))
89
+ red_mask2 = cv2.inRange(hsv_image, np.array([170, 50, 50]), np.array([180, 255, 255]))
90
+ red_mask = cv2.bitwise_or(red_mask1, red_mask2)
91
+ # Count non-zero pixels in each mask
92
+ color_counts = {
93
+ 'green': cv2.countNonZero(green_mask),
94
+ 'purple': cv2.countNonZero(purple_mask),
95
+ 'red': cv2.countNonZero(red_mask)
96
+ }
97
+ return max(color_counts, key=color_counts.get)
98
+
99
+ def predict_card_features(card_image, shape_detection_model, fill_model, shape_model, box):
100
+ """
101
+ For a given card image, detect shapes and predict fill and shape attributes.
102
+ Returns a dictionary of features.
103
+ """
104
+ shape_results = shape_detection_model(card_image)
105
+ card_height, card_width = card_image.shape[:2]
106
+ card_area = card_width * card_height
107
+
108
+ # Filter detections that are too small (less than 3% of card area)
109
+ filtered_boxes = []
110
+ for detected_box in shape_results[0].boxes.xyxy.cpu().numpy():
111
+ x1, y1, x2, y2 = detected_box.astype(int)
112
+ shape_area = (x2 - x1) * (y2 - y1)
113
+ if shape_area > 0.03 * card_area:
114
+ filtered_boxes.append(detected_box)
115
+
116
+ count = min(len(filtered_boxes), 3)
117
+ color_labels, fill_labels, shape_labels = [], [], []
118
+
119
+ for shape_box in filtered_boxes:
120
+ shape_box = shape_box.astype(int)
121
+ shape_img = card_image[shape_box[1]:shape_box[3], shape_box[0]:shape_box[2]]
122
+ # Preprocess images for classification models
123
+ fill_input_shape = fill_model.input_shape[1:3]
124
+ shape_input_shape = shape_model.input_shape[1:3]
125
+ fill_img = cv2.resize(shape_img, fill_input_shape) / 255.0
126
+ shape_img_resized = cv2.resize(shape_img, shape_input_shape) / 255.0
127
+ fill_img = np.expand_dims(fill_img, axis=0)
128
+ shape_img_resized = np.expand_dims(shape_img_resized, axis=0)
129
+ # Make predictions
130
+ fill_pred = fill_model.predict(fill_img)
131
+ shape_pred = shape_model.predict(shape_img_resized)
132
+ fill_labels.append(['empty', 'full', 'striped'][np.argmax(fill_pred)])
133
+ shape_labels.append(['diamond', 'oval', 'squiggle'][np.argmax(shape_pred)])
134
+ color_labels.append(predict_color(shape_img))
135
+
136
+ if count > 0:
137
+ color_label = max(set(color_labels), key=color_labels.count)
138
+ fill_label = max(set(fill_labels), key=fill_labels.count)
139
+ shape_label = max(set(shape_labels), key=shape_labels.count)
140
+ else:
141
+ color_label = fill_label = shape_label = 'unknown'
142
+
143
+ return {
144
+ 'count': count,
145
+ 'color': color_label,
146
+ 'fill': fill_label,
147
+ 'shape': shape_label,
148
+ 'box': box
149
+ }
150
+
151
+ def is_set(cards):
152
+ """
153
+ Check if a combination of cards forms a valid set.
154
+ For each feature, values must be either all the same or all different.
155
+ """
156
+ for feature in ['Count', 'Color', 'Fill', 'Shape']:
157
+ if len({card[feature] for card in cards}) not in [1, 3]:
158
+ return False
159
+ return True
160
+
161
+ def find_sets(card_df):
162
+ """
163
+ Examine every combination of three cards from the DataFrame and return valid sets.
164
+ """
165
+ sets_found = []
166
+ for combo in combinations(card_df.iterrows(), 3):
167
+ cards = [entry[1] for entry in combo]
168
+ if is_set(cards):
169
+ set_info = {
170
+ 'set_indices': [entry[0] for entry in combo],
171
+ 'cards': [{feature: card[feature] for feature in ['Count', 'Color', 'Fill', 'Shape', 'Coordinates']} for card in cards]
172
+ }
173
+ sets_found.append(set_info)
174
+ return sets_found
175
+
176
+ def detect_cards_from_image(board_image, card_detection_model):
177
+ """
178
+ Use the YOLO card detection model to detect cards on the board image.
179
+ Returns a list of tuples: (cropped card image, bounding box).
180
+ """
181
+ card_results = card_detection_model(board_image)
182
+ card_boxes = card_results[0].boxes.xyxy.cpu().numpy().astype(int)
183
+ cards = []
184
+ for box in card_boxes:
185
+ x1, y1, x2, y2 = box
186
+ card_img = board_image[y1:y2, x1:x2]
187
+ cards.append((card_img, box))
188
+ return cards
189
+
190
+ def classify_cards_from_board_image(board_image, card_detection_model, shape_detection_model, fill_model, shape_model):
191
+ """
192
+ For each detected card on the board image, predict its features.
193
+ Returns a pandas DataFrame of card feature data.
194
+ """
195
+ cards = detect_cards_from_image(board_image, card_detection_model)
196
+ card_data = []
197
+ for card_image, box in cards:
198
+ features = predict_card_features(card_image, shape_detection_model, fill_model, shape_model, box)
199
+ card_data.append({
200
+ "Count": features['count'],
201
+ "Color": features['color'],
202
+ "Fill": features['fill'],
203
+ "Shape": features['shape'],
204
+ "Coordinates": f"{box[0]}, {box[1]}, {box[2]}, {box[3]}"
205
+ })
206
+ return pd.DataFrame(card_data)
207
+
208
+ def classify_and_find_sets_from_array(board_image, card_detection_model, shape_detection_model, fill_model, shape_model):
209
+ """
210
+ Processes a board image (in BGR format), corrects its orientation, detects cards,
211
+ classifies them, finds valid sets, and finally restores the original orientation.
212
+ Returns a tuple: (sets_found, annotated image).
213
+ """
214
+ board_image, was_rotated = check_and_rotate_input_image(board_image, card_detection_model)
215
+ card_df = classify_cards_from_board_image(board_image, card_detection_model, shape_detection_model, fill_model, shape_model)
216
+ sets_found = find_sets(card_df)
217
+ annotated_image = draw_sets_on_image(board_image.copy(), sets_found)
218
+ final_image = restore_original_orientation(annotated_image, was_rotated)
219
+ return sets_found, final_image
220
+
221
+ # =============================================================================
222
+ # DRAWING FUNCTIONS
223
+ # =============================================================================
224
+ def draw_sets_on_image(board_image, sets_info):
225
+ """
226
+ Draw bounding boxes and set labels on the board image for each detected set.
227
+ """
228
+ colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255),
229
+ (255, 255, 0), (255, 0, 255), (0, 255, 255)]
230
+ base_thickness = 8
231
+ base_expansion = 5
232
+ for index, set_info in enumerate(sets_info):
233
+ color = colors[index % len(colors)]
234
+ thickness = base_thickness + 2 * index
235
+ expansion = base_expansion + 15 * index
236
+ for i, card in enumerate(set_info['cards']):
237
+ coordinates = list(map(int, card['Coordinates'].split(',')))
238
+ x1, y1, x2, y2 = coordinates
239
+ x1_expanded = max(0, x1 - expansion)
240
+ y1_expanded = max(0, y1 - expansion)
241
+ x2_expanded = min(board_image.shape[1], x2 + expansion)
242
+ y2_expanded = min(board_image.shape[0], y2 + expansion)
243
+ cv2.rectangle(board_image, (x1_expanded, y1_expanded), (x2_expanded, y2_expanded), color, thickness)
244
+ if i == 0:
245
+ cv2.putText(board_image, f"Set {index + 1}", (x1_expanded, y1_expanded - 10),
246
+ cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, thickness)
247
+ return board_image
248
+
249
+ # =============================================================================
250
+ # GRADIO INTERFACE FUNCTION
251
+ # =============================================================================
252
+ def detect_and_display_sets_interface(input_image):
253
+ """
254
+ Gradio interface function:
255
+ - Accepts a PIL image (board image)
256
+ - Converts it to a cv2 BGR image
257
+ - Processes it for set detection
258
+ - Returns the annotated image (as PIL) and a status message.
259
+ """
260
+ try:
261
+ # Convert input PIL image to OpenCV (BGR) format.
262
+ image_cv = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
263
+ sets_found, final_image = classify_and_find_sets_from_array(
264
+ image_cv,
265
+ card_detection_model,
266
+ shape_detection_model,
267
+ fill_classification_model,
268
+ shape_classification_model
269
+ )
270
+ # Convert back to RGB for display.
271
+ final_image_rgb = cv2.cvtColor(final_image, cv2.COLOR_BGR2RGB)
272
+ return Image.fromarray(final_image_rgb), "Sets detected successfully."
273
+ except Exception as e:
274
+ err = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
275
+ # Return a blank image with error details.
276
+ return Image.fromarray(np.zeros((100, 100, 3), dtype=np.uint8)), err
277
+
278
+ # =============================================================================
279
+ # LAUNCH GRADIO
280
+ # =============================================================================
281
+ iface = gr.Interface(
282
+ fn=detect_and_display_sets_interface,
283
+ inputs=gr.Image(type="pil", label="Upload Board Image"),
284
+ outputs=[gr.Image(type="pil", label="Annotated Image"), gr.Textbox(label="Status")],
285
+ title="Set Game Detector",
286
+ description=("Upload an image of a Set game board to detect cards, "
287
+ "classify their features, and highlight valid sets.")
288
+ )
289
+
290
+ if __name__ == "__main__":
291
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ tensorflow
3
+ torch
4
+ ultralytics
5
+ opencv-python-headless
6
+ numpy
7
+ Pillow
8
+ huggingface_hub
9
+ pandas