Oamitai commited on
Commit
7e5ec53
·
verified ·
1 Parent(s): 2da7fc0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -73
app.py CHANGED
@@ -8,9 +8,7 @@ import torch
8
  from ultralytics import YOLO
9
  from itertools import combinations
10
  from pathlib import Path
11
- from PIL import Image
12
  import gradio as gr
13
- import functools
14
  import traceback
15
  import time
16
  from typing import Tuple, List, Dict
@@ -34,78 +32,50 @@ logging.basicConfig(level=logging.INFO,
34
  logger = logging.getLogger("set_detector")
35
 
36
  # =============================================================================
37
- # MODEL PATHS
38
  # =============================================================================
39
- # For loading models from Hugging Face Hub
40
- HF_MODEL_REPO_PREFIX = "Omamitai"
41
- CARD_DETECTION_REPO = f"{HF_MODEL_REPO_PREFIX}/card-detection"
42
- SHAPE_DETECTION_REPO = f"{HF_MODEL_REPO_PREFIX}/shape-detection"
43
- SHAPE_CLASSIFICATION_REPO = f"{HF_MODEL_REPO_PREFIX}/shape-classification"
44
- FILL_CLASSIFICATION_REPO = f"{HF_MODEL_REPO_PREFIX}/fill-classification"
45
-
46
- CARD_MODEL_FILENAME = "best.pt"
47
- SHAPE_MODEL_FILENAME = "best.pt"
48
- SHAPE_CLASS_MODEL_FILENAME = "shape_model.keras"
49
- FILL_CLASS_MODEL_FILENAME = "fill_model.keras"
50
-
51
  # Global variables for model caching
52
- _MODEL_SHAPE = None
53
- _MODEL_FILL = None
54
- _DETECTOR_CARD = None
55
- _DETECTOR_SHAPE = None
56
 
57
  def load_models():
58
  """
59
  Load all models needed for SET detection.
60
- Returns tuple of (card_detector, shape_detector, shape_model, fill_model)
61
  """
62
- global _MODEL_SHAPE, _MODEL_FILL, _DETECTOR_CARD, _DETECTOR_SHAPE
63
 
64
  # Return cached models if already loaded
65
- if all([_MODEL_SHAPE, _MODEL_FILL, _DETECTOR_CARD, _DETECTOR_SHAPE]):
66
- return _DETECTOR_CARD, _DETECTOR_SHAPE, _MODEL_SHAPE, _MODEL_FILL
 
67
 
68
  try:
69
  from huggingface_hub import hf_hub_download
70
 
71
- # Download and load YOLO models
72
- logger.info("Downloading detection models...")
73
- card_model_path = hf_hub_download(
74
- repo_id=CARD_DETECTION_REPO,
75
- filename=CARD_MODEL_FILENAME,
76
- cache_dir="./hf_cache"
77
- )
78
-
79
- shape_model_path = hf_hub_download(
80
- repo_id=SHAPE_DETECTION_REPO,
81
- filename=SHAPE_MODEL_FILENAME,
82
- cache_dir="./hf_cache"
83
- )
84
-
85
- # Download and load classification models
86
- logger.info("Downloading classification models...")
87
- shape_class_path = hf_hub_download(
88
- repo_id=SHAPE_CLASSIFICATION_REPO,
89
- filename=SHAPE_CLASS_MODEL_FILENAME,
90
- cache_dir="./hf_cache"
91
- )
92
-
93
- fill_class_path = hf_hub_download(
94
- repo_id=FILL_CLASSIFICATION_REPO,
95
- filename=FILL_CLASS_MODEL_FILENAME,
96
- cache_dir="./hf_cache"
97
- )
98
 
99
- # Load all models
100
- logger.info("Loading models...")
101
- card_detector = YOLO(str(card_model_path))
102
  card_detector.conf = 0.5
103
 
104
- shape_detector = YOLO(str(shape_model_path))
 
 
105
  shape_detector.conf = 0.5
106
 
107
- shape_classifier = load_model(str(shape_class_path))
108
- fill_classifier = load_model(str(fill_class_path))
 
 
 
 
 
 
 
109
 
110
  # Use GPU if available
111
  if torch.cuda.is_available():
@@ -114,10 +84,10 @@ def load_models():
114
  shape_detector.to("cuda")
115
 
116
  # Cache the models
117
- _DETECTOR_CARD = card_detector
118
- _DETECTOR_SHAPE = shape_detector
119
- _MODEL_SHAPE = shape_classifier
120
- _MODEL_FILL = fill_classifier
121
 
122
  logger.info("All models loaded successfully!")
123
  return card_detector, shape_detector, shape_classifier, fill_classifier
@@ -437,10 +407,6 @@ def process_image(input_image):
437
  logger.error(traceback.format_exc())
438
  return input_image, error_message
439
 
440
- # Create examples directory
441
- examples_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "examples")
442
- os.makedirs(examples_dir, exist_ok=True)
443
-
444
  # =============================================================================
