Oamitai commited on
Commit
844e48a
·
verified ·
1 Parent(s): c15a3d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -104
app.py CHANGED
@@ -7,24 +7,17 @@ from tensorflow.keras.models import load_model
7
  import torch
8
  from ultralytics import YOLO
9
  from itertools import combinations
10
- from pathlib import Path
11
  import gradio as gr
12
  import traceback
13
  import time
14
  from typing import Tuple, List, Dict
15
  import logging
16
 
17
- # Configure TensorFlow GPU memory growth to prevent memory conflicts
18
- try:
19
- gpus = tf.config.list_physical_devices('GPU')
20
- if gpus:
21
- for gpu in gpus:
22
- tf.config.experimental.set_memory_growth(gpu, True)
23
- print(f"Found {len(gpus)} GPU(s), memory growth enabled")
24
- except Exception as e:
25
- print(f"Error configuring GPU: {e}")
26
-
27
- # Import spaces correctly for ZeroGPU
28
  try:
29
  import spaces
30
  except ImportError:
@@ -52,7 +45,7 @@ _FILL_CLASSIFIER = None
52
 
53
  def load_models():
54
  """
55
- Load all models needed for SET detection.
56
  Returns tuple of (card_detector, shape_detector, shape_classifier, fill_classifier)
57
  """
58
  global _CARD_DETECTOR, _SHAPE_DETECTOR, _SHAPE_CLASSIFIER, _FILL_CLASSIFIER
@@ -67,58 +60,42 @@ def load_models():
67
 
68
  logger.info("Loading models from Hugging Face Hub...")
69
 
70
- # Load TensorFlow models first to better manage GPU memory
71
- try:
72
- # Load Shape Classification Model
73
- logger.info("Loading shape classification model...")
74
- shape_classifier = tf.keras.models.load_model(
75
- hf_hub_download("Oamitai/shape-classification", "shape_model.keras")
76
- )
77
-
78
- # Load Fill Classification Model
79
- logger.info("Loading fill classification model...")
80
- fill_classifier = tf.keras.models.load_model(
81
- hf_hub_download("Oamitai/fill-classification", "fill_model.keras")
82
- )
83
-
84
- # Cache TensorFlow models
85
- _SHAPE_CLASSIFIER = shape_classifier
86
- _FILL_CLASSIFIER = fill_classifier
87
-
88
- except Exception as tf_error:
89
- logger.error(f"Error loading TensorFlow models: {tf_error}")
90
- raise
91
 
92
- # Add a small delay to ensure TensorFlow releases GPU resources
93
- time.sleep(0.5)
 
 
 
94
 
95
- # Now load YOLO models
96
- try:
97
- # Load YOLO Card Detection Model
98
- logger.info("Loading card detection model...")
99
- card_model_path = hf_hub_download("Oamitai/card-detection", "best.pt")
100
- card_detector = YOLO(card_model_path)
101
- card_detector.conf = 0.5
102
-
103
- # Load YOLO Shape Detection Model
104
- logger.info("Loading shape detection model...")
105
- shape_model_path = hf_hub_download("Oamitai/shape-detection", "best.pt")
106
- shape_detector = YOLO(shape_model_path)
107
- shape_detector.conf = 0.5
108
-
109
- # Use CPU initially for YOLO models to avoid GPU conflicts
110
- card_detector.to("cpu")
111
- shape_detector.to("cpu")
112
-
113
- # Cache PyTorch models
114
- _CARD_DETECTOR = card_detector
115
- _SHAPE_DETECTOR = shape_detector
116
-
117
- except Exception as pt_error:
118
- logger.error(f"Error loading PyTorch models: {pt_error}")
119
- raise
120
 
