Oamitai commited on
Commit
9b24efb
·
1 Parent(s): 956e5ab

Updated app.py with new improvements

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -11,10 +11,10 @@ from itertools import combinations
11
  from huggingface_hub import hf_hub_download
12
  import spaces # For ZeroGPU support
13
 
14
- # -----------------------------------------------------------------------------
15
- # Global variables for models (loaded lazily inside the GPU function)
16
- # Keras models can be loaded globally if they don't trigger CUDA initialization.
17
- # We'll load the Keras models here since they typically run on CPU and use GPU later.
18
  shape_classification_model = tf.keras.models.load_model(
19
  hf_hub_download("Oamitai/shape-classification", "shape_model.keras")
20
  )
@@ -22,7 +22,7 @@ fill_classification_model = tf.keras.models.load_model(
22
  hf_hub_download("Oamitai/fill-classification", "fill_model.keras")
23
  )
24
 
25
- # Global YOLO models will be loaded inside the GPU function.
26
  global_card_detection_model = None
27
  global_shape_detection_model = None
28
 
@@ -55,10 +55,10 @@ def restore_original_orientation(image, was_rotated):
55
  # =============================================================================
56
  def predict_color(shape_image):
57
  hsv_image = cv2.cvtColor(shape_image, cv2.COLOR_BGR2HSV)
58
- green_mask = cv2.inRange(hsv_image, np.array([40,50,50]), np.array([80,255,255]))
59
- purple_mask = cv2.inRange(hsv_image, np.array([120,50,50]), np.array([160,255,255]))
60
- red_mask1 = cv2.inRange(hsv_image, np.array([0,50,50]), np.array([10,255,255]))
61
- red_mask2 = cv2.inRange(hsv_image, np.array([170,50,50]), np.array([180,255,255]))
62
  red_mask = cv2.bitwise_or(red_mask1, red_mask2)
