File size: 3,976 Bytes
4e71548
2a728d0
 
 
4e71548
 
2a728d0
 
4e71548
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import asyncio
import logging
from typing import Optional, Dict, Any
from pathlib import Path
from doctr.models import ocr_predictor
import spacy
from config.config import settings
import torch


class ModelManager:
    """Singleton model manager for pre-loading all models at startup."""
    
    _instance = None
    _doctr_model = None
    _spacy_model = None
    _device = None
    _models_loaded = False
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(ModelManager, cls).__new__(cls)
        return cls._instance
    
    def __init__(self):
        if not self._models_loaded:
            self._load_models()
    
    def _load_models(self):
        """Load all models synchronously."""
        print("πŸš€ Starting model pre-loading...")
        
        # Set device based on config
        if settings.force_cpu:
            self._device = torch.device("cpu")
            print("πŸ“± Using CPU (forced by config)")
        else:
            self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            print(f"πŸ“± Using device: {self._device}")
        
        # Load doctr model
        print("πŸ”„ Loading doctr OCR model...")
        self._doctr_model = ocr_predictor(pretrained=True)
        self._doctr_model.det_predictor.model = self._doctr_model.det_predictor.model.to(self._device)
        self._doctr_model.reco_predictor.model = self._doctr_model.reco_predictor.model.to(self._device)
        print("βœ… Doctr model loaded successfully!")
        
        # Load spaCy model
        print(f"πŸ”„ Loading spaCy NER model: {settings.spacy_model_name}...")
        try:
            self._spacy_model = spacy.load(settings.spacy_model_name)
            print(f"βœ… spaCy model ({settings.spacy_model_name}) loaded successfully!")
        except OSError:
            print(f"⚠️ spaCy model '{settings.spacy_model_name}' not found.")
            # Try fallback models
            fallback_models = ["en_core_web_sm", "en_core_web_trf"]
            for fallback_model in fallback_models:
                if fallback_model != settings.spacy_model_name:
                    try:
                        print(f"πŸ”„ Trying fallback model: {fallback_model}")
                        self._spacy_model = spacy.load(fallback_model)
                        print(f"βœ… spaCy model ({fallback_model}) loaded successfully!")
                        break
                    except OSError:
                        continue
            
            if self._spacy_model is None:
                print("⚠️ No spaCy model found. Please install with: python -m spacy download en_core_web_sm")
        
        self._models_loaded = True
        print("πŸŽ‰ All models loaded successfully!")
    
    @property
    def doctr_model(self):
        """Get the loaded doctr model."""
        return self._doctr_model
    
    @property
    def spacy_model(self):
        """Get the loaded spaCy model."""
        return self._spacy_model
    
    @property
    def device(self):
        """Get the device being used."""
        return self._device
    
    @property
    def models_loaded(self):
        """Check if models are loaded."""
        return self._models_loaded
    
    async def ensure_models_loaded(self):
        """Ensure models are loaded (async wrapper)."""
        if not self._models_loaded:
            await asyncio.get_event_loop().run_in_executor(None, self._load_models)
        return True
    
    def get_model_status(self):
        """Get status of all models."""
        return {
            "doctr_model": self._doctr_model is not None,
            "spacy_model": self._spacy_model is not None,
            "device": str(self._device),
            "models_loaded": self._models_loaded,
            "spacy_model_name": settings.spacy_model_name,
            "force_cpu": settings.force_cpu
        }


# Global model manager instance
model_manager = ModelManager()