121
- logger.info("All models loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
122
  return card_detector, shape_detector, shape_classifier, fill_classifier
123
 
124
  except Exception as e:
@@ -241,32 +218,25 @@ def predict_card_features(
241
  shape_imgs.append(shape_crop_resized)
242
  color_candidates.append(predict_color(shape_crop))
243
 
244
- # Use verbose=0 to suppress progress bar
245
- # Add error handling for TensorFlow prediction
246
- try:
247
- fill_preds = fill_model.predict(np.array(fill_imgs), batch_size=len(fill_imgs), verbose=0)
248
- shape_preds = shape_model.predict(np.array(shape_imgs), batch_size=len(shape_imgs), verbose=0)
249
- except Exception as e:
250
- logger.error(f"Error during TensorFlow prediction: {e}")
251
- # Try with batch size of 1 as fallback
252
- fill_preds = []
253
- shape_preds = []
254
-
255
- for img in fill_imgs:
256
- try:
257
- pred = fill_model.predict(np.array([img]), verbose=0)
258
- fill_preds.append(pred[0])
259
- except Exception as e2:
260
- logger.error(f"Fill prediction error: {e2}")
261
- fill_preds.append(np.array([0.33, 0.33, 0.34])) # Fallback with uniform probabilities
262
-
263
- for img in shape_imgs:
264
- try:
265
- pred = shape_model.predict(np.array([img]), verbose=0)
266
- shape_preds.append(pred[0])
267
- except Exception as e2:
268
- logger.error(f"Shape prediction error: {e2}")
269
- shape_preds.append(np.array([0.33, 0.33, 0.34])) # Fallback with uniform probabilities
270
 
271
  fill_labels = ['empty', 'full', 'striped']
272
  shape_labels = ['diamond', 'oval', 'squiggle']
@@ -412,11 +382,9 @@ def optimize_image_size(image_array: np.ndarray, max_dim=1200) -> np.ndarray:
412
  return cv2.resize(image_array, (new_width, new_height), interpolation=cv2.INTER_AREA)
413
  return image_array
414
 
415
- # Make the process_image function use CPU for PyTorch models
416
  def process_image(input_image):
417
  """
418
- Main processing function for SET detection using CPU for both PyTorch
419
- and TensorFlow to avoid GPU conflicts.
420
  """
421
  if input_image is None:
422
  return None, "Please upload an image."
@@ -424,17 +392,13 @@ def process_image(input_image):
424
  try:
425
  start_time = time.time()
426
 
427
- # Load models
428
  card_detector, shape_detector, shape_model, fill_model = load_models()
429
 
430
- # Force CPU mode for YOLO models (PyTorch)
431
- card_detector.to("cpu")
432
- shape_detector.to("cpu")
433
-
434
  # Optimize image size
435
  optimized_img = optimize_image_size(input_image)
436
 
437
- # Convert to BGR if needed (OpenCV format)
438
  if len(optimized_img.shape) == 3 and optimized_img.shape[2] == 4: # RGBA
439
  optimized_img = cv2.cvtColor(optimized_img, cv2.COLOR_RGBA2BGR)
440
  elif len(optimized_img.shape) == 3 and optimized_img.shape[2] == 3:
@@ -474,17 +438,17 @@ def process_image(input_image):
474
  logger.error(traceback.format_exc())
475
  return input_image, error_message
476
 
477
- # Wrap the CPU-based function with spaces.GPU
478
  @spaces.GPU
479
- def process_image_with_gpu(input_image):
480
  """
481
- Wrapper function that uses spaces.GPU decorator but internally
482
- uses CPU processing to avoid GPU conflicts.
483
  """
484
  return process_image(input_image)
485
 
486
  # =============================================================================
487
- # GRADIO INTERFACE
488
  # =============================================================================
489
  with gr.Blocks(title="SET Game Detector") as demo:
490
  gr.HTML("""
@@ -516,9 +480,9 @@ with gr.Blocks(title="SET Game Detector") as demo:
516
  interactive=False
517
  )
518
 
519
- # Function bindings inside the Blocks context
520
  find_sets_btn.click(
521
- fn=process_image_with_gpu, # Use the wrapper function
522
  inputs=[input_image],
523
  outputs=[output_image, status]
524
  )
 
7
  import torch
8
  from ultralytics import YOLO
9
  from itertools import combinations
 
10
  import gradio as gr
11
  import traceback
12
  import time
13
  from typing import Tuple, List, Dict
14
  import logging
15
 
16
+ # Force CPU mode for TensorFlow
17
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
18
+ tf.config.set_visible_devices([], 'GPU')
19
+
20
+ # Import spaces for ZeroGPU wrapper
 
 
 
 
 
 
21
  try:
22
  import spaces
23
  except ImportError:
 
45
 
46
  def load_models():
47
  """
48
+ Load all models needed for SET detection in CPU-only mode.
49
  Returns tuple of (card_detector, shape_detector, shape_classifier, fill_classifier)
50
  """
51
  global _CARD_DETECTOR, _SHAPE_DETECTOR, _SHAPE_CLASSIFIER, _FILL_CLASSIFIER
 
60
 
61
  logger.info("Loading models from Hugging Face Hub...")
62
 
63
+ # Load Shape Classification Model (TensorFlow)
64
+ logger.info("Loading shape classification model...")
65
+ shape_classifier = load_model(
66
+ hf_hub_download("Oamitai/shape-classification", "shape_model.keras")
67
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ # Load Fill Classification Model (TensorFlow)
70
+ logger.info("Loading fill classification model...")
71
+ fill_classifier = load_model(
72
+ hf_hub_download("Oamitai/fill-classification", "fill_model.keras")
73
+ )
74
 
75
+ # Load YOLO Card Detection Model (PyTorch)
76
+ logger.info("Loading card detection model...")
77
+ card_model_path = hf_hub_download("Oamitai/card-detection", "best.pt")
78
+ card_detector = YOLO(card_model_path)
79
+ card_detector.conf = 0.5
80
+
81
+ # Load YOLO Shape Detection Model (PyTorch)
82
+ logger.info("Loading shape detection model...")
83
+ shape_model_path = hf_hub_download("Oamitai/shape-detection", "best.pt")
84
+ shape_detector = YOLO(shape_model_path)
85
+ shape_detector.conf = 0.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ # Explicitly set to CPU mode
88
+ logger.info("Setting models to CPU mode...")
89
+ card_detector.to("cpu")
90
+ shape_detector.to("cpu")
91
+
92
+ # Cache the models
93
+ _CARD_DETECTOR = card_detector
94
+ _SHAPE_DETECTOR = shape_detector
95
+ _SHAPE_CLASSIFIER = shape_classifier
96
+ _FILL_CLASSIFIER = fill_classifier
97
+
98
+ logger.info("All models loaded successfully in CPU mode!")
99
  return card_detector, shape_detector, shape_classifier, fill_classifier
100
 
101
  except Exception as e:
 
218
  shape_imgs.append(shape_crop_resized)
219
  color_candidates.append(predict_color(shape_crop))
220
 
221
+ # Handle TensorFlow prediction - process one image at a time to avoid memory issues
222
+ fill_preds = []
223
+ shape_preds = []
224
+
225
+ for img in fill_imgs:
226
+ try:
227
+ pred = fill_model.predict(np.array([img]), verbose=0)
228
+ fill_preds.append(pred[0])
229
+ except Exception as e:
230
+ logger.error(f"Fill prediction error: {e}")
231
+ fill_preds.append(np.array([0.33, 0.33, 0.34])) # Fallback
232
+
233
+ for img in shape_imgs:
234
+ try:
235
+ pred = shape_model.predict(np.array([img]), verbose=0)
236
+ shape_preds.append(pred[0])
237
+ except Exception as e:
238
+ logger.error(f"Shape prediction error: {e}")
239
+ shape_preds.append(np.array([0.33, 0.33, 0.34])) # Fallback
 
 
 
 
 
 
 
240
 
241
  fill_labels = ['empty', 'full', 'striped']
242
  shape_labels = ['diamond', 'oval', 'squiggle']
 
382
  return cv2.resize(image_array, (new_width, new_height), interpolation=cv2.INTER_AREA)
383
  return image_array
384
 
 
385
  def process_image(input_image):
386
  """
387
+ CPU-only processing function for SET detection.
 
388
  """
389
  if input_image is None:
390
  return None, "Please upload an image."
 
392
  try:
393
  start_time = time.time()
394
 
395
+ # Load models (CPU-only)
396
  card_detector, shape_detector, shape_model, fill_model = load_models()
397
 
 
 
 
 
398
  # Optimize image size
399
  optimized_img = optimize_image_size(input_image)
400
 
401
+ # Convert to BGR (OpenCV format)
402
  if len(optimized_img.shape) == 3 and optimized_img.shape[2] == 4: # RGBA
403
  optimized_img = cv2.cvtColor(optimized_img, cv2.COLOR_RGBA2BGR)
404
  elif len(optimized_img.shape) == 3 and optimized_img.shape[2] == 3:
 
438
  logger.error(traceback.format_exc())
439
  return input_image, error_message
440
 
441
+ # Keep the spaces.GPU decorator for ZeroGPU API but use CPU internally
442
  @spaces.GPU
443
+ def process_image_wrapper(input_image):
444
  """
445
+ Wrapper for process_image that uses the spaces.GPU decorator
446
+ but internally works in CPU-only mode.
447
  """
448
  return process_image(input_image)
449
 
450
  # =============================================================================
451
+ # SIMPLIFIED GRADIO INTERFACE
452
  # =============================================================================
453
  with gr.Blocks(title="SET Game Detector") as demo:
454
  gr.HTML("""
 
480
  interactive=False
481
  )
482
 
483
+ # Function bindings
484
  find_sets_btn.click(
485
+ fn=process_image_wrapper,
486
  inputs=[input_image],
487
  outputs=[output_image, status]
488
  )