Spaces:
Runtime error
Runtime error
File size: 3,929 Bytes
4e71548 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import asyncio
import torch
from typing import Optional
from doctr.models import ocr_predictor
import spacy
from src.config.config import settings
class ModelManager:
"""Singleton model manager for pre-loading all models at startup."""
_instance = None
_doctr_model = None
_spacy_model = None
_device = None
_models_loaded = False
def __new__(cls):
if cls._instance is None:
cls._instance = super(ModelManager, cls).__new__(cls)
return cls._instance
def __init__(self):
if not self._models_loaded:
self._load_models()
def _load_models(self):
"""Load all models synchronously."""
print("π Starting model pre-loading...")
# Set device based on config
if settings.force_cpu:
self._device = torch.device("cpu")
print("π± Using CPU (forced by config)")
else:
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"π± Using device: {self._device}")
# Load doctr model
print("π Loading doctr OCR model...")
self._doctr_model = ocr_predictor(pretrained=True)
self._doctr_model.det_predictor.model = self._doctr_model.det_predictor.model.to(self._device)
self._doctr_model.reco_predictor.model = self._doctr_model.reco_predictor.model.to(self._device)
print("β
Doctr model loaded successfully!")
# Load spaCy model
print(f"π Loading spaCy NER model: {settings.spacy_model_name}...")
try:
self._spacy_model = spacy.load(settings.spacy_model_name)
print(f"β
spaCy model ({settings.spacy_model_name}) loaded successfully!")
except OSError:
print(f"β οΈ spaCy model '{settings.spacy_model_name}' not found.")
# Try fallback models
fallback_models = ["en_core_web_sm", "en_core_web_trf"]
for fallback_model in fallback_models:
if fallback_model != settings.spacy_model_name:
try:
print(f"π Trying fallback model: {fallback_model}")
self._spacy_model = spacy.load(fallback_model)
print(f"β
spaCy model ({fallback_model}) loaded successfully!")
break
except OSError:
continue
if self._spacy_model is None:
print("β οΈ No spaCy model found. Please install with: python -m spacy download en_core_web_sm")
self._models_loaded = True
print("π All models loaded successfully!")
@property
def doctr_model(self):
"""Get the loaded doctr model."""
return self._doctr_model
@property
def spacy_model(self):
"""Get the loaded spaCy model."""
return self._spacy_model
@property
def device(self):
"""Get the device being used."""
return self._device
@property
def models_loaded(self):
"""Check if models are loaded."""
return self._models_loaded
async def ensure_models_loaded(self):
"""Ensure models are loaded (async wrapper)."""
if not self._models_loaded:
await asyncio.get_event_loop().run_in_executor(None, self._load_models)
return True
def get_model_status(self):
"""Get status of all models."""
return {
"doctr_model": self._doctr_model is not None,
"spacy_model": self._spacy_model is not None,
"device": str(self._device),
"models_loaded": self._models_loaded,
"spacy_model_name": settings.spacy_model_name,
"force_cpu": settings.force_cpu
}
# Global model manager instance
model_manager = ModelManager() |