Spaces:
Runtime error
Runtime error
| import asyncio | |
| import torch | |
| from typing import Optional | |
| from doctr.models import ocr_predictor | |
| import spacy | |
| from src.config.config import settings | |
| class ModelManager: | |
| """Singleton model manager for pre-loading all models at startup.""" | |
| _instance = None | |
| _doctr_model = None | |
| _spacy_model = None | |
| _device = None | |
| _models_loaded = False | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super(ModelManager, cls).__new__(cls) | |
| return cls._instance | |
| def __init__(self): | |
| if not self._models_loaded: | |
| self._load_models() | |
| def _load_models(self): | |
| """Load all models synchronously.""" | |
| print("π Starting model pre-loading...") | |
| # Set device based on config | |
| if settings.force_cpu: | |
| self._device = torch.device("cpu") | |
| print("π± Using CPU (forced by config)") | |
| else: | |
| self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"π± Using device: {self._device}") | |
| # Load doctr model | |
| print("π Loading doctr OCR model...") | |
| self._doctr_model = ocr_predictor(pretrained=True) | |
| self._doctr_model.det_predictor.model = self._doctr_model.det_predictor.model.to(self._device) | |
| self._doctr_model.reco_predictor.model = self._doctr_model.reco_predictor.model.to(self._device) | |
| print("β Doctr model loaded successfully!") | |
| # Load spaCy model | |
| print(f"π Loading spaCy NER model: {settings.spacy_model_name}...") | |
| try: | |
| self._spacy_model = spacy.load(settings.spacy_model_name) | |
| print(f"β spaCy model ({settings.spacy_model_name}) loaded successfully!") | |
| except OSError: | |
| print(f"β οΈ spaCy model '{settings.spacy_model_name}' not found.") | |
| # Try fallback models | |
| fallback_models = ["en_core_web_sm", "en_core_web_trf"] | |
| for fallback_model in fallback_models: | |
| if fallback_model != settings.spacy_model_name: | |
| try: | |
| print(f"π Trying fallback model: {fallback_model}") | |
| self._spacy_model = spacy.load(fallback_model) | |
| print(f"β spaCy model ({fallback_model}) loaded successfully!") | |
| break | |
| except OSError: | |
| continue | |
| if self._spacy_model is None: | |
| print("β οΈ No spaCy model found. Please install with: python -m spacy download en_core_web_sm") | |
| self._models_loaded = True | |
| print("π All models loaded successfully!") | |
| def doctr_model(self): | |
| """Get the loaded doctr model.""" | |
| return self._doctr_model | |
| def spacy_model(self): | |
| """Get the loaded spaCy model.""" | |
| return self._spacy_model | |
| def device(self): | |
| """Get the device being used.""" | |
| return self._device | |
| def models_loaded(self): | |
| """Check if models are loaded.""" | |
| return self._models_loaded | |
| async def ensure_models_loaded(self): | |
| """Ensure models are loaded (async wrapper).""" | |
| if not self._models_loaded: | |
| await asyncio.get_event_loop().run_in_executor(None, self._load_models) | |
| return True | |
| def get_model_status(self): | |
| """Get status of all models.""" | |
| return { | |
| "doctr_model": self._doctr_model is not None, | |
| "spacy_model": self._spacy_model is not None, | |
| "device": str(self._device), | |
| "models_loaded": self._models_loaded, | |
| "spacy_model_name": settings.spacy_model_name, | |
| "force_cpu": settings.force_cpu | |
| } | |
| # Global model manager instance | |
| model_manager = ModelManager() |