Oamitai commited on
Commit
2da7fc0
·
verified ·
1 Parent(s): b1c8777

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -433
app.py CHANGED
@@ -34,226 +34,99 @@ logging.basicConfig(level=logging.INFO,
34
  logger = logging.getLogger("set_detector")
35
 
36
  # =============================================================================
37
- # MODEL PATHS & LOADING
38
  # =============================================================================
39
  # For loading models from Hugging Face Hub
40
  HF_MODEL_REPO_PREFIX = "Omamitai"
41
-
42
- # Define model repos and paths
43
  CARD_DETECTION_REPO = f"{HF_MODEL_REPO_PREFIX}/card-detection"
44
  SHAPE_DETECTION_REPO = f"{HF_MODEL_REPO_PREFIX}/shape-detection"
45
  SHAPE_CLASSIFICATION_REPO = f"{HF_MODEL_REPO_PREFIX}/shape-classification"
46
  FILL_CLASSIFICATION_REPO = f"{HF_MODEL_REPO_PREFIX}/fill-classification"
47
 
48
- # Model filenames
49
  CARD_MODEL_FILENAME = "best.pt"
50
  SHAPE_MODEL_FILENAME = "best.pt"
51
  SHAPE_CLASS_MODEL_FILENAME = "shape_model.keras"
52
  FILL_CLASS_MODEL_FILENAME = "fill_model.keras"
53
 
54
- # For local testing: fallback to local models if HF downloads fail
55
- # Use the local directory structure as fallback
56
- if os.path.exists("/home/user"): # Check if we're on HF Spaces
57
- local_base_dir = Path("/home/user/app/models")
58
- else:
59
- local_base_dir = Path(os.path.dirname(os.path.abspath(__file__))) / "models"
60
-
61
- local_char_path = local_base_dir / "Characteristics" / "11022025"
62
- local_shape_path = local_base_dir / "Shape" / "15052024"
63
- local_card_path = local_base_dir / "Card" / "16042024"
64
-
65
  # Global variables for model caching
66
  _MODEL_SHAPE = None
67
  _MODEL_FILL = None
68
  _DETECTOR_CARD = None
69
  _DETECTOR_SHAPE = None
70
- _MODELS_LOADED = False
71
- _MODEL_LOADING_ERROR = None
72
 
73
- def load_classification_models() -> Tuple[tf.keras.Model, tf.keras.Model]:
74
  """
75
- Loads the Keras models for 'shape' and 'fill' classification from HuggingFace Hub.
76
- Returns (shape_model, fill_model).
77
  """
78
- global _MODEL_SHAPE, _MODEL_FILL, _MODEL_LOADING_ERROR
79
 
80
- # If models are already loaded, return them
81
- if _MODEL_SHAPE is not None and _MODEL_FILL is not None:
82
- return _MODEL_SHAPE, _MODEL_FILL
83
 
84
  try:
85
  from huggingface_hub import hf_hub_download
86
 
87
- # Try to download from HuggingFace Hub
88
- logger.info(f"Downloading shape classification model from {SHAPE_CLASSIFICATION_REPO}...")
 
 
 
 
 
 
89
  shape_model_path = hf_hub_download(
 
 
 
 
 
 
 
 
90
  repo_id=SHAPE_CLASSIFICATION_REPO,
91
  filename=SHAPE_CLASS_MODEL_FILENAME,
92
  cache_dir="./hf_cache"
93
  )
94
 
95
- logger.info(f"Downloading fill classification model from {FILL_CLASSIFICATION_REPO}...")
96
- fill_model_path = hf_hub_download(
97
  repo_id=FILL_CLASSIFICATION_REPO,
98
  filename=FILL_CLASS_MODEL_FILENAME,
99
  cache_dir="./hf_cache"
100
  )
101
 
102
- # Load the models
103
- logger.info("Loading classification models...")
104
- model_shape = load_model(str(shape_model_path))
105
- model_fill = load_model(str(fill_model_path))
106
-
107
- logger.info("Classification models loaded successfully")
108
- _MODEL_SHAPE, _MODEL_FILL = model_shape, model_fill
109
- return model_shape, model_fill
110
-
111
- except Exception as e:
112
- error_msg = f"Error downloading classification models from HF Hub: {str(e)}"
113
- logger.error(error_msg)
114
 
115
- # Try fallback to local files
116
- try:
117
- logger.info("Trying fallback to local model files...")
118
- shape_model_path = local_char_path / SHAPE_CLASS_MODEL_FILENAME
119
- fill_model_path = local_char_path / FILL_CLASS_MODEL_FILENAME
120
-
121
- if not shape_model_path.exists() or not fill_model_path.exists():
122
- raise FileNotFoundError("Local model files not found")
123
-
124
- # Load the models
125
- model_shape = load_model(str(shape_model_path))
126
- model_fill = load_model(str(fill_model_path))
127
-
128
- logger.info("Classification models loaded successfully from local files")
129
- _MODEL_SHAPE, _MODEL_FILL = model_shape, model_fill
130
- return model_shape, model_fill
131
-
132
- except Exception as fallback_error:
133
- error_msg = f"{error_msg}\nFallback to local files also failed: {str(fallback_error)}"
134
- logger.error(error_msg)
135
- _MODEL_LOADING_ERROR = error_msg
136
- return None, None
137
-
138
- def load_detection_models() -> Tuple[YOLO, YOLO]:
139
- """
140
- Loads the YOLO detection models for cards and shapes from HuggingFace Hub.
141
- Returns (card_detector, shape_detector).
142
- """
143
- global _DETECTOR_CARD, _DETECTOR_SHAPE, _MODEL_LOADING_ERROR
144
-
145
- # If models are already loaded, return them
146
- if _DETECTOR_CARD is not None and _DETECTOR_SHAPE is not None:
147
- return _DETECTOR_CARD, _DETECTOR_SHAPE
148
-
149
- try:
150
- from huggingface_hub import hf_hub_download
151
 
152
- # Try to download from HuggingFace Hub
153
- logger.info(f"Downloading card detection model from {CARD_DETECTION_REPO}...")
154
- card_model_path = hf_hub_download(
155
- repo_id=CARD_DETECTION_REPO,
156
- filename=CARD_MODEL_FILENAME,
157
- cache_dir="./hf_cache"
158
- )
159
-
160
- logger.info(f"Downloading shape detection model from {SHAPE_DETECTION_REPO}...")
161
- shape_model_path = hf_hub_download(
162
- repo_id=SHAPE_DETECTION_REPO,
163
- filename=SHAPE_MODEL_FILENAME,
164
- cache_dir="./hf_cache"
165
- )
166
-
167
- # Load the models
168
- logger.info("Loading detection models...")
169
- detector_shape = YOLO(str(shape_model_path))
170
- detector_shape.conf = 0.5
171
- detector_card = YOLO(str(card_model_path))
172
- detector_card.conf = 0.5
173
 
174
  # Use GPU if available
175
  if torch.cuda.is_available():
176
  logger.info("CUDA is available. Using GPU for inference.")
177
- detector_card.to("cuda")
178
- detector_shape.to("cuda")
179
- else:
180
- logger.info("CUDA is not available. Using CPU for inference.")
181
-
182
- logger.info("Detection models loaded successfully")
183
- _DETECTOR_CARD, _DETECTOR_SHAPE = detector_card, detector_shape
184
- return detector_card, detector_shape
185
 
186
- except Exception as e:
187
- error_msg = f"Error downloading detection models from HF Hub: {str(e)}"
188
- logger.error(error_msg)
 
 
189
 
190
- # Try fallback to local files
191
- try:
192
- logger.info("Trying fallback to local model files...")
193
- shape_model_path = local_shape_path / SHAPE_MODEL_FILENAME
194
- card_model_path = local_card_path / CARD_MODEL_FILENAME
195
-
196
- if not shape_model_path.exists() or not card_model_path.exists():
197
- raise FileNotFoundError("Local model files not found")
198
-
199
- # Load the models
200
- detector_shape = YOLO(str(shape_model_path))
201
- detector_shape.conf = 0.5
202
- detector_card = YOLO(str(card_model_path))
203
- detector_card.conf = 0.5
204
-
205
- # Use GPU if available
206
- if torch.cuda.is_available():
207
- logger.info("CUDA is available. Using GPU for inference.")
208
- detector_card.to("cuda")
209
- detector_shape.to("cuda")
210
-
211
- logger.info("Detection models loaded successfully from local files")
212
- _DETECTOR_CARD, _DETECTOR_SHAPE = detector_card, detector_shape
213
- return detector_card, detector_shape
214
-
215
- except Exception as fallback_error:
216
- error_msg = f"{error_msg}\nFallback to local files also failed: {str(fallback_error)}"
217
- logger.error(error_msg)
218
- _MODEL_LOADING_ERROR = error_msg
219
- return None, None
220
-
221
- def load_all_models() -> bool:
222
- """
223
- Loads all required models and returns True if successful.
224
- """
225
- global _MODELS_LOADED, _MODEL_LOADING_ERROR
226
-
227
- if _MODELS_LOADED:
228
- return True
229
-
230
- try:
231
- model_shape, model_fill = load_classification_models()
232
- detector_card, detector_shape = load_detection_models()
233
 
234
- models_loaded = all([model_shape, model_fill, detector_card, detector_shape])
235
- _MODELS_LOADED = models_loaded
236
-
237
- if not models_loaded and _MODEL_LOADING_ERROR is None:
238
- _MODEL_LOADING_ERROR = "Unknown error loading models"
239
-
240
- return models_loaded
241
  except Exception as e:
242
  error_msg = f"Error loading models: {str(e)}"
243
  logger.error(error_msg)
244
- _MODEL_LOADING_ERROR = error_msg
245
- return False
246
-
247
- def get_model_status() -> str:
248
- """
249
- Returns a status message about the models.
250
- """
251
- if _MODELS_LOADED:
252
- return "All models loaded successfully!"
253
- elif _MODEL_LOADING_ERROR:
254
- return f"Error: {_MODEL_LOADING_ERROR}"
255
- else:
256
- return "Models not loaded yet. Click 'Load Models' to preload them."
257
 
258
  # =============================================================================
259
  # UTILITY & DETECTION FUNCTIONS
@@ -487,57 +360,6 @@ def draw_detected_sets(board_img: np.ndarray, sets_detected: List[dict]) -> np.n
487
  )
