Spaces:
Sleeping
Sleeping
Updated app.py with new improvements
Browse files
app.py
CHANGED
|
@@ -11,10 +11,10 @@ from itertools import combinations
|
|
| 11 |
from huggingface_hub import hf_hub_download
|
| 12 |
import spaces # For ZeroGPU support
|
| 13 |
|
| 14 |
-
#
|
| 15 |
-
#
|
| 16 |
-
#
|
| 17 |
-
#
|
| 18 |
shape_classification_model = tf.keras.models.load_model(
|
| 19 |
hf_hub_download("Oamitai/shape-classification", "shape_model.keras")
|
| 20 |
)
|
|
@@ -22,7 +22,7 @@ fill_classification_model = tf.keras.models.load_model(
|
|
| 22 |
hf_hub_download("Oamitai/fill-classification", "fill_model.keras")
|
| 23 |
)
|
| 24 |
|
| 25 |
-
# Global YOLO models will be loaded inside the GPU function.
|
| 26 |
global_card_detection_model = None
|
| 27 |
global_shape_detection_model = None
|
| 28 |
|
|
@@ -55,10 +55,10 @@ def restore_original_orientation(image, was_rotated):
|
|
| 55 |
# =============================================================================
|
| 56 |
def predict_color(shape_image):
|
| 57 |
hsv_image = cv2.cvtColor(shape_image, cv2.COLOR_BGR2HSV)
|
| 58 |
-
green_mask = cv2.inRange(hsv_image, np.array([40,50,50]), np.array([80,255,255]))
|
| 59 |
-
purple_mask = cv2.inRange(hsv_image, np.array([120,50,50]), np.array([160,255,255]))
|
| 60 |
-
red_mask1 = cv2.inRange(hsv_image, np.array([0,50,50]), np.array([10,255,255]))
|
| 61 |
-
red_mask2 = cv2.inRange(hsv_image, np.array([170,50,50]), np.array([180,255,255]))
|
| 62 |
red_mask = cv2.bitwise_or(red_mask1, red_mask2)
|
| 63 |
color_counts = {
|
| 64 |
'green': cv2.countNonZero(green_mask),
|
|
@@ -71,14 +71,17 @@ def predict_card_features(card_image, shape_detection_model, fill_model, shape_m
|
|
| 71 |
shape_results = shape_detection_model(card_image)
|
| 72 |
card_height, card_width = card_image.shape[:2]
|
| 73 |
card_area = card_width * card_height
|
|
|
|
| 74 |
filtered_boxes = []
|
| 75 |
for detected_box in shape_results[0].boxes.xyxy.cpu().numpy():
|
| 76 |
x1, y1, x2, y2 = detected_box.astype(int)
|
| 77 |
shape_area = (x2 - x1) * (y2 - y1)
|
| 78 |
if shape_area > 0.03 * card_area:
|
| 79 |
filtered_boxes.append([x1, y1, x2, y2])
|
|
|
|
| 80 |
if len(filtered_boxes) == 0:
|
| 81 |
return {'count': 0, 'color': 'unknown', 'fill': 'unknown', 'shape': 'unknown', 'box': box}
|
|
|
|
| 82 |
fill_input_shape = fill_model.input_shape[1:3]
|
| 83 |
shape_input_shape = shape_model.input_shape[1:3]
|
| 84 |
fill_imgs = []
|
|
@@ -94,16 +97,19 @@ def predict_card_features(card_image, shape_detection_model, fill_model, shape_m
|
|
| 94 |
color_list.append(predict_color(shape_img))
|
| 95 |
fill_imgs = np.array(fill_imgs)
|
| 96 |
shape_imgs = np.array(shape_imgs)
|
|
|
|
| 97 |
fill_preds = fill_model.predict(fill_imgs, batch_size=len(fill_imgs))
|
| 98 |
shape_preds = shape_model.predict(shape_imgs, batch_size=len(shape_imgs))
|
| 99 |
fill_labels_list = ['empty', 'full', 'striped']
|
| 100 |
shape_labels_list = ['diamond', 'oval', 'squiggle']
|
| 101 |
predicted_fill = [fill_labels_list[np.argmax(pred)] for pred in fill_preds]
|
| 102 |
predicted_shape = [shape_labels_list[np.argmax(pred)] for pred in shape_preds]
|
|
|
|
| 103 |
count = min(len(filtered_boxes), 3)
|
| 104 |
final_color = max(set(color_list), key=color_list.count)
|
| 105 |
final_fill = max(set(predicted_fill), key=predicted_fill.count)
|
| 106 |
final_shape = max(set(predicted_shape), key=predicted_shape.count)
|
|
|
|
| 107 |
return {'count': count, 'color': final_color, 'fill': final_fill, 'shape': final_shape, 'box': box}
|
| 108 |
|
| 109 |
def is_set(cards):
|
|
@@ -194,7 +200,7 @@ def detect_and_display_sets_interface(input_image):
|
|
| 194 |
try:
|
| 195 |
image_cv = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
|
| 196 |
global global_card_detection_model, global_shape_detection_model
|
| 197 |
-
#
|
| 198 |
if global_card_detection_model is None:
|
| 199 |
card_model_path = hf_hub_download("Oamitai/card-detection", "best.pt")
|
| 200 |
global_card_detection_model = YOLO(card_model_path)
|
|
|
|
| 11 |
from huggingface_hub import hf_hub_download
|
| 12 |
import spaces # For ZeroGPU support
|
| 13 |
|
| 14 |
+
# =============================================================================
|
| 15 |
+
# MODEL LOADING (Keras Models on CPU)
|
| 16 |
+
# =============================================================================
|
| 17 |
+
# These models can be loaded globally.
|
| 18 |
shape_classification_model = tf.keras.models.load_model(
|
| 19 |
hf_hub_download("Oamitai/shape-classification", "shape_model.keras")
|
| 20 |
)
|
|
|
|
| 22 |
hf_hub_download("Oamitai/fill-classification", "fill_model.keras")
|
| 23 |
)
|
| 24 |
|
| 25 |
+
# Global YOLO models will be loaded lazily inside the GPU function.
|
| 26 |
global_card_detection_model = None
|
| 27 |
global_shape_detection_model = None
|
| 28 |
|
|
|
|
| 55 |
# =============================================================================
|
| 56 |
def predict_color(shape_image):
|
| 57 |
hsv_image = cv2.cvtColor(shape_image, cv2.COLOR_BGR2HSV)
|
| 58 |
+
green_mask = cv2.inRange(hsv_image, np.array([40, 50, 50]), np.array([80, 255, 255]))
|
| 59 |
+
purple_mask = cv2.inRange(hsv_image, np.array([120, 50, 50]), np.array([160, 255, 255]))
|
| 60 |
+
red_mask1 = cv2.inRange(hsv_image, np.array([0, 50, 50]), np.array([10, 255, 255]))
|
| 61 |
+
red_mask2 = cv2.inRange(hsv_image, np.array([170, 50, 50]), np.array([180, 255, 255]))
|
| 62 |
red_mask = cv2.bitwise_or(red_mask1, red_mask2)
|
| 63 |
color_counts = {
|
| 64 |
'green': cv2.countNonZero(green_mask),
|
|
|
|
| 71 |
shape_results = shape_detection_model(card_image)
|
| 72 |
card_height, card_width = card_image.shape[:2]
|
| 73 |
card_area = card_width * card_height
|
| 74 |
+
|
| 75 |
filtered_boxes = []
|
| 76 |
for detected_box in shape_results[0].boxes.xyxy.cpu().numpy():
|
| 77 |
x1, y1, x2, y2 = detected_box.astype(int)
|
| 78 |
shape_area = (x2 - x1) * (y2 - y1)
|
| 79 |
if shape_area > 0.03 * card_area:
|
| 80 |
filtered_boxes.append([x1, y1, x2, y2])
|
| 81 |
+
|
| 82 |
if len(filtered_boxes) == 0:
|
| 83 |
return {'count': 0, 'color': 'unknown', 'fill': 'unknown', 'shape': 'unknown', 'box': box}
|
| 84 |
+
|
| 85 |
fill_input_shape = fill_model.input_shape[1:3]
|
| 86 |
shape_input_shape = shape_model.input_shape[1:3]
|
| 87 |
fill_imgs = []
|
|
|
|
| 97 |
color_list.append(predict_color(shape_img))
|
| 98 |
fill_imgs = np.array(fill_imgs)
|
| 99 |
shape_imgs = np.array(shape_imgs)
|
| 100 |
+
|
| 101 |
fill_preds = fill_model.predict(fill_imgs, batch_size=len(fill_imgs))
|
| 102 |
shape_preds = shape_model.predict(shape_imgs, batch_size=len(shape_imgs))
|
| 103 |
fill_labels_list = ['empty', 'full', 'striped']
|
| 104 |
shape_labels_list = ['diamond', 'oval', 'squiggle']
|
| 105 |
predicted_fill = [fill_labels_list[np.argmax(pred)] for pred in fill_preds]
|
| 106 |
predicted_shape = [shape_labels_list[np.argmax(pred)] for pred in shape_preds]
|
| 107 |
+
|
| 108 |
count = min(len(filtered_boxes), 3)
|
| 109 |
final_color = max(set(color_list), key=color_list.count)
|
| 110 |
final_fill = max(set(predicted_fill), key=predicted_fill.count)
|
| 111 |
final_shape = max(set(predicted_shape), key=predicted_shape.count)
|
| 112 |
+
|
| 113 |
return {'count': count, 'color': final_color, 'fill': final_fill, 'shape': final_shape, 'box': box}
|
| 114 |
|
| 115 |
def is_set(cards):
|
|
|
|
| 200 |
try:
|
| 201 |
image_cv = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
|
| 202 |
global global_card_detection_model, global_shape_detection_model
|
| 203 |
+
# Lazy load YOLO models on GPU after allocation.
|
| 204 |
if global_card_detection_model is None:
|
| 205 |
card_model_path = hf_hub_download("Oamitai/card-detection", "best.pt")
|
| 206 |
global_card_detection_model = YOLO(card_model_path)
|