Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,6 +14,16 @@ import time
|
|
| 14 |
from typing import Tuple, List, Dict
|
| 15 |
import logging
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
# Import spaces correctly for ZeroGPU
|
| 18 |
try:
|
| 19 |
import spaces
|
|
@@ -57,37 +67,56 @@ def load_models():
|
|
| 57 |
|
| 58 |
logger.info("Loading models from Hugging Face Hub...")
|
| 59 |
|
| 60 |
-
# Load
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
| 79 |
|
| 80 |
-
#
|
| 81 |
-
|
| 82 |
-
logger.info("CUDA is available. Using GPU for inference.")
|
| 83 |
-
card_detector.to("cuda")
|
| 84 |
-
shape_detector.to("cuda")
|
| 85 |
|
| 86 |
-
#
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
logger.info("All models loaded successfully!")
|
| 93 |
return card_detector, shape_detector, shape_classifier, fill_classifier
|
|
@@ -213,8 +242,31 @@ def predict_card_features(
|
|
| 213 |
color_candidates.append(predict_color(shape_crop))
|
| 214 |
|
| 215 |
# Use verbose=0 to suppress progress bar
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
fill_labels = ['empty', 'full', 'striped']
|
| 220 |
shape_labels = ['diamond', 'oval', 'squiggle']
|
|
@@ -223,9 +275,20 @@ def predict_card_features(
|
|
| 223 |
shape_result = [shape_labels[np.argmax(sp)] for sp in shape_preds]
|
| 224 |
|
| 225 |
# Take the most common color/fill/shape across all shape detections for the card
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
return {
|
| 231 |
'count': len(shape_boxes),
|
|
@@ -349,11 +412,11 @@ def optimize_image_size(image_array: np.ndarray, max_dim=1200) -> np.ndarray:
|
|
| 349 |
return cv2.resize(image_array, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
| 350 |
return image_array
|
| 351 |
|
| 352 |
-
|
| 353 |
def process_image(input_image):
|
| 354 |
"""
|
| 355 |
-
Main processing function for SET detection
|
| 356 |
-
|
| 357 |
"""
|
| 358 |
if input_image is None:
|
| 359 |
return None, "Please upload an image."
|
|
@@ -364,6 +427,10 @@ def process_image(input_image):
|
|
| 364 |
# Load models
|
| 365 |
card_detector, shape_detector, shape_model, fill_model = load_models()
|
| 366 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
# Optimize image size
|
| 368 |
optimized_img = optimize_image_size(input_image)
|
| 369 |
|
|
@@ -407,6 +474,15 @@ def process_image(input_image):
|
|
| 407 |
logger.error(traceback.format_exc())
|
| 408 |
return input_image, error_message
|
| 409 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
# =============================================================================
|
| 411 |
# GRADIO INTERFACE
|
| 412 |
# =============================================================================
|
|
@@ -442,7 +518,7 @@ with gr.Blocks(title="SET Game Detector") as demo:
|
|
| 442 |
|
| 443 |
# Function bindings inside the Blocks context
|
| 444 |
find_sets_btn.click(
|
| 445 |
-
fn=
|
| 446 |
inputs=[input_image],
|
| 447 |
outputs=[output_image, status]
|
| 448 |
)
|
|
|
|
| 14 |
from typing import Tuple, List, Dict
|
| 15 |
import logging
|
| 16 |
|
| 17 |
+
# Configure TensorFlow GPU memory growth to prevent memory conflicts
|
| 18 |
+
try:
|
| 19 |
+
gpus = tf.config.list_physical_devices('GPU')
|
| 20 |
+
if gpus:
|
| 21 |
+
for gpu in gpus:
|
| 22 |
+
tf.config.experimental.set_memory_growth(gpu, True)
|
| 23 |
+
print(f"Found {len(gpus)} GPU(s), memory growth enabled")
|
| 24 |
+
except Exception as e:
|
| 25 |
+
print(f"Error configuring GPU: {e}")
|
| 26 |
+
|
| 27 |
# Import spaces correctly for ZeroGPU
|
| 28 |
try:
|
| 29 |
import spaces
|
|
|
|
| 67 |
|
| 68 |
logger.info("Loading models from Hugging Face Hub...")
|
| 69 |
|
| 70 |
+
# Load TensorFlow models first to better manage GPU memory
|
| 71 |
+
try:
|
| 72 |
+
# Load Shape Classification Model
|
| 73 |
+
logger.info("Loading shape classification model...")
|
| 74 |
+
shape_classifier = tf.keras.models.load_model(
|
| 75 |
+
hf_hub_download("Oamitai/shape-classification", "shape_model.keras")
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Load Fill Classification Model
|
| 79 |
+
logger.info("Loading fill classification model...")
|
| 80 |
+
fill_classifier = tf.keras.models.load_model(
|
| 81 |
+
hf_hub_download("Oamitai/fill-classification", "fill_model.keras")
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Cache TensorFlow models
|
| 85 |
+
_SHAPE_CLASSIFIER = shape_classifier
|
| 86 |
+
_FILL_CLASSIFIER = fill_classifier
|
| 87 |
+
|
| 88 |
+
except Exception as tf_error:
|
| 89 |
+
logger.error(f"Error loading TensorFlow models: {tf_error}")
|
| 90 |
+
raise
|
| 91 |
|
| 92 |
+
# Add a small delay to ensure TensorFlow releases GPU resources
|
| 93 |
+
time.sleep(0.5)
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
+
# Now load YOLO models
|
| 96 |
+
try:
|
| 97 |
+
# Load YOLO Card Detection Model
|
| 98 |
+
logger.info("Loading card detection model...")
|
| 99 |
+
card_model_path = hf_hub_download("Oamitai/card-detection", "best.pt")
|
| 100 |
+
card_detector = YOLO(card_model_path)
|
| 101 |
+
card_detector.conf = 0.5
|
| 102 |
+
|
| 103 |
+
# Load YOLO Shape Detection Model
|
| 104 |
+
logger.info("Loading shape detection model...")
|
| 105 |
+
shape_model_path = hf_hub_download("Oamitai/shape-detection", "best.pt")
|
| 106 |
+
shape_detector = YOLO(shape_model_path)
|
| 107 |
+
shape_detector.conf = 0.5
|
| 108 |
+
|
| 109 |
+
# Use CPU initially for YOLO models to avoid GPU conflicts
|
| 110 |
+
card_detector.to("cpu")
|
| 111 |
+
shape_detector.to("cpu")
|
| 112 |
+
|
| 113 |
+
# Cache PyTorch models
|
| 114 |
+
_CARD_DETECTOR = card_detector
|
| 115 |
+
_SHAPE_DETECTOR = shape_detector
|
| 116 |
+
|
| 117 |
+
except Exception as pt_error:
|
| 118 |
+
logger.error(f"Error loading PyTorch models: {pt_error}")
|
| 119 |
+
raise
|
| 120 |
|
| 121 |
logger.info("All models loaded successfully!")
|
| 122 |
return card_detector, shape_detector, shape_classifier, fill_classifier
|
|
|
|
| 242 |
color_candidates.append(predict_color(shape_crop))
|
| 243 |
|
| 244 |
# Use verbose=0 to suppress progress bar
|
| 245 |
+
# Add error handling for TensorFlow prediction
|
| 246 |
+
try:
|
| 247 |
+
fill_preds = fill_model.predict(np.array(fill_imgs), batch_size=len(fill_imgs), verbose=0)
|
| 248 |
+
shape_preds = shape_model.predict(np.array(shape_imgs), batch_size=len(shape_imgs), verbose=0)
|
| 249 |
+
except Exception as e:
|
| 250 |
+
logger.error(f"Error during TensorFlow prediction: {e}")
|
| 251 |
+
# Try with batch size of 1 as fallback
|
| 252 |
+
fill_preds = []
|
| 253 |
+
shape_preds = []
|
| 254 |
+
|
| 255 |
+
for img in fill_imgs:
|
| 256 |
+
try:
|
| 257 |
+
pred = fill_model.predict(np.array([img]), verbose=0)
|
| 258 |
+
fill_preds.append(pred[0])
|
| 259 |
+
except Exception as e2:
|
| 260 |
+
logger.error(f"Fill prediction error: {e2}")
|
| 261 |
+
fill_preds.append(np.array([0.33, 0.33, 0.34])) # Fallback with uniform probabilities
|
| 262 |
+
|
| 263 |
+
for img in shape_imgs:
|
| 264 |
+
try:
|
| 265 |
+
pred = shape_model.predict(np.array([img]), verbose=0)
|
| 266 |
+
shape_preds.append(pred[0])
|
| 267 |
+
except Exception as e2:
|
| 268 |
+
logger.error(f"Shape prediction error: {e2}")
|
| 269 |
+
shape_preds.append(np.array([0.33, 0.33, 0.34])) # Fallback with uniform probabilities
|
| 270 |
|
| 271 |
fill_labels = ['empty', 'full', 'striped']
|
| 272 |
shape_labels = ['diamond', 'oval', 'squiggle']
|
|
|
|
| 275 |
shape_result = [shape_labels[np.argmax(sp)] for sp in shape_preds]
|
| 276 |
|
| 277 |
# Take the most common color/fill/shape across all shape detections for the card
|
| 278 |
+
if color_candidates:
|
| 279 |
+
final_color = max(set(color_candidates), key=color_candidates.count)
|
| 280 |
+
else:
|
| 281 |
+
final_color = "unknown"
|
| 282 |
+
|
| 283 |
+
if fill_result:
|
| 284 |
+
final_fill = max(set(fill_result), key=fill_result.count)
|
| 285 |
+
else:
|
| 286 |
+
final_fill = "unknown"
|
| 287 |
+
|
| 288 |
+
if shape_result:
|
| 289 |
+
final_shape = max(set(shape_result), key=shape_result.count)
|
| 290 |
+
else:
|
| 291 |
+
final_shape = "unknown"
|
| 292 |
|
| 293 |
return {
|
| 294 |
'count': len(shape_boxes),
|
|
|
|
| 412 |
return cv2.resize(image_array, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
| 413 |
return image_array
|
| 414 |
|
| 415 |
+
# Make the process_image function use CPU for PyTorch models
|
| 416 |
def process_image(input_image):
|
| 417 |
"""
|
| 418 |
+
Main processing function for SET detection using CPU for both PyTorch
|
| 419 |
+
and TensorFlow to avoid GPU conflicts.
|
| 420 |
"""
|
| 421 |
if input_image is None:
|
| 422 |
return None, "Please upload an image."
|
|
|
|
| 427 |
# Load models
|
| 428 |
card_detector, shape_detector, shape_model, fill_model = load_models()
|
| 429 |
|
| 430 |
+
# Force CPU mode for YOLO models (PyTorch)
|
| 431 |
+
card_detector.to("cpu")
|
| 432 |
+
shape_detector.to("cpu")
|
| 433 |
+
|
| 434 |
# Optimize image size
|
| 435 |
optimized_img = optimize_image_size(input_image)
|
| 436 |
|
|
|
|
| 474 |
logger.error(traceback.format_exc())
|
| 475 |
return input_image, error_message
|
| 476 |
|
| 477 |
+
# Wrap the CPU-based function with spaces.GPU
|
| 478 |
+
@spaces.GPU
|
| 479 |
+
def process_image_with_gpu(input_image):
|
| 480 |
+
"""
|
| 481 |
+
Wrapper function that uses spaces.GPU decorator but internally
|
| 482 |
+
uses CPU processing to avoid GPU conflicts.
|
| 483 |
+
"""
|
| 484 |
+
return process_image(input_image)
|
| 485 |
+
|
| 486 |
# =============================================================================
|
| 487 |
# GRADIO INTERFACE
|
| 488 |
# =============================================================================
|
|
|
|
| 518 |
|
| 519 |
# Function bindings inside the Blocks context
|
| 520 |
find_sets_btn.click(
|
| 521 |
+
fn=process_image_with_gpu, # Use the wrapper function
|
| 522 |
inputs=[input_image],
|
| 523 |
outputs=[output_image, status]
|
| 524 |
)
|