488
  return board_img
489
 
490
- def identify_sets_from_image(
491
- board_img: np.ndarray
492
- ) -> Tuple[List[dict], np.ndarray, str]:
493
- """
494
- End-to-end pipeline to classify cards on the board and detect valid sets.
495
- Returns a tuple of (list of sets, annotated image, status message).
496
- """
497
- # Load models
498
- if not load_all_models():
499
- error_msg = _MODEL_LOADING_ERROR or "Error: Could not load models."
500
- return [], board_img, error_msg
501
-
502
- card_detector, shape_detector = _DETECTOR_CARD, _DETECTOR_SHAPE
503
- model_shape, model_fill = _MODEL_SHAPE, _MODEL_FILL
504
-
505
- # Convert image to BGR if needed (OpenCV format)
506
- if len(board_img.shape) == 3 and board_img.shape[2] == 4: # RGBA
507
- board_img = cv2.cvtColor(board_img, cv2.COLOR_RGBA2BGR)
508
- elif len(board_img.shape) == 3 and board_img.shape[2] == 3:
509
- # We assume the image is already in BGR format (OpenCV standard)
510
- # If it's in RGB format (common from web uploads), we'll convert it
511
- board_img = cv2.cvtColor(board_img, cv2.COLOR_RGB2BGR)
512
- else:
513
- return [], board_img, "Error: Unsupported image format. Please upload a color image."
514
-
515
- # 1. Check and fix orientation if needed
516
- processed, was_rotated = verify_and_rotate_image(board_img, card_detector)
517
-
518
- # 2. Verify that cards are present
519
- cards = detect_cards(processed, card_detector)
520
- if not cards:
521
- return [], cv2.cvtColor(board_img, cv2.COLOR_BGR2RGB), "No cards detected in the image. Please check that it's a SET game board."
522
-
523
- # 3. Classify each card's features, then find sets
524
- df_cards = classify_cards_on_board(processed, card_detector, shape_detector, model_fill, model_shape)
525
- found_sets = locate_all_sets(df_cards)
526
-
527
- if not found_sets:
528
- return [], cv2.cvtColor(processed, cv2.COLOR_BGR2RGB), "Cards detected, but no valid SETs found. You may need to add more cards to the table!"
529
-
530
- # 4. Draw sets on a copy of the image
531
- annotated = draw_detected_sets(processed.copy(), found_sets)
532
-
533
- # 5. Restore orientation if we rotated earlier
534
- final_output = restore_orientation(annotated, was_rotated)
535
-
536
- # Convert back to RGB for display
537
- final_output_rgb = cv2.cvtColor(final_output, cv2.COLOR_BGR2RGB)
538
-
539
- return found_sets, final_output_rgb, f"Found {len(found_sets)} SET(s) in the image."
540
-
541
  def optimize_image_size(image_array: np.ndarray, max_dim=1200) -> np.ndarray:
