end-rin commited on
Commit
8141c38
·
verified ·
1 Parent(s): 1340b8f

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +28 -17
app.py CHANGED
@@ -148,9 +148,8 @@ MODELS = {
148
  },
149
  }
150
 
151
- # Global model cache (only keep one model loaded at a time to save memory)
152
- _current_model = None
153
- _current_model_name = None
154
  _llama_class = None
155
 
156
 
@@ -208,29 +207,39 @@ def transform_text(text: str, style: str) -> str:
208
 
209
 
210
  def load_model(model_key: str):
211
- """Load a GGUF model. Unloads previous model to save memory."""
212
- global _current_model, _current_model_name
213
-
214
- if _current_model_name == model_key and _current_model is not None:
215
- return _current_model
216
-
217
- # Unload previous model
218
- if _current_model is not None:
219
- del _current_model
220
- _current_model = None
221
- _current_model_name = None
222
 
 
223
  config = MODELS[model_key]
224
  Llama = _get_llama_class()
225
- _current_model = Llama.from_pretrained(
226
  repo_id=config['repo_id'],
227
  filename=config['filename'],
228
  n_ctx=256,
229
  n_threads=8,
230
  verbose=False,
231
  )
232
- _current_model_name = model_key
233
- return _current_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
 
236
  def get_prediction(model, text: str, task: str, model_key: str) -> str:
@@ -587,6 +596,8 @@ def create_demo():
587
  # =============================================================================
588
 
589
  if __name__ == "__main__":
 
 
590
  demo = create_demo()
591
  demo.queue(default_concurrency_limit=1)
592
  demo.launch()
 
148
  },
149
  }
150
 
151
+ # Global model cache all three models are pre-loaded at startup
152
+ _loaded_models = {}
 
153
  _llama_class = None
154
 
155
 
 
207
 
208
 
209
  def load_model(model_key: str):
210
+ """Return a pre-loaded model, or load on demand as fallback."""
211
+ if model_key in _loaded_models:
212
+ return _loaded_models[model_key]
 
 
 
 
 
 
 
 
213
 
214
+ # Fallback: load on demand if not yet ready
215
  config = MODELS[model_key]
216
  Llama = _get_llama_class()
217
+ model = Llama.from_pretrained(
218
  repo_id=config['repo_id'],
219
  filename=config['filename'],
220
  n_ctx=256,
221
  n_threads=8,
222
  verbose=False,
223
  )
224
+ _loaded_models[model_key] = model
225
+ return model
226
+
227
+
228
+ def preload_all_models():
229
+ """Pre-load all models at startup so switching is instant."""
230
+ Llama = _get_llama_class()
231
+ for key, config in MODELS.items():
232
+ if key not in _loaded_models:
233
+ print(f"Pre-loading {config['name']}...")
234
+ _loaded_models[key] = Llama.from_pretrained(
235
+ repo_id=config['repo_id'],
236
+ filename=config['filename'],
237
+ n_ctx=256,
238
+ n_threads=8,
239
+ verbose=False,
240
+ )
241
+ print(f" {config['name']} ready.")
242
+ print("All models pre-loaded.")
243
 
244
 
245
  def get_prediction(model, text: str, task: str, model_key: str) -> str:
 
596
  # =============================================================================
597
 
598
  if __name__ == "__main__":
599
+ print("Starting model pre-load...")
600
+ preload_all_models()
601
  demo = create_demo()
602
  demo.queue(default_concurrency_limit=1)
603
  demo.launch()