""" Model loading for the scanner service. Three models are loaded: 1. YOLO OBB — card detection (oriented bounding boxes) 2. ViT classifier — card name classification (HuggingFace pipeline) 3. Draw wrapper — optional, provides rotation correction via draw2 CRITICAL: draw2 uses OBB (result.obb), NEVER result.boxes. """ import logging import torch from huggingface_hub import hf_hub_download from ultralytics import YOLO from transformers import AutoImageProcessor, pipeline as hf_pipeline from draw.draw import Draw logger = logging.getLogger(__name__) _yolo_model: YOLO | None = None _classifier = None _draw_instance: Draw | None = None _image_processor = None def get_device() -> str: if torch.cuda.is_available(): return "cuda" return "cpu" def load_models(): """Load all models. Called once at startup.""" global _yolo_model, _classifier, _draw_instance, _image_processor device = get_device() logger.info(f"Loading models on device: {device}") # 1. YOLO OBB model (mandatory) yolo_path = hf_hub_download(repo_id="HichTala/draw2", filename="ygo_yolo.pt") _yolo_model = YOLO(yolo_path) logger.info("YOLO OBB model loaded") # 2. ViT classifier (mandatory for card names) # Load independently from Draw wrapper to guarantee availability _image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") _classifier = hf_pipeline( "image-classification", model="HichTala/draw2", image_processor=_image_processor, device=device, ) logger.info("ViT classifier loaded") # 3. Draw wrapper (optional — better rotation correction) # Draw(source=None) → YOLO uses a default image (boats.jpg), no error # confidence_threshold=5 means 5% internally (draw2 divides by 100) try: _draw_instance = Draw(source=None, confidence_threshold=5) logger.info("Draw wrapper loaded") except Exception as e: logger.warning(f"Draw wrapper failed to load (non-fatal): {e}") _draw_instance = None def get_yolo() -> YOLO: if _yolo_model is None: raise RuntimeError("YOLO model not loaded — call load_models() first") return _yolo_model def get_classifier(): if _classifier is None: raise RuntimeError("ViT classifier not loaded — call load_models() first") return _classifier def get_draw() -> Draw | None: """Returns the Draw instance, or None if it failed to load.""" return _draw_instance