542
  """
543
  Resizes an image if its largest dimension exceeds max_dim, to reduce processing time.
@@ -557,46 +379,57 @@ def optimize_image_size(image_array: np.ndarray, max_dim=1200) -> np.ndarray:
557
  return cv2.resize(image_array, (new_width, new_height), interpolation=cv2.INTER_AREA)
558
  return image_array
559
 
560
- # =============================================================================
561
- # MAIN PROCESSING FUNCTIONS FOR GRADIO
562
- # =============================================================================
563
- def preload_models():
564
- """
565
- Function to preload models and return status.
566
- """
567
- try:
568
- if load_all_models():
569
- return "Models loaded successfully! Ready to detect SETs."
570
- else:
571
- return f"Error loading models: {_MODEL_LOADING_ERROR or 'Unknown error'}"
572
- except Exception as e:
573
- return f"Error loading models: {str(e)}"
574
-
575
  @spaces.GPU
576
- def process_set_image(input_image):
577
  """
578
- Main processing function for the Gradio interface.
579
- Takes an input image, processes it to find SETs, and returns the output image and status.
580
-
581
- Uses @spaces.GPU for Hugging Face Spaces zero-GPU optimization.
582
  """
583
  if input_image is None:
584
  return None, "Please upload an image."
585
 
586
  try:
587
  start_time = time.time()
588
- logger.info("Processing image...")
589
 
590
- # Check if image needs to be optimized (resized)
591
- optimized_image = optimize_image_size(input_image)
592
 
593
- # Identify sets
594
- found_sets, annotated_image, status_message = identify_sets_from_image(optimized_image)
595
 
596
- process_time = time.time() - start_time
597
- logger.info(f"Image processed in {process_time:.2f} seconds.")
 
 
 
 
598
 
599
- return annotated_image, status_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
 
601
  except Exception as e:
602
  error_message = f"Error processing image: {str(e)}"
@@ -604,205 +437,69 @@ def process_set_image(input_image):
604
  logger.error(traceback.format_exc())
605
  return input_image, error_message
606
 
 
 
 
 
607
  # =============================================================================
608
  # GRADIO INTERFACE
609
  # =============================================================================
610
- def create_gradio_interface():
611
- """
612
- Creates and returns the Gradio interface for the SET Game Detector.
613
- """
614
- # CSS for styling the Gradio interface
615
- css = """
616
- @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@400;500;600;700&display=swap');
617
-
618
- .gradio-container {
619
- font-family: 'Poppins', sans-serif;
620
- }
621
- .app-header {
622
- text-align: center;
623
- margin-bottom: 20px;
624
- background: linear-gradient(90deg, rgba(124, 58, 237, 0.1) 0%, rgba(236, 72, 153, 0.1) 100%);
625
- padding: 1rem;
626
- border-radius: 12px;
627
- }
628
- .app-header h1 {
629
- font-size: 2.5rem;
630
- background: linear-gradient(90deg, #8B5CF6 0%, #7C3AED 50%, #EC4899 100%);
631
- -webkit-background-clip: text;
632
- background-clip: text;
633
- -webkit-text-fill-color: transparent;
634
- margin-bottom: 5px;
635
- }
636
- .app-header p {
637
- font-size: 1.1rem;
638
- opacity: 0.8;
639
- margin-top: 0;
640
- }
641
- .footer {
642
- text-align: center;
643
- margin-top: 20px;
644
- padding: 10px;
645
- background: linear-gradient(90deg, rgba(124, 58, 237, 0.05) 0%, rgba(236, 72, 153, 0.05) 100%);
646
- border-radius: 12px;
647
- }
648
 
649
- /* Responsive design for mobile */
650
- @media (max-width: 600px) {
651
- .app-header h1 {
652
- font-size: 1.8rem;
653
- }
654
- .app-header p {
655
- font-size: 0.9rem;
656
- }
657
- }
658
-
659
- /* Custom styling for buttons */
660
- #find-sets-btn {
661
- background: linear-gradient(90deg, #7C3AED 0%, #EC4899 100%);
662
- color: white !important;
663
- }
664
- #find-sets-btn:hover {
665
- opacity: 0.9;
666
- }
667
-
668
- /* Image containers */
669
- .image-container {
670
- border-radius: 12px;
671
- overflow: hidden;
672
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
673
- }
674
-
675
- /* Status box styling */
676
- #status-box {
677
- font-weight: 500;
678
- border-radius: 8px;
679
- }
680
- """
681
-
682
- # Create the Gradio interface
683
- with gr.Blocks(css=css, title="SET Game Detector") as demo:
684
- # Header
685
- gr.HTML("""
686
- <div class="app-header">
687
- <h1>🎴 SET Game Detector</h1>
688
- <p>Upload an image of a SET board to find all valid sets</p>
689
- </div>
690
- """)
691
-
692
- # Model status display
693
- model_status = gr.Textbox(
694
- label="Model Status",
695
- value=get_model_status(),
696
- interactive=False
697
- )
698
- load_models_btn = gr.Button("🔄 Load Models", visible=not _MODELS_LOADED)
699
-
700
- # Main layout
701
- with gr.Row():
702
- with gr.Column():
703
- # Fixed: Removed 'tool' parameter which is not supported in older Gradio versions
704
- input_image = gr.Image(
705
- label="Upload SET Board Image",
706
- type="numpy",
707
- elem_id="input-image",
708
- elem_classes="image-container"
709
- )
710
-
711
- with gr.Row():
712
- process_btn = gr.Button(
713
- "🔎 Find Sets",
714
- variant="primary",
715
- elem_id="find-sets-btn",
716
- interactive=_MODELS_LOADED
717
- )
718
- clear_btn = gr.Button("🗑️ Clear", variant="secondary")
719
 
720
- with gr.Column():
721
- output_image = gr.Image(
722
- label="Detected Sets",
723
- elem_id="output-image",
724
- elem_classes="image-container",
725
- interactive=False
726
- )
727
- status = gr.Textbox(
728
- label="Status",
729
- placeholder="Upload an image and click 'Find Sets'",
730
- elem_id="status-box",
731
- interactive=False
732
- )
733
-
734
- # Example images section - Create an examples directory for deployment
735
- examples_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "examples")
736
- os.makedirs(examples_dir, exist_ok=True)
737
-
738
- # Add examples section - using simpler format for compatibility
739
- gr.Examples(
740
- examples=[
741
- os.path.join(examples_dir, "set_example1.jpg"),
742
- os.path.join(examples_dir, "set_example2.jpg")
743
- ],
744
- inputs=input_image
745
- )
746
-
747
- # Footer with attribution
748
- gr.HTML(
749
- """
750
- <div class="footer">
751
- <p>SET Game Detector by <a href="https://github.com/omamitai" target="_blank">omamitai</a> |
752
- Gradio version adapted for Hugging Face Spaces</p>
753
- </div>
754
- """
755
- )
756
 
757
- # Function bindings
758
- load_models_btn.click(
759
- fn=preload_models,
760
- outputs=[model_status]
 
 
 
761
  )
762
 
763
- process_btn.click(
764
- fn=process_set_image,
 
765
  inputs=[input_image],
766
  outputs=[output_image, status]
767
  )
768
 
769
- clear_btn.click(
770
- fn=lambda: (None, "Ready for new image"),
771
- outputs=[output_image, status]
772
- )
773
-
774
- # Update button status when models are loaded
775
- if _MODELS_LOADED:
776
- process_btn.interactive = True
777
- load_models_btn.visible = False
778
- else:
779
- # Try to load models on startup
780
- try:
781
- if load_all_models():
782
- model_status.value = "Models loaded successfully! Ready to detect SETs."
783
- process_btn.interactive = True
784
- load_models_btn.visible = False
785
- except Exception as e:
786
- logger.error(f"Error preloading models: {str(e)}")
787
-
788
- return demo
789
 
790
  # =============================================================================
791
  # MAIN EXECUTION
792
  # =============================================================================
793
  if __name__ == "__main__":
794
- # Initialize HF hub download for models when using Hugging Face Spaces
795
- try:
796
- from huggingface_hub import hf_hub_download
797
- except ImportError:
798
- logger.warning("huggingface_hub not available. Will try to use local models.")
799
-
800
- # Create examples directory if it doesn't exist
801
- examples_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "examples")
802
- os.makedirs(examples_dir, exist_ok=True)
803
-
804
- # Create the Gradio interface
805
- demo = create_gradio_interface()
806
-
807
  # Launch the app
808
  demo.queue().launch()
 
34
  logger = logging.getLogger("set_detector")
35
 
36
  # =============================================================================
37
+ # MODEL PATHS
38
  # =============================================================================
39
  # For loading models from Hugging Face Hub
40
  HF_MODEL_REPO_PREFIX = "Omamitai"
 
 
41
  CARD_DETECTION_REPO = f"{HF_MODEL_REPO_PREFIX}/card-detection"
42
  SHAPE_DETECTION_REPO = f"{HF_MODEL_REPO_PREFIX}/shape-detection"
43
  SHAPE_CLASSIFICATION_REPO = f"{HF_MODEL_REPO_PREFIX}/shape-classification"
44
  FILL_CLASSIFICATION_REPO = f"{HF_MODEL_REPO_PREFIX}/fill-classification"
45
 
 
46
  CARD_MODEL_FILENAME = "best.pt"
47
  SHAPE_MODEL_FILENAME = "best.pt"
48
  SHAPE_CLASS_MODEL_FILENAME = "shape_model.keras"
49
  FILL_CLASS_MODEL_FILENAME = "fill_model.keras"
50
 
 
 
 
 
 
 
 
 
 
 
 
51
  # Global variables for model caching
52
  _MODEL_SHAPE = None
53
  _MODEL_FILL = None
54
  _DETECTOR_CARD = None
55
  _DETECTOR_SHAPE = None
 
 
56
 
57
+ def load_models():
58
  """
