File size: 2,609 Bytes
e42e330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6874c0e
e42e330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os

import spacy
import torch
from doctr.models import ocr_predictor
from loguru import logger


class ModelManager:
    """Singleton model manager for pre-loading all models at startup."""

    _instance = None
    _doctr_model = None
    _spacy_model = None
    _device = None
    _models_loaded = False

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(ModelManager, cls).__new__(cls)
        return cls._instance

    def __init__(self):
        pass

    async def _load_models(self):
        """Load all models synchronously."""
        logger.info("๐Ÿš€ Starting model pre-loading...")

        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"๐Ÿ“ฑ Using device: {self._device}")

        # Load doctr model
        logger.info("๐Ÿ”„ Loading doctr OCR model...")
        self._doctr_model = ocr_predictor(pretrained=True)
        self._doctr_model.det_predictor.model = (
            self._doctr_model.det_predictor.model.to(self._device)
        )
        self._doctr_model.reco_predictor.model = (
            self._doctr_model.reco_predictor.model.to(self._device)
        )
        logger.info("โœ… Doctr model loaded successfully!")

        # Load spaCy model
        self._spacy_model = spacy.load(os.getenv("SPACY_MODEL_NAME", "en_core_web_sm"))
        logger.info(f"โœ… spaCy model loaded successfully!")
        self._models_loaded = True
        logger.info("๐ŸŽ‰ All models loaded successfully!")

    @property
    def doctr_model(self):
        """Get the loaded doctr model."""
        return self._doctr_model

    @property
    def spacy_model(self):
        """Get the loaded spaCy model."""
        return self._spacy_model

    @property
    def device(self):
        """Get the device being used."""
        return self._device

    @property
    def models_loaded(self):
        """Check if models are loaded."""
        return self._models_loaded

    async def ensure_models_loaded(self):
        """Ensure models are loaded (async wrapper)."""
        if not self._models_loaded:
            await self._load_models()
        return True

    async def get_model_status(self):
        """Get status of all models."""
        return {
            "doctr_model": self._doctr_model is not None,
            "spacy_model": self._spacy_model is not None,
            "device": str(self._device),
            "models_loaded": self._models_loaded,
            "spacy_model_name": os.getenv("SPACY_MODEL_NAME"),
        }


# Global model manager instance
model_manager = ModelManager()