Spaces:
Running
Running
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -148,9 +148,8 @@ MODELS = {
|
|
| 148 |
},
|
| 149 |
}
|
| 150 |
|
| 151 |
-
# Global model cache
|
| 152 |
-
|
| 153 |
-
_current_model_name = None
|
| 154 |
_llama_class = None
|
| 155 |
|
| 156 |
|
|
@@ -208,29 +207,39 @@ def transform_text(text: str, style: str) -> str:
|
|
| 208 |
|
| 209 |
|
| 210 |
def load_model(model_key: str):
|
| 211 |
-
"""
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
if _current_model_name == model_key and _current_model is not None:
|
| 215 |
-
return _current_model
|
| 216 |
-
|
| 217 |
-
# Unload previous model
|
| 218 |
-
if _current_model is not None:
|
| 219 |
-
del _current_model
|
| 220 |
-
_current_model = None
|
| 221 |
-
_current_model_name = None
|
| 222 |
|
|
|
|
| 223 |
config = MODELS[model_key]
|
| 224 |
Llama = _get_llama_class()
|
| 225 |
-
|
| 226 |
repo_id=config['repo_id'],
|
| 227 |
filename=config['filename'],
|
| 228 |
n_ctx=256,
|
| 229 |
n_threads=8,
|
| 230 |
verbose=False,
|
| 231 |
)
|
| 232 |
-
|
| 233 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
|
| 236 |
def get_prediction(model, text: str, task: str, model_key: str) -> str:
|
|
@@ -587,6 +596,8 @@ def create_demo():
|
|
| 587 |
# =============================================================================
|
| 588 |
|
| 589 |
if __name__ == "__main__":
|
|
|
|
|
|
|
| 590 |
demo = create_demo()
|
| 591 |
demo.queue(default_concurrency_limit=1)
|
| 592 |
demo.launch()
|
|
|
|
| 148 |
},
|
| 149 |
}
|
| 150 |
|
| 151 |
+
# Global model cache — all three models are pre-loaded at startup
|
| 152 |
+
_loaded_models = {}
|
|
|
|
| 153 |
_llama_class = None
|
| 154 |
|
| 155 |
|
|
|
|
| 207 |
|
| 208 |
|
| 209 |
def load_model(model_key: str):
|
| 210 |
+
"""Return a pre-loaded model, or load on demand as fallback."""
|
| 211 |
+
if model_key in _loaded_models:
|
| 212 |
+
return _loaded_models[model_key]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
+
# Fallback: load on demand if not yet ready
|
| 215 |
config = MODELS[model_key]
|
| 216 |
Llama = _get_llama_class()
|
| 217 |
+
model = Llama.from_pretrained(
|
| 218 |
repo_id=config['repo_id'],
|
| 219 |
filename=config['filename'],
|
| 220 |
n_ctx=256,
|
| 221 |
n_threads=8,
|
| 222 |
verbose=False,
|
| 223 |
)
|
| 224 |
+
_loaded_models[model_key] = model
|
| 225 |
+
return model
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def preload_all_models():
|
| 229 |
+
"""Pre-load all models at startup so switching is instant."""
|
| 230 |
+
Llama = _get_llama_class()
|
| 231 |
+
for key, config in MODELS.items():
|
| 232 |
+
if key not in _loaded_models:
|
| 233 |
+
print(f"Pre-loading {config['name']}...")
|
| 234 |
+
_loaded_models[key] = Llama.from_pretrained(
|
| 235 |
+
repo_id=config['repo_id'],
|
| 236 |
+
filename=config['filename'],
|
| 237 |
+
n_ctx=256,
|
| 238 |
+
n_threads=8,
|
| 239 |
+
verbose=False,
|
| 240 |
+
)
|
| 241 |
+
print(f" {config['name']} ready.")
|
| 242 |
+
print("All models pre-loaded.")
|
| 243 |
|
| 244 |
|
| 245 |
def get_prediction(model, text: str, task: str, model_key: str) -> str:
|
|
|
|
| 596 |
# =============================================================================
|
| 597 |
|
| 598 |
if __name__ == "__main__":
|
| 599 |
+
print("Starting model pre-load...")
|
| 600 |
+
preload_all_models()
|
| 601 |
demo = create_demo()
|
| 602 |
demo.queue(default_concurrency_limit=1)
|
| 603 |
demo.launch()
|