59
+ Load all models needed for SET detection.
60
+ Returns tuple of (card_detector, shape_detector, shape_model, fill_model)
61
  """
62
+ global _MODEL_SHAPE, _MODEL_FILL, _DETECTOR_CARD, _DETECTOR_SHAPE
63
 
64
+ # Return cached models if already loaded
65
+ if all([_MODEL_SHAPE, _MODEL_FILL, _DETECTOR_CARD, _DETECTOR_SHAPE]):
66
+ return _DETECTOR_CARD, _DETECTOR_SHAPE, _MODEL_SHAPE, _MODEL_FILL
67
 
68
  try:
69
  from huggingface_hub import hf_hub_download
70
 
71
+ # Download and load YOLO models
72
+ logger.info("Downloading detection models...")
73
+ card_model_path = hf_hub_download(
74
+ repo_id=CARD_DETECTION_REPO,
75
+ filename=CARD_MODEL_FILENAME,
76
+ cache_dir="./hf_cache"
77
+ )
78
+
79
  shape_model_path = hf_hub_download(
80
+ repo_id=SHAPE_DETECTION_REPO,
81
+ filename=SHAPE_MODEL_FILENAME,
82
+ cache_dir="./hf_cache"
83
+ )
84
+
85
+ # Download and load classification models
86
+ logger.info("Downloading classification models...")
87
+ shape_class_path = hf_hub_download(
88
  repo_id=SHAPE_CLASSIFICATION_REPO,
89
  filename=SHAPE_CLASS_MODEL_FILENAME,
90
  cache_dir="./hf_cache"
91
  )
92
 
93
+ fill_class_path = hf_hub_download(
 
94
  repo_id=FILL_CLASSIFICATION_REPO,
95
  filename=FILL_CLASS_MODEL_FILENAME,
96
  cache_dir="./hf_cache"
97
  )
98
 
99
+ # Load all models
100
+ logger.info("Loading models...")
101
+ card_detector = YOLO(str(card_model_path))
102
+ card_detector.conf = 0.5
 
 
 
 
 
 
 
 
103
 
104
+ shape_detector = YOLO(str(shape_model_path))
105
+ shape_detector.conf = 0.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ shape_classifier = load_model(str(shape_class_path))
108
+ fill_classifier = load_model(str(fill_class_path))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  # Use GPU if available
111
  if torch.cuda.is_available():
112
  logger.info("CUDA is available. Using GPU for inference.")
113
+ card_detector.to("cuda")
114
+ shape_detector.to("cuda")
 
 
 
 
 
 
115
 
116
+ # Cache the models
117
+ _DETECTOR_CARD = card_detector
118
+ _DETECTOR_SHAPE = shape_detector
119
+ _MODEL_SHAPE = shape_classifier
120
+ _MODEL_FILL = fill_classifier
121
 
122
+ logger.info("All models loaded successfully!")
123
+ return card_detector, shape_detector, shape_classifier, fill_classifier
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
 
 
 
 
 
 
 
125
  except Exception as e:
126
  error_msg = f"Error loading models: {str(e)}"
127
  logger.error(error_msg)
128
+ logger.error(traceback.format_exc())
129
+ raise ValueError(error_msg)
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  # =============================================================================
132
  # UTILITY & DETECTION FUNCTIONS
 
360
  )
361
  return board_img
362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  def optimize_image_size(image_array: np.ndarray, max_dim=1200) -> np.ndarray:
364
  """
