Oamitai commited on
Commit
c15a3d6
·
verified ·
1 Parent(s): 7e5ec53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -38
app.py CHANGED
@@ -14,6 +14,16 @@ import time
14
  from typing import Tuple, List, Dict
15
  import logging
16
 
 
 
 
 
 
 
 
 
 
 
17
  # Import spaces correctly for ZeroGPU
18
  try:
19
  import spaces
@@ -57,37 +67,56 @@ def load_models():
57
 
58
  logger.info("Loading models from Hugging Face Hub...")
59
 
60
- # Load YOLO Card Detection Model
61
- card_model_path = hf_hub_download("Oamitai/card-detection", "best.pt")
62
- card_detector = YOLO(card_model_path)
63
- card_detector.conf = 0.5
64
-
65
- # Load YOLO Shape Detection Model
66
- shape_model_path = hf_hub_download("Oamitai/shape-detection", "best.pt")
67
- shape_detector = YOLO(shape_model_path)
68
- shape_detector.conf = 0.5
69
-
70
- # Load Shape Classification Model
71
- shape_classifier = tf.keras.models.load_model(
72
- hf_hub_download("Oamitai/shape-classification", "shape_model.keras")
73
- )
74
-
75
- # Load Fill Classification Model
76
- fill_classifier = tf.keras.models.load_model(
77
- hf_hub_download("Oamitai/fill-classification", "fill_model.keras")
78
- )
 
 
79
 
80
- # Use GPU if available
81
- if torch.cuda.is_available():
82
- logger.info("CUDA is available. Using GPU for inference.")
83
- card_detector.to("cuda")
84
- shape_detector.to("cuda")
85
 
86
- # Cache the models
87
- _CARD_DETECTOR = card_detector
88
- _SHAPE_DETECTOR = shape_detector
89
- _SHAPE_CLASSIFIER = shape_classifier
90
- _FILL_CLASSIFIER = fill_classifier
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  logger.info("All models loaded successfully!")
93
  return card_detector, shape_detector, shape_classifier, fill_classifier
