Spaces:
Running
Running
File size: 2,609 Bytes
e42e330 6874c0e e42e330 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import os
import spacy
import torch
from doctr.models import ocr_predictor
from loguru import logger
class ModelManager:
"""Singleton model manager for pre-loading all models at startup."""
_instance = None
_doctr_model = None
_spacy_model = None
_device = None
_models_loaded = False
def __new__(cls):
if cls._instance is None:
cls._instance = super(ModelManager, cls).__new__(cls)
return cls._instance
def __init__(self):
pass
async def _load_models(self):
"""Load all models synchronously."""
logger.info("๐ Starting model pre-loading...")
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"๐ฑ Using device: {self._device}")
# Load doctr model
logger.info("๐ Loading doctr OCR model...")
self._doctr_model = ocr_predictor(pretrained=True)
self._doctr_model.det_predictor.model = (
self._doctr_model.det_predictor.model.to(self._device)
)
self._doctr_model.reco_predictor.model = (
self._doctr_model.reco_predictor.model.to(self._device)
)
logger.info("โ
Doctr model loaded successfully!")
# Load spaCy model
self._spacy_model = spacy.load(os.getenv("SPACY_MODEL_NAME", "en_core_web_sm"))
logger.info(f"โ
spaCy model loaded successfully!")
self._models_loaded = True
logger.info("๐ All models loaded successfully!")
@property
def doctr_model(self):
"""Get the loaded doctr model."""
return self._doctr_model
@property
def spacy_model(self):
"""Get the loaded spaCy model."""
return self._spacy_model
@property
def device(self):
"""Get the device being used."""
return self._device
@property
def models_loaded(self):
"""Check if models are loaded."""
return self._models_loaded
async def ensure_models_loaded(self):
"""Ensure models are loaded (async wrapper)."""
if not self._models_loaded:
await self._load_models()
return True
async def get_model_status(self):
"""Get status of all models."""
return {
"doctr_model": self._doctr_model is not None,
"spacy_model": self._spacy_model is not None,
"device": str(self._device),
"models_loaded": self._models_loaded,
"spacy_model_name": os.getenv("SPACY_MODEL_NAME"),
}
# Global model manager instance
model_manager = ModelManager()
|