365
  Resizes an image if its largest dimension exceeds max_dim, to reduce processing time.
 
379
  return cv2.resize(image_array, (new_width, new_height), interpolation=cv2.INTER_AREA)
380
  return image_array
381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  @spaces.GPU
383
+ def process_image(input_image):
384
  """
385
+ Main processing function for SET detection.
386
+ Takes an input image, processes it, and returns the annotated image and status.
 
 
387
  """
388
  if input_image is None:
389
  return None, "Please upload an image."
390
 
391
  try:
392
  start_time = time.time()
 
393
 
394
+ # Load models
395
+ card_detector, shape_detector, shape_model, fill_model = load_models()
396
 
397
+ # Optimize image size
398
+ optimized_img = optimize_image_size(input_image)
399
 
400
+ # Convert to BGR if needed (OpenCV format)
401
+ if len(optimized_img.shape) == 3 and optimized_img.shape[2] == 4: # RGBA
402
+ optimized_img = cv2.cvtColor(optimized_img, cv2.COLOR_RGBA2BGR)
403
+ elif len(optimized_img.shape) == 3 and optimized_img.shape[2] == 3:
404
+ # RGB to BGR
405
+ optimized_img = cv2.cvtColor(optimized_img, cv2.COLOR_RGB2BGR)
406
 