445
  # GRADIO INTERFACE
446
  # =============================================================================
@@ -474,15 +440,6 @@ with gr.Blocks(title="SET Game Detector") as demo:
474
  interactive=False
475
  )
476
 
477
- # Examples - simplified
478
- gr.Examples(
479
- examples=[
480
- os.path.join(examples_dir, "set_example1.jpg"),
481
- os.path.join(examples_dir, "set_example2.jpg")
482
- ],
483
- inputs=input_image
484
- )
485
-
486
  # Function bindings inside the Blocks context
487
  find_sets_btn.click(
488
  fn=process_image,
 
8
  from ultralytics import YOLO
9
  from itertools import combinations
10
  from pathlib import Path
 
11
  import gradio as gr
 
12
  import traceback
13
  import time
14
  from typing import Tuple, List, Dict
 
32
  logger = logging.getLogger("set_detector")
33
 
34
  # =============================================================================
35
+ # MODEL LOADING
36
  # =============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
37
  # Global variables for model caching
38
+ _CARD_DETECTOR = None
39
+ _SHAPE_DETECTOR = None
40
+ _SHAPE_CLASSIFIER = None
41
+ _FILL_CLASSIFIER = None
42
 
43
  def load_models():
44
  """
45
  Load all models needed for SET detection.
46
+ Returns tuple of (card_detector, shape_detector, shape_classifier, fill_classifier)
47
  """
48
+ global _CARD_DETECTOR, _SHAPE_DETECTOR, _SHAPE_CLASSIFIER, _FILL_CLASSIFIER
49
 
50
  # Return cached models if already loaded
51
+ if all([_CARD_DETECTOR, _SHAPE_DETECTOR, _SHAPE_CLASSIFIER, _FILL_CLASSIFIER]):
52
+ logger.info("Using cached models")
53
+ return _CARD_DETECTOR, _SHAPE_DETECTOR, _SHAPE_CLASSIFIER, _FILL_CLASSIFIER
54
 
55
  try:
56
  from huggingface_hub import hf_hub_download
57
 
58
+ logger.info("Loading models from Hugging Face Hub...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ # Load YOLO Card Detection Model
61
+ card_model_path = hf_hub_download("Oamitai/card-detection", "best.pt")
62
+ card_detector = YOLO(card_model_path)
63
  card_detector.conf = 0.5
64
 
65
+ # Load YOLO Shape Detection Model
66
+ shape_model_path = hf_hub_download("Oamitai/shape-detection", "best.pt")
67
+ shape_detector = YOLO(shape_model_path)
68
  shape_detector.conf = 0.5
69
 
70
+ # Load Shape Classification Model
71
+ shape_classifier = tf.keras.models.load_model(
72
+ hf_hub_download("Oamitai/shape-classification", "shape_model.keras")
73
+ )
74
+
75
+ # Load Fill Classification Model
76
+ fill_classifier = tf.keras.models.load_model(
77
+ hf_hub_download("Oamitai/fill-classification", "fill_model.keras")
78
+ )
79
 
80
  # Use GPU if available
81
  if torch.cuda.is_available():
 
84
  shape_detector.to("cuda")
85
 
86
  # Cache the models
87
+ _CARD_DETECTOR = card_detector
88
+ _SHAPE_DETECTOR = shape_detector
89
+ _SHAPE_CLASSIFIER = shape_classifier
90
+ _FILL_CLASSIFIER = fill_classifier
91
 
92
  logger.info("All models loaded successfully!")
93
  return card_detector, shape_detector, shape_classifier, fill_classifier
 
407
  logger.error(traceback.format_exc())
408
  return input_image, error_message
409
 
 
 
 
 
410
  # =============================================================================
411
  # GRADIO INTERFACE
412
  # =============================================================================
 
440
  interactive=False
441
  )
442
 
 
 
 
 
 
 
 
 
 
443
  # Function bindings inside the Blocks context
444
  find_sets_btn.click(
445
  fn=process_image,