Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,24 +7,17 @@ from tensorflow.keras.models import load_model
|
|
| 7 |
import torch
|
| 8 |
from ultralytics import YOLO
|
| 9 |
from itertools import combinations
|
| 10 |
-
from pathlib import Path
|
| 11 |
import gradio as gr
|
| 12 |
import traceback
|
| 13 |
import time
|
| 14 |
from typing import Tuple, List, Dict
|
| 15 |
import logging
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
tf.config.experimental.set_memory_growth(gpu, True)
|
| 23 |
-
print(f"Found {len(gpus)} GPU(s), memory growth enabled")
|
| 24 |
-
except Exception as e:
|
| 25 |
-
print(f"Error configuring GPU: {e}")
|
| 26 |
-
|
| 27 |
-
# Import spaces correctly for ZeroGPU
|
| 28 |
try:
|
| 29 |
import spaces
|
| 30 |
except ImportError:
|
|
@@ -52,7 +45,7 @@ _FILL_CLASSIFIER = None
|
|
| 52 |
|
| 53 |
def load_models():
|
| 54 |
"""
|
| 55 |
-
Load all models needed for SET detection.
|
| 56 |
Returns tuple of (card_detector, shape_detector, shape_classifier, fill_classifier)
|
| 57 |
"""
|
| 58 |
global _CARD_DETECTOR, _SHAPE_DETECTOR, _SHAPE_CLASSIFIER, _FILL_CLASSIFIER
|
|
@@ -67,58 +60,42 @@ def load_models():
|
|
| 67 |
|
| 68 |
logger.info("Loading models from Hugging Face Hub...")
|
| 69 |
|
| 70 |
-
# Load
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
hf_hub_download("Oamitai/shape-classification", "shape_model.keras")
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
# Load Fill Classification Model
|
| 79 |
-
logger.info("Loading fill classification model...")
|
| 80 |
-
fill_classifier = tf.keras.models.load_model(
|
| 81 |
-
hf_hub_download("Oamitai/fill-classification", "fill_model.keras")
|
| 82 |
-
)
|
| 83 |
-
|
| 84 |
-
# Cache TensorFlow models
|
| 85 |
-
_SHAPE_CLASSIFIER = shape_classifier
|
| 86 |
-
_FILL_CLASSIFIER = fill_classifier
|
| 87 |
-
|
| 88 |
-
except Exception as tf_error:
|
| 89 |
-
logger.error(f"Error loading TensorFlow models: {tf_error}")
|
| 90 |
-
raise
|
| 91 |
|
| 92 |
-
#
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
-
#
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
shape_detector = YOLO(shape_model_path)
|
| 107 |
-
shape_detector.conf = 0.5
|
| 108 |
-
|
| 109 |
-
# Use CPU initially for YOLO models to avoid GPU conflicts
|
| 110 |
-
card_detector.to("cpu")
|
| 111 |
-
shape_detector.to("cpu")
|
| 112 |
-
|
| 113 |
-
# Cache PyTorch models
|
| 114 |
-
_CARD_DETECTOR = card_detector
|
| 115 |
-
_SHAPE_DETECTOR = shape_detector
|
| 116 |
-
|
| 117 |
-
except Exception as pt_error:
|
| 118 |
-
logger.error(f"Error loading PyTorch models: {pt_error}")
|
| 119 |
-
raise
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
return card_detector, shape_detector, shape_classifier, fill_classifier
|
| 123 |
|
| 124 |
except Exception as e:
|
|
@@ -241,32 +218,25 @@ def predict_card_features(
|
|
| 241 |
shape_imgs.append(shape_crop_resized)
|
| 242 |
color_candidates.append(predict_color(shape_crop))
|
| 243 |
|
| 244 |
-
#
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
for img in shape_imgs:
|
| 264 |
-
try:
|
| 265 |
-
pred = shape_model.predict(np.array([img]), verbose=0)
|
| 266 |
-
shape_preds.append(pred[0])
|
| 267 |
-
except Exception as e2:
|
| 268 |
-
logger.error(f"Shape prediction error: {e2}")
|
| 269 |
-
shape_preds.append(np.array([0.33, 0.33, 0.34])) # Fallback with uniform probabilities
|
| 270 |
|
| 271 |
fill_labels = ['empty', 'full', 'striped']
|
| 272 |
shape_labels = ['diamond', 'oval', 'squiggle']
|
|
@@ -412,11 +382,9 @@ def optimize_image_size(image_array: np.ndarray, max_dim=1200) -> np.ndarray:
|
|
| 412 |
return cv2.resize(image_array, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
| 413 |
return image_array
|
| 414 |
|
| 415 |
-
# Make the process_image function use CPU for PyTorch models
|
| 416 |
def process_image(input_image):
|
| 417 |
"""
|
| 418 |
-
|
| 419 |
-
and TensorFlow to avoid GPU conflicts.
|
| 420 |
"""
|
| 421 |
if input_image is None:
|
| 422 |
return None, "Please upload an image."
|
|
@@ -424,17 +392,13 @@ def process_image(input_image):
|
|
| 424 |
try:
|
| 425 |
start_time = time.time()
|
| 426 |
|
| 427 |
-
# Load models
|
| 428 |
card_detector, shape_detector, shape_model, fill_model = load_models()
|
| 429 |
|
| 430 |
-
# Force CPU mode for YOLO models (PyTorch)
|
| 431 |
-
card_detector.to("cpu")
|
| 432 |
-
shape_detector.to("cpu")
|
| 433 |
-
|
| 434 |
# Optimize image size
|
| 435 |
optimized_img = optimize_image_size(input_image)
|
| 436 |
|
| 437 |
-
# Convert to BGR
|
| 438 |
if len(optimized_img.shape) == 3 and optimized_img.shape[2] == 4: # RGBA
|
| 439 |
optimized_img = cv2.cvtColor(optimized_img, cv2.COLOR_RGBA2BGR)
|
| 440 |
elif len(optimized_img.shape) == 3 and optimized_img.shape[2] == 3:
|
|
@@ -474,17 +438,17 @@ def process_image(input_image):
|
|
| 474 |
logger.error(traceback.format_exc())
|
| 475 |
return input_image, error_message
|
| 476 |
|
| 477 |
-
#
|
| 478 |
@spaces.GPU
|
| 479 |
-
def
|
| 480 |
"""
|
| 481 |
-
Wrapper
|
| 482 |
-
|
| 483 |
"""
|
| 484 |
return process_image(input_image)
|
| 485 |
|
| 486 |
# =============================================================================
|
| 487 |
-
# GRADIO INTERFACE
|
| 488 |
# =============================================================================
|
| 489 |
with gr.Blocks(title="SET Game Detector") as demo:
|
| 490 |
gr.HTML("""
|
|
@@ -516,9 +480,9 @@ with gr.Blocks(title="SET Game Detector") as demo:
|
|
| 516 |
interactive=False
|
| 517 |
)
|
| 518 |
|
| 519 |
-
# Function bindings
|
| 520 |
find_sets_btn.click(
|
| 521 |
-
fn=
|
| 522 |
inputs=[input_image],
|
| 523 |
outputs=[output_image, status]
|
| 524 |
)
|
|
|
|
| 7 |
import torch
|
| 8 |
from ultralytics import YOLO
|
| 9 |
from itertools import combinations
|
|
|
|
| 10 |
import gradio as gr
|
| 11 |
import traceback
|
| 12 |
import time
|
| 13 |
from typing import Tuple, List, Dict
|
| 14 |
import logging
|
| 15 |
|
| 16 |
+
# Force CPU mode for TensorFlow
|
| 17 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
| 18 |
+
tf.config.set_visible_devices([], 'GPU')
|
| 19 |
+
|
| 20 |
+
# Import spaces for ZeroGPU wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
try:
|
| 22 |
import spaces
|
| 23 |
except ImportError:
|
|
|
|
| 45 |
|
| 46 |
def load_models():
|
| 47 |
"""
|
| 48 |
+
Load all models needed for SET detection in CPU-only mode.
|
| 49 |
Returns tuple of (card_detector, shape_detector, shape_classifier, fill_classifier)
|
| 50 |
"""
|
| 51 |
global _CARD_DETECTOR, _SHAPE_DETECTOR, _SHAPE_CLASSIFIER, _FILL_CLASSIFIER
|
|
|
|
| 60 |
|
| 61 |
logger.info("Loading models from Hugging Face Hub...")
|
| 62 |
|
| 63 |
+
# Load Shape Classification Model (TensorFlow)
|
| 64 |
+
logger.info("Loading shape classification model...")
|
| 65 |
+
shape_classifier = load_model(
|
| 66 |
+
hf_hub_download("Oamitai/shape-classification", "shape_model.keras")
|
| 67 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
+
# Load Fill Classification Model (TensorFlow)
|
| 70 |
+
logger.info("Loading fill classification model...")
|
| 71 |
+
fill_classifier = load_model(
|
| 72 |
+
hf_hub_download("Oamitai/fill-classification", "fill_model.keras")
|
| 73 |
+
)
|
| 74 |
|
| 75 |
+
# Load YOLO Card Detection Model (PyTorch)
|
| 76 |
+
logger.info("Loading card detection model...")
|
| 77 |
+
card_model_path = hf_hub_download("Oamitai/card-detection", "best.pt")
|
| 78 |
+
card_detector = YOLO(card_model_path)
|
| 79 |
+
card_detector.conf = 0.5
|
| 80 |
+
|
| 81 |
+
# Load YOLO Shape Detection Model (PyTorch)
|
| 82 |
+
logger.info("Loading shape detection model...")
|
| 83 |
+
shape_model_path = hf_hub_download("Oamitai/shape-detection", "best.pt")
|
| 84 |
+
shape_detector = YOLO(shape_model_path)
|
| 85 |
+
shape_detector.conf = 0.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
+
# Explicitly set to CPU mode
|
| 88 |
+
logger.info("Setting models to CPU mode...")
|
| 89 |
+
card_detector.to("cpu")
|
| 90 |
+
shape_detector.to("cpu")
|
| 91 |
+
|
| 92 |
+
# Cache the models
|
| 93 |
+
_CARD_DETECTOR = card_detector
|
| 94 |
+
_SHAPE_DETECTOR = shape_detector
|
| 95 |
+
_SHAPE_CLASSIFIER = shape_classifier
|
| 96 |
+
_FILL_CLASSIFIER = fill_classifier
|
| 97 |
+
|
| 98 |
+
logger.info("All models loaded successfully in CPU mode!")
|
| 99 |
return card_detector, shape_detector, shape_classifier, fill_classifier
|
| 100 |
|
| 101 |
except Exception as e:
|
|
|
|
| 218 |
shape_imgs.append(shape_crop_resized)
|
| 219 |
color_candidates.append(predict_color(shape_crop))
|
| 220 |
|
| 221 |
+
# Handle TensorFlow prediction - process one image at a time to avoid memory issues
|
| 222 |
+
fill_preds = []
|
| 223 |
+
shape_preds = []
|
| 224 |
+
|
| 225 |
+
for img in fill_imgs:
|
| 226 |
+
try:
|
| 227 |
+
pred = fill_model.predict(np.array([img]), verbose=0)
|
| 228 |
+
fill_preds.append(pred[0])
|
| 229 |
+
except Exception as e:
|
| 230 |
+
logger.error(f"Fill prediction error: {e}")
|
| 231 |
+
fill_preds.append(np.array([0.33, 0.33, 0.34])) # Fallback
|
| 232 |
+
|
| 233 |
+
for img in shape_imgs:
|
| 234 |
+
try:
|
| 235 |
+
pred = shape_model.predict(np.array([img]), verbose=0)
|
| 236 |
+
shape_preds.append(pred[0])
|
| 237 |
+
except Exception as e:
|
| 238 |
+
logger.error(f"Shape prediction error: {e}")
|
| 239 |
+
shape_preds.append(np.array([0.33, 0.33, 0.34])) # Fallback
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
fill_labels = ['empty', 'full', 'striped']
|
| 242 |
shape_labels = ['diamond', 'oval', 'squiggle']
|
|
|
|
| 382 |
return cv2.resize(image_array, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
| 383 |
return image_array
|
| 384 |
|
|
|
|
| 385 |
def process_image(input_image):
|
| 386 |
"""
|
| 387 |
+
CPU-only processing function for SET detection.
|
|
|
|
| 388 |
"""
|
| 389 |
if input_image is None:
|
| 390 |
return None, "Please upload an image."
|
|
|
|
| 392 |
try:
|
| 393 |
start_time = time.time()
|
| 394 |
|
| 395 |
+
# Load models (CPU-only)
|
| 396 |
card_detector, shape_detector, shape_model, fill_model = load_models()
|
| 397 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
# Optimize image size
|
| 399 |
optimized_img = optimize_image_size(input_image)
|
| 400 |
|
| 401 |
+
# Convert to BGR (OpenCV format)
|
| 402 |
if len(optimized_img.shape) == 3 and optimized_img.shape[2] == 4: # RGBA
|
| 403 |
optimized_img = cv2.cvtColor(optimized_img, cv2.COLOR_RGBA2BGR)
|
| 404 |
elif len(optimized_img.shape) == 3 and optimized_img.shape[2] == 3:
|
|
|
|
| 438 |
logger.error(traceback.format_exc())
|
| 439 |
return input_image, error_message
|
| 440 |
|
| 441 |
+
# Keep the spaces.GPU decorator for ZeroGPU API but use CPU internally
|
| 442 |
@spaces.GPU
|
| 443 |
+
def process_image_wrapper(input_image):
|
| 444 |
"""
|
| 445 |
+
Wrapper for process_image that uses the spaces.GPU decorator
|
| 446 |
+
but internally works in CPU-only mode.
|
| 447 |
"""
|
| 448 |
return process_image(input_image)
|
| 449 |
|
| 450 |
# =============================================================================
|
| 451 |
+
# SIMPLIFIED GRADIO INTERFACE
|
| 452 |
# =============================================================================
|
| 453 |
with gr.Blocks(title="SET Game Detector") as demo:
|
| 454 |
gr.HTML("""
|
|
|
|
| 480 |
interactive=False
|
| 481 |
)
|
| 482 |
|
| 483 |
+
# Function bindings
|
| 484 |
find_sets_btn.click(
|
| 485 |
+
fn=process_image_wrapper,
|
| 486 |
inputs=[input_image],
|
| 487 |
outputs=[output_image, status]
|
| 488 |
)
|