63
  color_counts = {
64
  'green': cv2.countNonZero(green_mask),
@@ -71,14 +71,17 @@ def predict_card_features(card_image, shape_detection_model, fill_model, shape_m
71
  shape_results = shape_detection_model(card_image)
72
  card_height, card_width = card_image.shape[:2]
73
  card_area = card_width * card_height
 
74
  filtered_boxes = []
75
  for detected_box in shape_results[0].boxes.xyxy.cpu().numpy():
76
  x1, y1, x2, y2 = detected_box.astype(int)
77
  shape_area = (x2 - x1) * (y2 - y1)
78
  if shape_area > 0.03 * card_area:
79
  filtered_boxes.append([x1, y1, x2, y2])
 
80
  if len(filtered_boxes) == 0:
81
  return {'count': 0, 'color': 'unknown', 'fill': 'unknown', 'shape': 'unknown', 'box': box}
 
82
  fill_input_shape = fill_model.input_shape[1:3]
83
  shape_input_shape = shape_model.input_shape[1:3]
84
  fill_imgs = []
@@ -94,16 +97,19 @@ def predict_card_features(card_image, shape_detection_model, fill_model, shape_m
94
  color_list.append(predict_color(shape_img))
95
  fill_imgs = np.array(fill_imgs)
96
  shape_imgs = np.array(shape_imgs)
 
97
  fill_preds = fill_model.predict(fill_imgs, batch_size=len(fill_imgs))
98
  shape_preds = shape_model.predict(shape_imgs, batch_size=len(shape_imgs))
99
  fill_labels_list = ['empty', 'full', 'striped']
100
  shape_labels_list = ['diamond', 'oval', 'squiggle']
101
  predicted_fill = [fill_labels_list[np.argmax(pred)] for pred in fill_preds]
102
  predicted_shape = [shape_labels_list[np.argmax(pred)] for pred in shape_preds]
 
103
  count = min(len(filtered_boxes), 3)
104
  final_color = max(set(color_list), key=color_list.count)
105
  final_fill = max(set(predicted_fill), key=predicted_fill.count)
106
  final_shape = max(set(predicted_shape), key=predicted_shape.count)
 
107
  return {'count': count, 'color': final_color, 'fill': final_fill, 'shape': final_shape, 'box': box}
108
 
109
  def is_set(cards):
@@ -194,7 +200,7 @@ def detect_and_display_sets_interface(input_image):
194
  try:
195
  image_cv = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
196
  global global_card_detection_model, global_shape_detection_model
197
- # Load YOLO models on GPU only once after allocation.
198
  if global_card_detection_model is None:
199
  card_model_path = hf_hub_download("Oamitai/card-detection", "best.pt")
200
  global_card_detection_model = YOLO(card_model_path)
 
11
  from huggingface_hub import hf_hub_download
12
  import spaces # For ZeroGPU support
13
 
14
+ # =============================================================================
15
+ # MODEL LOADING (Keras Models on CPU)
16
+ # =============================================================================
17
+ # These models can be loaded globally.
18
  shape_classification_model = tf.keras.models.load_model(
19
  hf_hub_download("Oamitai/shape-classification", "shape_model.keras")
20
  )
 
22
  hf_hub_download("Oamitai/fill-classification", "fill_model.keras")
23
  )
24
 
25
+ # Global YOLO models will be loaded lazily inside the GPU function.
26
  global_card_detection_model = None
27
  global_shape_detection_model = None
28
 
 
55
  # =============================================================================
56
  def predict_color(shape_image):
57
  hsv_image = cv2.cvtColor(shape_image, cv2.COLOR_BGR2HSV)
58
+ green_mask = cv2.inRange(hsv_image, np.array([40, 50, 50]), np.array([80, 255, 255]))
59
+ purple_mask = cv2.inRange(hsv_image, np.array([120, 50, 50]), np.array([160, 255, 255]))
60
+ red_mask1 = cv2.inRange(hsv_image, np.array([0, 50, 50]), np.array([10, 255, 255]))
61
+ red_mask2 = cv2.inRange(hsv_image, np.array([170, 50, 50]), np.array([180, 255, 255]))
62
  red_mask = cv2.bitwise_or(red_mask1, red_mask2)
63
  color_counts = {
64
  'green': cv2.countNonZero(green_mask),
 
71
  shape_results = shape_detection_model(card_image)
72
  card_height, card_width = card_image.shape[:2]
73
  card_area = card_width * card_height
74
+
75
  filtered_boxes = []
76
  for detected_box in shape_results[0].boxes.xyxy.cpu().numpy():
77
  x1, y1, x2, y2 = detected_box.astype(int)
78
  shape_area = (x2 - x1) * (y2 - y1)
79
  if shape_area > 0.03 * card_area:
80
  filtered_boxes.append([x1, y1, x2, y2])
81
+
82
  if len(filtered_boxes) == 0:
83
  return {'count': 0, 'color': 'unknown', 'fill': 'unknown', 'shape': 'unknown', 'box': box}
84
+
85
  fill_input_shape = fill_model.input_shape[1:3]
86
  shape_input_shape = shape_model.input_shape[1:3]
87
  fill_imgs = []
 
97
  color_list.append(predict_color(shape_img))
98
  fill_imgs = np.array(fill_imgs)
99
  shape_imgs = np.array(shape_imgs)
100
+
101
  fill_preds = fill_model.predict(fill_imgs, batch_size=len(fill_imgs))
102
  shape_preds = shape_model.predict(shape_imgs, batch_size=len(shape_imgs))
103
  fill_labels_list = ['empty', 'full', 'striped']
104
  shape_labels_list = ['diamond', 'oval', 'squiggle']
105
  predicted_fill = [fill_labels_list[np.argmax(pred)] for pred in fill_preds]
106
  predicted_shape = [shape_labels_list[np.argmax(pred)] for pred in shape_preds]
107
+
108
  count = min(len(filtered_boxes), 3)
109
  final_color = max(set(color_list), key=color_list.count)
110
  final_fill = max(set(predicted_fill), key=predicted_fill.count)
111
  final_shape = max(set(predicted_shape), key=predicted_shape.count)
112
+
113
  return {'count': count, 'color': final_color, 'fill': final_fill, 'shape': final_shape, 'box': box}
114
 
115
  def is_set(cards):
 
200
  try:
201
  image_cv = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
202
  global global_card_detection_model, global_shape_detection_model
203
+ # Lazy load YOLO models on GPU after allocation.
204
  if global_card_detection_model is None:
205
  card_model_path = hf_hub_download("Oamitai/card-detection", "best.pt")
206
  global_card_detection_model = YOLO(card_model_path)