407
+ # Check and fix orientation
408
+ processed_img, was_rotated = verify_and_rotate_image(optimized_img, card_detector)
409
+
410
+ # Detect cards
411
+ cards = detect_cards(processed_img, card_detector)
412
+ if not cards:
413
+ return cv2.cvtColor(optimized_img, cv2.COLOR_BGR2RGB), "No cards detected. Please check that it's a SET game board."
414
+
415
+ # Classify cards and find sets
416
+ df_cards = classify_cards_on_board(processed_img, card_detector, shape_detector, fill_model, shape_model)
417
+ found_sets = locate_all_sets(df_cards)
418
+
419
+ if not found_sets:
420
+ return cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB), "Cards detected, but no valid SETs found!"
421
+
422
+ # Draw sets on the image
423
+ annotated = draw_detected_sets(processed_img.copy(), found_sets)
424
+
425
+ # Restore original orientation if needed
426
+ final_output = restore_orientation(annotated, was_rotated)
427
+
428
+ # Convert back to RGB for display
429
+ final_output_rgb = cv2.cvtColor(final_output, cv2.COLOR_BGR2RGB)
430
+
431
+ process_time = time.time() - start_time
432
+ return final_output_rgb, f"Found {len(found_sets)} SET(s) in {process_time:.2f} seconds."
433
 
