ocr-engine-1 / src /utils /_model_manager.py
kanha-upadhyay's picture
Update spaCy model loading to use 'en_core_web_sm' instead of 'en_core_web_trf'
6874c0e
import os
import spacy
import torch
from doctr.models import ocr_predictor
from loguru import logger
class ModelManager:
"""Singleton model manager for pre-loading all models at startup."""
_instance = None
_doctr_model = None
_spacy_model = None
_device = None
_models_loaded = False
def __new__(cls):
if cls._instance is None:
cls._instance = super(ModelManager, cls).__new__(cls)
return cls._instance
def __init__(self):
pass
async def _load_models(self):
"""Load all models synchronously."""
logger.info("πŸš€ Starting model pre-loading...")
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"πŸ“± Using device: {self._device}")
# Load doctr model
logger.info("πŸ”„ Loading doctr OCR model...")
self._doctr_model = ocr_predictor(pretrained=True)
self._doctr_model.det_predictor.model = (
self._doctr_model.det_predictor.model.to(self._device)
)
self._doctr_model.reco_predictor.model = (
self._doctr_model.reco_predictor.model.to(self._device)
)
logger.info("βœ… Doctr model loaded successfully!")
# Load spaCy model
self._spacy_model = spacy.load(os.getenv("SPACY_MODEL_NAME", "en_core_web_sm"))
logger.info(f"βœ… spaCy model loaded successfully!")
self._models_loaded = True
logger.info("πŸŽ‰ All models loaded successfully!")
@property
def doctr_model(self):
"""Get the loaded doctr model."""
return self._doctr_model
@property
def spacy_model(self):
"""Get the loaded spaCy model."""
return self._spacy_model
@property
def device(self):
"""Get the device being used."""
return self._device
@property
def models_loaded(self):
"""Check if models are loaded."""
return self._models_loaded
async def ensure_models_loaded(self):
"""Ensure models are loaded (async wrapper)."""
if not self._models_loaded:
await self._load_models()
return True
async def get_model_status(self):
"""Get status of all models."""
return {
"doctr_model": self._doctr_model is not None,
"spacy_model": self._spacy_model is not None,
"device": str(self._device),
"models_loaded": self._models_loaded,
"spacy_model_name": os.getenv("SPACY_MODEL_NAME"),
}
# Global model manager instance
model_manager = ModelManager()