Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,9 +8,7 @@ import torch
|
|
| 8 |
from ultralytics import YOLO
|
| 9 |
from itertools import combinations
|
| 10 |
from pathlib import Path
|
| 11 |
-
from PIL import Image
|
| 12 |
import gradio as gr
|
| 13 |
-
import functools
|
| 14 |
import traceback
|
| 15 |
import time
|
| 16 |
from typing import Tuple, List, Dict
|
|
@@ -34,78 +32,50 @@ logging.basicConfig(level=logging.INFO,
|
|
| 34 |
logger = logging.getLogger("set_detector")
|
| 35 |
|
| 36 |
# =============================================================================
|
| 37 |
-
# MODEL
|
| 38 |
# =============================================================================
|
| 39 |
-
# For loading models from Hugging Face Hub
|
| 40 |
-
HF_MODEL_REPO_PREFIX = "Omamitai"
|
| 41 |
-
CARD_DETECTION_REPO = f"{HF_MODEL_REPO_PREFIX}/card-detection"
|
| 42 |
-
SHAPE_DETECTION_REPO = f"{HF_MODEL_REPO_PREFIX}/shape-detection"
|
| 43 |
-
SHAPE_CLASSIFICATION_REPO = f"{HF_MODEL_REPO_PREFIX}/shape-classification"
|
| 44 |
-
FILL_CLASSIFICATION_REPO = f"{HF_MODEL_REPO_PREFIX}/fill-classification"
|
| 45 |
-
|
| 46 |
-
CARD_MODEL_FILENAME = "best.pt"
|
| 47 |
-
SHAPE_MODEL_FILENAME = "best.pt"
|
| 48 |
-
SHAPE_CLASS_MODEL_FILENAME = "shape_model.keras"
|
| 49 |
-
FILL_CLASS_MODEL_FILENAME = "fill_model.keras"
|
| 50 |
-
|
| 51 |
# Global variables for model caching
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
|
| 57 |
def load_models():
|
| 58 |
"""
|
| 59 |
Load all models needed for SET detection.
|
| 60 |
-
Returns tuple of (card_detector, shape_detector,
|
| 61 |
"""
|
| 62 |
-
global
|
| 63 |
|
| 64 |
# Return cached models if already loaded
|
| 65 |
-
if all([
|
| 66 |
-
|
|
|
|
| 67 |
|
| 68 |
try:
|
| 69 |
from huggingface_hub import hf_hub_download
|
| 70 |
|
| 71 |
-
|
| 72 |
-
logger.info("Downloading detection models...")
|
| 73 |
-
card_model_path = hf_hub_download(
|
| 74 |
-
repo_id=CARD_DETECTION_REPO,
|
| 75 |
-
filename=CARD_MODEL_FILENAME,
|
| 76 |
-
cache_dir="./hf_cache"
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
shape_model_path = hf_hub_download(
|
| 80 |
-
repo_id=SHAPE_DETECTION_REPO,
|
| 81 |
-
filename=SHAPE_MODEL_FILENAME,
|
| 82 |
-
cache_dir="./hf_cache"
|
| 83 |
-
)
|
| 84 |
-
|
| 85 |
-
# Download and load classification models
|
| 86 |
-
logger.info("Downloading classification models...")
|
| 87 |
-
shape_class_path = hf_hub_download(
|
| 88 |
-
repo_id=SHAPE_CLASSIFICATION_REPO,
|
| 89 |
-
filename=SHAPE_CLASS_MODEL_FILENAME,
|
| 90 |
-
cache_dir="./hf_cache"
|
| 91 |
-
)
|
| 92 |
-
|
| 93 |
-
fill_class_path = hf_hub_download(
|
| 94 |
-
repo_id=FILL_CLASSIFICATION_REPO,
|
| 95 |
-
filename=FILL_CLASS_MODEL_FILENAME,
|
| 96 |
-
cache_dir="./hf_cache"
|
| 97 |
-
)
|
| 98 |
|
| 99 |
-
# Load
|
| 100 |
-
|
| 101 |
-
card_detector = YOLO(
|
| 102 |
card_detector.conf = 0.5
|
| 103 |
|
| 104 |
-
|
|
|
|
|
|
|
| 105 |
shape_detector.conf = 0.5
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
# Use GPU if available
|
| 111 |
if torch.cuda.is_available():
|
|
@@ -114,10 +84,10 @@ def load_models():
|
|
| 114 |
shape_detector.to("cuda")
|
| 115 |
|
| 116 |
# Cache the models
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
|
| 122 |
logger.info("All models loaded successfully!")
|
| 123 |
return card_detector, shape_detector, shape_classifier, fill_classifier
|
|
@@ -437,10 +407,6 @@ def process_image(input_image):
|
|
| 437 |
logger.error(traceback.format_exc())
|
| 438 |
return input_image, error_message
|
| 439 |
|
| 440 |
-
# Create examples directory
|
| 441 |
-
examples_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "examples")
|
| 442 |
-
os.makedirs(examples_dir, exist_ok=True)
|
| 443 |
-
|
| 444 |
# =============================================================================
|
| 445 |
# GRADIO INTERFACE
|
| 446 |
# =============================================================================
|
|
@@ -474,15 +440,6 @@ with gr.Blocks(title="SET Game Detector") as demo:
|
|
| 474 |
interactive=False
|
| 475 |
)
|
| 476 |
|
| 477 |
-
# Examples - simplified
|
| 478 |
-
gr.Examples(
|
| 479 |
-
examples=[
|
| 480 |
-
os.path.join(examples_dir, "set_example1.jpg"),
|
| 481 |
-
os.path.join(examples_dir, "set_example2.jpg")
|
| 482 |
-
],
|
| 483 |
-
inputs=input_image
|
| 484 |
-
)
|
| 485 |
-
|
| 486 |
# Function bindings inside the Blocks context
|
| 487 |
find_sets_btn.click(
|
| 488 |
fn=process_image,
|
|
|
|
| 8 |
from ultralytics import YOLO
|
| 9 |
from itertools import combinations
|
| 10 |
from pathlib import Path
|
|
|
|
| 11 |
import gradio as gr
|
|
|
|
| 12 |
import traceback
|
| 13 |
import time
|
| 14 |
from typing import Tuple, List, Dict
|
|
|
|
| 32 |
logger = logging.getLogger("set_detector")
|
| 33 |
|
| 34 |
# =============================================================================
|
| 35 |
+
# MODEL LOADING
|
| 36 |
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
# Global variables for model caching
|
| 38 |
+
_CARD_DETECTOR = None
|
| 39 |
+
_SHAPE_DETECTOR = None
|
| 40 |
+
_SHAPE_CLASSIFIER = None
|
| 41 |
+
_FILL_CLASSIFIER = None
|
| 42 |
|
| 43 |
def load_models():
|
| 44 |
"""
|
| 45 |
Load all models needed for SET detection.
|
| 46 |
+
Returns tuple of (card_detector, shape_detector, shape_classifier, fill_classifier)
|
| 47 |
"""
|
| 48 |
+
global _CARD_DETECTOR, _SHAPE_DETECTOR, _SHAPE_CLASSIFIER, _FILL_CLASSIFIER
|
| 49 |
|
| 50 |
# Return cached models if already loaded
|
| 51 |
+
if all([_CARD_DETECTOR, _SHAPE_DETECTOR, _SHAPE_CLASSIFIER, _FILL_CLASSIFIER]):
|
| 52 |
+
logger.info("Using cached models")
|
| 53 |
+
return _CARD_DETECTOR, _SHAPE_DETECTOR, _SHAPE_CLASSIFIER, _FILL_CLASSIFIER
|
| 54 |
|
| 55 |
try:
|
| 56 |
from huggingface_hub import hf_hub_download
|
| 57 |
|
| 58 |
+
logger.info("Loading models from Hugging Face Hub...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
# Load YOLO Card Detection Model
|
| 61 |
+
card_model_path = hf_hub_download("Oamitai/card-detection", "best.pt")
|
| 62 |
+
card_detector = YOLO(card_model_path)
|
| 63 |
card_detector.conf = 0.5
|
| 64 |
|
| 65 |
+
# Load YOLO Shape Detection Model
|
| 66 |
+
shape_model_path = hf_hub_download("Oamitai/shape-detection", "best.pt")
|
| 67 |
+
shape_detector = YOLO(shape_model_path)
|
| 68 |
shape_detector.conf = 0.5
|
| 69 |
|
| 70 |
+
# Load Shape Classification Model
|
| 71 |
+
shape_classifier = tf.keras.models.load_model(
|
| 72 |
+
hf_hub_download("Oamitai/shape-classification", "shape_model.keras")
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Load Fill Classification Model
|
| 76 |
+
fill_classifier = tf.keras.models.load_model(
|
| 77 |
+
hf_hub_download("Oamitai/fill-classification", "fill_model.keras")
|
| 78 |
+
)
|
| 79 |
|
| 80 |
# Use GPU if available
|
| 81 |
if torch.cuda.is_available():
|
|
|
|
| 84 |
shape_detector.to("cuda")
|
| 85 |
|
| 86 |
# Cache the models
|
| 87 |
+
_CARD_DETECTOR = card_detector
|
| 88 |
+
_SHAPE_DETECTOR = shape_detector
|
| 89 |
+
_SHAPE_CLASSIFIER = shape_classifier
|
| 90 |
+
_FILL_CLASSIFIER = fill_classifier
|
| 91 |
|
| 92 |
logger.info("All models loaded successfully!")
|
| 93 |
return card_detector, shape_detector, shape_classifier, fill_classifier
|
|
|
|
| 407 |
logger.error(traceback.format_exc())
|
| 408 |
return input_image, error_message
|
| 409 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
# =============================================================================
|
| 411 |
# GRADIO INTERFACE
|
| 412 |
# =============================================================================
|
|
|
|
| 440 |
interactive=False
|
| 441 |
)
|
| 442 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
# Function bindings inside the Blocks context
|
| 444 |
find_sets_btn.click(
|
| 445 |
fn=process_image,
|