434
  except Exception as e:
435
  error_message = f"Error processing image: {str(e)}"
 
437
  logger.error(traceback.format_exc())
438
  return input_image, error_message
439
 
440
+ # Create examples directory
441
+ examples_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "examples")
442
+ os.makedirs(examples_dir, exist_ok=True)
443
+
444
  # =============================================================================
445
  # GRADIO INTERFACE
446
  # =============================================================================
447
+ with gr.Blocks(title="SET Game Detector") as demo:
448
+ gr.HTML("""
449
+ <div style="text-align: center; margin-bottom: 1rem;">
450
+ <h1 style="margin-bottom: 0.5rem;">🎴 SET Game Detector</h1>
451
+ <p>Upload an image of a SET game board to find all valid sets</p>
452
+ </div>
453
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
+ with gr.Row():
456
+ with gr.Column():
457
+ input_image = gr.Image(
458
+ label="Upload SET Board Image",
459
+ type="numpy"
460
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
+ find_sets_btn = gr.Button(
463
+ "🔎 Find Sets",
464
+ variant="primary"
465
+ )
466
+
467
+ with gr.Column():
468
+ output_image = gr.Image(
469
+ label="Detected Sets"
470
+ )
471
+ status = gr.Textbox(
472
+ label="Status",
473
+ value="Upload an image and click 'Find Sets'",
474
+ interactive=False
475
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
 
477
+ # Examples - simplified
478
+ gr.Examples(
479
+ examples=[
480
+ os.path.join(examples_dir, "set_example1.jpg"),
481
+ os.path.join(examples_dir, "set_example2.jpg")
482
+ ],
483
+ inputs=input_image
484
  )
485
 
486
+ # Function bindings inside the Blocks context
487
+ find_sets_btn.click(
488
+ fn=process_image,
489
  inputs=[input_image],
490
  outputs=[output_image, status]
491
  )
492
 
493
+ gr.HTML("""
494
+ <div style="text-align: center; margin-top: 1rem; padding: 0.5rem; font-size: 0.8rem;">
495
+ <p>SET Game Detector by <a href="https://github.com/omamitai" target="_blank">omamitai</a> |
496
+ Gradio version adapted for Hugging Face Spaces</p>
497
+ </div>
498
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
 
500
  # =============================================================================
501
  # MAIN EXECUTION
502
  # =============================================================================
503
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  # Launch the app
505
  demo.queue().launch()