Spaces:
Sleeping
Sleeping
Sushwetabm
commited on
Commit
·
c16d4e7
1
Parent(s):
08937bf
updated model.py
Browse files
model.py
CHANGED
|
@@ -126,7 +126,7 @@
|
|
| 126 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 127 |
from functools import lru_cache
|
| 128 |
import logging
|
| 129 |
-
|
| 130 |
logger = logging.getLogger(__name__)
|
| 131 |
_model_loaded = False
|
| 132 |
_tokenizer = None
|
|
@@ -157,3 +157,37 @@ def load_model_sync():
|
|
| 157 |
except Exception as e:
|
| 158 |
logger.error(f"❌ Failed to load model: {e}")
|
| 159 |
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 127 |
from functools import lru_cache
|
| 128 |
import logging
|
| 129 |
+
import asyncio
|
| 130 |
logger = logging.getLogger(__name__)
|
| 131 |
_model_loaded = False
|
| 132 |
_tokenizer = None
|
|
|
|
| 157 |
except Exception as e:
|
| 158 |
logger.error(f"❌ Failed to load model: {e}")
|
| 159 |
raise
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
async def load_model_async():
|
| 163 |
+
global _tokenizer, _model, _model_loaded
|
| 164 |
+
if _model_loaded:
|
| 165 |
+
return
|
| 166 |
+
|
| 167 |
+
config = get_model_config()
|
| 168 |
+
model_id = config["model_id"]
|
| 169 |
+
|
| 170 |
+
try:
|
| 171 |
+
_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 172 |
+
_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
| 173 |
+
_model.eval()
|
| 174 |
+
_model_loaded = True
|
| 175 |
+
logger.info(f"✅ Model {model_id} loaded successfully.")
|
| 176 |
+
except Exception as e:
|
| 177 |
+
logger.error(f"❌ Failed to load model: {e}")
|
| 178 |
+
raise
|
| 179 |
+
|
| 180 |
+
def get_model():
|
| 181 |
+
if not _model_loaded:
|
| 182 |
+
raise ValueError("Model not loaded yet")
|
| 183 |
+
return _tokenizer, _model
|
| 184 |
+
|
| 185 |
+
def is_model_loaded():
|
| 186 |
+
return _model_loaded
|
| 187 |
+
|
| 188 |
+
def get_model_info():
|
| 189 |
+
return {
|
| 190 |
+
"model_id": get_model_config()["model_id"],
|
| 191 |
+
"loaded": _model_loaded,
|
| 192 |
+
"loading": not _model_loaded
|
| 193 |
+
}
|