Spaces:
Paused
Paused
| """ | |
| Model loading for the scanner service. | |
| Three models are loaded: | |
| 1. YOLO OBB β card detection (oriented bounding boxes) | |
| 2. ViT classifier β card name classification (HuggingFace pipeline) | |
| 3. Draw wrapper β optional, provides rotation correction via draw2 | |
| CRITICAL: draw2 uses OBB (result.obb), NEVER result.boxes. | |
| """ | |
| import logging | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from ultralytics import YOLO | |
| from transformers import AutoImageProcessor, pipeline as hf_pipeline | |
| from draw.draw import Draw | |
| logger = logging.getLogger(__name__) | |
| _yolo_model: YOLO | None = None | |
| _classifier = None | |
| _draw_instance: Draw | None = None | |
| _image_processor = None | |
| def get_device() -> str: | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| return "cpu" | |
| def load_models(): | |
| """Load all models. Called once at startup.""" | |
| global _yolo_model, _classifier, _draw_instance, _image_processor | |
| device = get_device() | |
| logger.info(f"Loading models on device: {device}") | |
| # 1. YOLO OBB model (mandatory) | |
| yolo_path = hf_hub_download(repo_id="HichTala/draw2", filename="ygo_yolo.pt") | |
| _yolo_model = YOLO(yolo_path) | |
| logger.info("YOLO OBB model loaded") | |
| # 2. ViT classifier (mandatory for card names) | |
| # Load independently from Draw wrapper to guarantee availability | |
| _image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") | |
| _classifier = hf_pipeline( | |
| "image-classification", | |
| model="HichTala/draw2", | |
| image_processor=_image_processor, | |
| device=device, | |
| ) | |
| logger.info("ViT classifier loaded") | |
| # 3. Draw wrapper (optional β better rotation correction) | |
| # Draw(source=None) β YOLO uses a default image (boats.jpg), no error | |
| # confidence_threshold=5 means 5% internally (draw2 divides by 100) | |
| try: | |
| _draw_instance = Draw(source=None, confidence_threshold=5) | |
| logger.info("Draw wrapper loaded") | |
| except Exception as e: | |
| logger.warning(f"Draw wrapper failed to load (non-fatal): {e}") | |
| _draw_instance = None | |
| def get_yolo() -> YOLO: | |
| if _yolo_model is None: | |
| raise RuntimeError("YOLO model not loaded β call load_models() first") | |
| return _yolo_model | |
| def get_classifier(): | |
| if _classifier is None: | |
| raise RuntimeError("ViT classifier not loaded β call load_models() first") | |
| return _classifier | |
| def get_draw() -> Draw | None: | |
| """Returns the Draw instance, or None if it failed to load.""" | |
| return _draw_instance | |