anaspro
commited on
Commit
·
a645494
1
Parent(s):
3e07df2
Add model caching with lru_cache for ZeroGPU
Browse files- Use @lru_cache to cache loaded model
- Prevents reloading model on every request
- First load: ~18 seconds
- Subsequent loads: instant (cached)
- Works better with ZeroGPU ephemeral processes
app.py
CHANGED
|
@@ -3,6 +3,7 @@ import gradio as gr
|
|
| 3 |
import spaces
|
| 4 |
import re
|
| 5 |
from threading import Thread
|
|
|
|
| 6 |
from transformers import pipeline, TextIteratorStreamer
|
| 7 |
from huggingface_hub import login
|
| 8 |
import logging
|
|
@@ -68,8 +69,22 @@ model_id = "unsloth/gpt-oss-20b-unsloth-bnb-4bit"
|
|
| 68 |
# Load harmony encoding (lightweight, can load outside GPU)
|
| 69 |
enc = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
| 70 |
|
| 71 |
-
#
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
# ======================================================
|
| 75 |
# Format Conversation History
|
|
@@ -92,19 +107,8 @@ def format_conversation_history(chat_history):
|
|
| 92 |
def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
|
| 93 |
"""Generate response using GPT-OSS with Harmony format"""
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
# Load pipeline inside GPU context (for ZeroGPU)
|
| 98 |
-
if pipe is None:
|
| 99 |
-
logger.info("🚀 Loading GPT-OSS-20B model on GPU...")
|
| 100 |
-
pipe = pipeline(
|
| 101 |
-
"text-generation",
|
| 102 |
-
model=model_id,
|
| 103 |
-
torch_dtype="auto",
|
| 104 |
-
device_map="auto",
|
| 105 |
-
trust_remote_code=True,
|
| 106 |
-
)
|
| 107 |
-
logger.info("✅ Model loaded successfully!")
|
| 108 |
|
| 109 |
# Create new user message
|
| 110 |
new_message = {"role": "user", "content": input_data}
|
|
|
|
| 3 |
import spaces
|
| 4 |
import re
|
| 5 |
from threading import Thread
|
| 6 |
+
from functools import lru_cache
|
| 7 |
from transformers import pipeline, TextIteratorStreamer
|
| 8 |
from huggingface_hub import login
|
| 9 |
import logging
|
|
|
|
| 69 |
# Load harmony encoding (lightweight, can load outside GPU)
|
| 70 |
enc = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
| 71 |
|
| 72 |
+
# ======================================================
|
| 73 |
+
# Cached Model Loader (for ZeroGPU)
|
| 74 |
+
# ======================================================
|
| 75 |
+
@lru_cache(maxsize=1)
|
| 76 |
+
def load_model():
|
| 77 |
+
"""Load model with caching to avoid reloading"""
|
| 78 |
+
logger.info("🚀 Loading GPT-OSS-20B model on GPU...")
|
| 79 |
+
model_pipe = pipeline(
|
| 80 |
+
"text-generation",
|
| 81 |
+
model=model_id,
|
| 82 |
+
torch_dtype="auto",
|
| 83 |
+
device_map="auto",
|
| 84 |
+
trust_remote_code=True,
|
| 85 |
+
)
|
| 86 |
+
logger.info("✅ Model loaded successfully!")
|
| 87 |
+
return model_pipe
|
| 88 |
|
| 89 |
# ======================================================
|
| 90 |
# Format Conversation History
|
|
|
|
| 107 |
def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
|
| 108 |
"""Generate response using GPT-OSS with Harmony format"""
|
| 109 |
|
| 110 |
+
# Get cached model (loads only once)
|
| 111 |
+
pipe = load_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
# Create new user message
|
| 114 |
new_message = {"role": "user", "content": input_data}
|