@@ -213,8 +242,31 @@ def predict_card_features(
213
  color_candidates.append(predict_color(shape_crop))
214
 
215
  # Use verbose=0 to suppress progress bar
216
- fill_preds = fill_model.predict(np.array(fill_imgs), batch_size=len(fill_imgs), verbose=0)
217
- shape_preds = shape_model.predict(np.array(shape_imgs), batch_size=len(shape_imgs), verbose=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  fill_labels = ['empty', 'full', 'striped']
220
  shape_labels = ['diamond', 'oval', 'squiggle']
@@ -223,9 +275,20 @@ def predict_card_features(
223
  shape_result = [shape_labels[np.argmax(sp)] for sp in shape_preds]
224
 
225
  # Take the most common color/fill/shape across all shape detections for the card
226
- final_color = max(set(color_candidates), key=color_candidates.count)
227
- final_fill = max(set(fill_result), key=fill_result.count)
228
- final_shape = max(set(shape_result), key=shape_result.count)
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  return {
231
  'count': len(shape_boxes),
@@ -349,11 +412,11 @@ def optimize_image_size(image_array: np.ndarray, max_dim=1200) -> np.ndarray:
349
  return cv2.resize(image_array, (new_width, new_height), interpolation=cv2.INTER_AREA)
350
  return image_array
351
 
352
- @spaces.GPU
353
  def process_image(input_image):
354
  """
355
- Main processing function for SET detection.
356
- Takes an input image, processes it, and returns the annotated image and status.
357
  """
358
  if input_image is None:
359
  return None, "Please upload an image."
@@ -364,6 +427,10 @@ def process_image(input_image):
364
  # Load models
365
  card_detector, shape_detector, shape_model, fill_model = load_models()
366
 
 
 
 
 
367
  # Optimize image size
368
  optimized_img = optimize_image_size(input_image)
369
 
@@ -407,6 +474,15 @@ def process_image(input_image):
407
  logger.error(traceback.format_exc())
408
  return input_image, error_message
409
 
 
 
 
 
 
 
 
 
 
410
  # =============================================================================
411
  # GRADIO INTERFACE
412
  # =============================================================================
@@ -442,7 +518,7 @@ with gr.Blocks(title="SET Game Detector") as demo:
442
 
443
  # Function bindings inside the Blocks context
444
  find_sets_btn.click(
445
- fn=process_image,
446
  inputs=[input_image],
447
  outputs=[output_image, status]
448
  )
 
14
  from typing import Tuple, List, Dict
15
  import logging
16
 
17
+ # Configure TensorFlow GPU memory growth to prevent memory conflicts
18
+ try:
19
+ gpus = tf.config.list_physical_devices('GPU')
20
+ if gpus:
21
+ for gpu in gpus:
22
+ tf.config.experimental.set_memory_growth(gpu, True)
23
+ print(f"Found {len(gpus)} GPU(s), memory growth enabled")
24
+ except Exception as e:
25
+ print(f"Error configuring GPU: {e}")
26
+
27
  # Import spaces correctly for ZeroGPU
28
  try:
29
  import spaces
 
67
 
68
  logger.info("Loading models from Hugging Face Hub...")
69
 
70
+ # Load TensorFlow models first to better manage GPU memory
71
+ try:
72
+ # Load Shape Classification Model
73
+ logger.info("Loading shape classification model...")
74
+ shape_classifier = tf.keras.models.load_model(
75
+ hf_hub_download("Oamitai/shape-classification", "shape_model.keras")
76
+ )
77
+
78
+ # Load Fill Classification Model
79
+ logger.info("Loading fill classification model...")
80
+ fill_classifier = tf.keras.models.load_model(
81
+ hf_hub_download("Oamitai/fill-classification", "fill_model.keras")
82
+ )
83
+
84
+ # Cache TensorFlow models
85
+ _SHAPE_CLASSIFIER = shape_classifier
86
+ _FILL_CLASSIFIER = fill_classifier
87
+
88
+ except Exception as tf_error:
89
+ logger.error(f"Error loading TensorFlow models: {tf_error}")
90
+ raise
91
 
92
+ # Add a small delay to ensure TensorFlow releases GPU resources
93
+ time.sleep(0.5)
 
 
 
94
 
95
+ # Now load YOLO models
96
+ try:
97
+ # Load YOLO Card Detection Model
98
+ logger.info("Loading card detection model...")
99
+ card_model_path = hf_hub_download("Oamitai/card-detection", "best.pt")
100
+ card_detector = YOLO(card_model_path)
101
+ card_detector.conf = 0.5
102
+
103
+ # Load YOLO Shape Detection Model
104
+ logger.info("Loading shape detection model...")
105
+ shape_model_path = hf_hub_download("Oamitai/shape-detection", "best.pt")
106
+ shape_detector = YOLO(shape_model_path)
107
+ shape_detector.conf = 0.5
108
+
109
+ # Use CPU initially for YOLO models to avoid GPU conflicts
110
+ card_detector.to("cpu")
111
+ shape_detector.to("cpu")
112
+
113
+ # Cache PyTorch models
114
+ _CARD_DETECTOR = card_detector
115
+ _SHAPE_DETECTOR = shape_detector
116
+
117
+ except Exception as pt_error:
118
+ logger.error(f"Error loading PyTorch models: {pt_error}")
119
+ raise
120
 
121
  logger.info("All models loaded successfully!")
122
  return card_detector, shape_detector, shape_classifier, fill_classifier
 
242
  color_candidates.append(predict_color(shape_crop))
243
 
244
  # Use verbose=0 to suppress progress bar
245
+ # Add error handling for TensorFlow prediction
246
+ try:
247
+ fill_preds = fill_model.predict(np.array(fill_imgs), batch_size=len(fill_imgs), verbose=0)
248
+ shape_preds = shape_model.predict(np.array(shape_imgs), batch_size=len(shape_imgs), verbose=0)
249
+ except Exception as e:
250
+ logger.error(f"Error during TensorFlow prediction: {e}")
251
+ # Try with batch size of 1 as fallback
252
+ fill_preds = []
253
+ shape_preds = []
254
+
255
+ for img in fill_imgs:
256
+ try:
257
+ pred = fill_model.predict(np.array([img]), verbose=0)
258
+ fill_preds.append(pred[0])
259
+ except Exception as e2:
260
+ logger.error(f"Fill prediction error: {e2}")
261
+ fill_preds.append(np.array([0.33, 0.33, 0.34])) # Fallback with uniform probabilities
262
+
263
+ for img in shape_imgs:
264
+ try:
265
+ pred = shape_model.predict(np.array([img]), verbose=0)
266
+ shape_preds.append(pred[0])
267
+ except Exception as e2:
268
+ logger.error(f"Shape prediction error: {e2}")
269
+ shape_preds.append(np.array([0.33, 0.33, 0.34])) # Fallback with uniform probabilities
270
 
271
  fill_labels = ['empty', 'full', 'striped']
272
  shape_labels = ['diamond', 'oval', 'squiggle']
 
275
  shape_result = [shape_labels[np.argmax(sp)] for sp in shape_preds]
276
 
277
  # Take the most common color/fill/shape across all shape detections for the card
278
+ if color_candidates:
279
+ final_color = max(set(color_candidates), key=color_candidates.count)
280
+ else:
281
+ final_color = "unknown"
282
+
283
+ if fill_result:
284
+ final_fill = max(set(fill_result), key=fill_result.count)
285
+ else:
286
+ final_fill = "unknown"
287
+
288
+ if shape_result:
289
+ final_shape = max(set(shape_result), key=shape_result.count)
290
+ else:
291
+ final_shape = "unknown"
292
 
293
  return {
294
  'count': len(shape_boxes),
 
412
  return cv2.resize(image_array, (new_width, new_height), interpolation=cv2.INTER_AREA)
413
  return image_array
414
 
415
+ # Make the process_image function use CPU for PyTorch models
416
  def process_image(input_image):
417
  """
418
+ Main processing function for SET detection using CPU for both PyTorch
419
+ and TensorFlow to avoid GPU conflicts.
420
  """
421
  if input_image is None:
422
  return None, "Please upload an image."
 
427
  # Load models
428
  card_detector, shape_detector, shape_model, fill_model = load_models()
429
 
430
+ # Force CPU mode for YOLO models (PyTorch)
431
+ card_detector.to("cpu")
432
+ shape_detector.to("cpu")
433
+
434
  # Optimize image size
435
  optimized_img = optimize_image_size(input_image)
436
 
 
474
  logger.error(traceback.format_exc())
475
  return input_image, error_message
476
 
477
+ # Wrap the CPU-based function with spaces.GPU
478
+ @spaces.GPU
479
+ def process_image_with_gpu(input_image):
480
+ """
481
+ Wrapper function that uses spaces.GPU decorator but internally
482
+ uses CPU processing to avoid GPU conflicts.
483
+ """
484
+ return process_image(input_image)
485
+
486
  # =============================================================================
487
  # GRADIO INTERFACE
488
  # =============================================================================
 
518
 
519
  # Function bindings inside the Blocks context
520
  find_sets_btn.click(
521
+ fn=process_image_with_gpu, # Use the wrapper function
522
  inputs=[input_image],
523
  outputs=[output_image, status]
524
  )