Bank-Scrubber / src /utils /model_manager.py
Aryan Jain
bank scrubber streamlit application
4e71548
import asyncio
import torch
from typing import Optional
from doctr.models import ocr_predictor
import spacy
from src.config.config import settings
class ModelManager:
"""Singleton model manager for pre-loading all models at startup."""
_instance = None
_doctr_model = None
_spacy_model = None
_device = None
_models_loaded = False
def __new__(cls):
if cls._instance is None:
cls._instance = super(ModelManager, cls).__new__(cls)
return cls._instance
def __init__(self):
if not self._models_loaded:
self._load_models()
def _load_models(self):
"""Load all models synchronously."""
print("πŸš€ Starting model pre-loading...")
# Set device based on config
if settings.force_cpu:
self._device = torch.device("cpu")
print("πŸ“± Using CPU (forced by config)")
else:
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"πŸ“± Using device: {self._device}")
# Load doctr model
print("πŸ”„ Loading doctr OCR model...")
self._doctr_model = ocr_predictor(pretrained=True)
self._doctr_model.det_predictor.model = self._doctr_model.det_predictor.model.to(self._device)
self._doctr_model.reco_predictor.model = self._doctr_model.reco_predictor.model.to(self._device)
print("βœ… Doctr model loaded successfully!")
# Load spaCy model
print(f"πŸ”„ Loading spaCy NER model: {settings.spacy_model_name}...")
try:
self._spacy_model = spacy.load(settings.spacy_model_name)
print(f"βœ… spaCy model ({settings.spacy_model_name}) loaded successfully!")
except OSError:
print(f"⚠️ spaCy model '{settings.spacy_model_name}' not found.")
# Try fallback models
fallback_models = ["en_core_web_sm", "en_core_web_trf"]
for fallback_model in fallback_models:
if fallback_model != settings.spacy_model_name:
try:
print(f"πŸ”„ Trying fallback model: {fallback_model}")
self._spacy_model = spacy.load(fallback_model)
print(f"βœ… spaCy model ({fallback_model}) loaded successfully!")
break
except OSError:
continue
if self._spacy_model is None:
print("⚠️ No spaCy model found. Please install with: python -m spacy download en_core_web_sm")
self._models_loaded = True
print("πŸŽ‰ All models loaded successfully!")
@property
def doctr_model(self):
"""Get the loaded doctr model."""
return self._doctr_model
@property
def spacy_model(self):
"""Get the loaded spaCy model."""
return self._spacy_model
@property
def device(self):
"""Get the device being used."""
return self._device
@property
def models_loaded(self):
"""Check if models are loaded."""
return self._models_loaded
async def ensure_models_loaded(self):
"""Ensure models are loaded (async wrapper)."""
if not self._models_loaded:
await asyncio.get_event_loop().run_in_executor(None, self._load_models)
return True
def get_model_status(self):
"""Get status of all models."""
return {
"doctr_model": self._doctr_model is not None,
"spacy_model": self._spacy_model is not None,
"device": str(self._device),
"models_loaded": self._models_loaded,
"spacy_model_name": settings.spacy_model_name,
"force_cpu": settings.force_cpu
}
# Global model manager instance
model_manager = ModelManager()