yugidex-scanner / app /model_loader.py
rcrochepeyre
Initial scanner deploy
d116174
"""
Model loading for the scanner service.
Three models are loaded:
1. YOLO OBB β€” card detection (oriented bounding boxes)
2. ViT classifier β€” card name classification (HuggingFace pipeline)
3. Draw wrapper β€” optional, provides rotation correction via draw2
CRITICAL: draw2 uses OBB (result.obb), NEVER result.boxes.
"""
import logging
import torch
from huggingface_hub import hf_hub_download
from ultralytics import YOLO
from transformers import AutoImageProcessor, pipeline as hf_pipeline
from draw.draw import Draw
logger = logging.getLogger(__name__)
_yolo_model: YOLO | None = None
_classifier = None
_draw_instance: Draw | None = None
_image_processor = None
def get_device() -> str:
if torch.cuda.is_available():
return "cuda"
return "cpu"
def load_models():
"""Load all models. Called once at startup."""
global _yolo_model, _classifier, _draw_instance, _image_processor
device = get_device()
logger.info(f"Loading models on device: {device}")
# 1. YOLO OBB model (mandatory)
yolo_path = hf_hub_download(repo_id="HichTala/draw2", filename="ygo_yolo.pt")
_yolo_model = YOLO(yolo_path)
logger.info("YOLO OBB model loaded")
# 2. ViT classifier (mandatory for card names)
# Load independently from Draw wrapper to guarantee availability
_image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
_classifier = hf_pipeline(
"image-classification",
model="HichTala/draw2",
image_processor=_image_processor,
device=device,
)
logger.info("ViT classifier loaded")
# 3. Draw wrapper (optional β€” better rotation correction)
# Draw(source=None) β†’ YOLO uses a default image (boats.jpg), no error
# confidence_threshold=5 means 5% internally (draw2 divides by 100)
try:
_draw_instance = Draw(source=None, confidence_threshold=5)
logger.info("Draw wrapper loaded")
except Exception as e:
logger.warning(f"Draw wrapper failed to load (non-fatal): {e}")
_draw_instance = None
def get_yolo() -> YOLO:
if _yolo_model is None:
raise RuntimeError("YOLO model not loaded β€” call load_models() first")
return _yolo_model
def get_classifier():
if _classifier is None:
raise RuntimeError("ViT classifier not loaded β€” call load_models() first")
return _classifier
def get_draw() -> Draw | None:
"""Returns the Draw instance, or None if it failed to load."""
return _draw_instance