ananttripathiak commited on
Commit
85edb93
·
verified ·
1 Parent(s): d75a2c4

Create model_loader.py

Browse files
Files changed (1) hide show
  1. models/model_loader.py +92 -0
models/model_loader.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Loader Module
3
+ Handles loading and caching of AI models.
4
+ """
5
+
6
+ import os
7
+ import logging
8
+ from typing import Optional
9
+ from sentence_transformers import SentenceTransformer
10
+ import spacy
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class ModelLoader:
16
+ """
17
+ Singleton class for loading and caching models.
18
+ """
19
+
20
+ _instance = None
21
+ _models = {}
22
+
23
+ def __new__(cls):
24
+ if cls._instance is None:
25
+ cls._instance = super(ModelLoader, cls).__new__(cls)
26
+ return cls._instance
27
+
28
+ def __init__(self):
29
+ """Initialize model loader."""
30
+ self.cache_dir = os.getenv('MODEL_CACHE_DIR', './model_cache')
31
+ os.makedirs(self.cache_dir, exist_ok=True)
32
+
33
+ def load_sentence_transformer(
34
+ self,
35
+ model_name: str = "all-MiniLM-L6-v2"
36
+ ) -> SentenceTransformer:
37
+ """
38
+ Load sentence transformer model with caching.
39
+
40
+ Args:
41
+ model_name: HuggingFace model name
42
+
43
+ Returns:
44
+ Loaded model
45
+ """
46
+ if model_name in self._models:
47
+ logger.info(f"Using cached model: {model_name}")
48
+ return self._models[model_name]
49
+
50
+ try:
51
+ logger.info(f"Loading sentence transformer: {model_name}")
52
+ model = SentenceTransformer(model_name, cache_folder=self.cache_dir)
53
+ self._models[model_name] = model
54
+ logger.info(f"Successfully loaded: {model_name}")
55
+ return model
56
+ except Exception as e:
57
+ logger.error(f"Failed to load model {model_name}: {e}")
58
+ raise
59
+
60
+ def load_spacy_model(self, model_name: str = "en_core_web_sm"):
61
+ """
62
+ Load spaCy model with caching.
63
+
64
+ Args:
65
+ model_name: spaCy model name
66
+
67
+ Returns:
68
+ Loaded spaCy model
69
+ """
70
+ if model_name in self._models:
71
+ logger.info(f"Using cached spaCy model: {model_name}")
72
+ return self._models[model_name]
73
+
74
+ try:
75
+ logger.info(f"Loading spaCy model: {model_name}")
76
+ nlp = spacy.load(model_name)
77
+ self._models[model_name] = nlp
78
+ logger.info(f"Successfully loaded: {model_name}")
79
+ return nlp
80
+ except Exception as e:
81
+ logger.error(f"Failed to load spaCy model: {e}")
82
+ return None
83
+
84
+ def clear_cache(self):
85
+ """Clear model cache."""
86
+ self._models.clear()
87
+ logger.info("Model cache cleared")
88
+
89
+ def get_loaded_models(self):
90
+ """Get list of currently loaded models."""
91
+ return list(self._models.keys())
92
+