anaspro commited on
Commit
a645494
·
1 Parent(s): 3e07df2

Add model caching with lru_cache for ZeroGPU

Browse files

- Use @lru_cache to cache loaded model
- Prevents reloading model on every request
- First load: ~18 seconds
- Subsequent loads: instant (cached)
- Works better with ZeroGPU ephemeral processes

Files changed (1) hide show
  1. app.py +19 -15
app.py CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
3
  import spaces
4
  import re
5
  from threading import Thread
 
6
  from transformers import pipeline, TextIteratorStreamer
7
  from huggingface_hub import login
8
  import logging
@@ -68,8 +69,22 @@ model_id = "unsloth/gpt-oss-20b-unsloth-bnb-4bit"
68
  # Load harmony encoding (lightweight, can load outside GPU)
69
  enc = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
70
 
71
- # Pipeline will be created inside @spaces.GPU function
72
- pipe = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  # ======================================================
75
  # Format Conversation History
@@ -92,19 +107,8 @@ def format_conversation_history(chat_history):
92
  def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
93
  """Generate response using GPT-OSS with Harmony format"""
94
 
95
- global pipe
96
-
97
- # Load pipeline inside GPU context (for ZeroGPU)
98
- if pipe is None:
99
- logger.info("🚀 Loading GPT-OSS-20B model on GPU...")
100
- pipe = pipeline(
101
- "text-generation",
102
- model=model_id,
103
- torch_dtype="auto",
104
- device_map="auto",
105
- trust_remote_code=True,
106
- )
107
- logger.info("✅ Model loaded successfully!")
108
 
109
  # Create new user message
110
  new_message = {"role": "user", "content": input_data}
 
3
  import spaces
4
  import re
5
  from threading import Thread
6
+ from functools import lru_cache
7
  from transformers import pipeline, TextIteratorStreamer
8
  from huggingface_hub import login
9
  import logging
 
69
  # Load harmony encoding (lightweight, can load outside GPU)
70
  enc = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
71
 
72
+ # ======================================================
73
+ # Cached Model Loader (for ZeroGPU)
74
+ # ======================================================
75
+ @lru_cache(maxsize=1)
76
+ def load_model():
77
+ """Load model with caching to avoid reloading"""
78
+ logger.info("🚀 Loading GPT-OSS-20B model on GPU...")
79
+ model_pipe = pipeline(
80
+ "text-generation",
81
+ model=model_id,
82
+ torch_dtype="auto",
83
+ device_map="auto",
84
+ trust_remote_code=True,
85
+ )
86
+ logger.info("✅ Model loaded successfully!")
87
+ return model_pipe
88
 
89
  # ======================================================
90
  # Format Conversation History
 
107
  def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
108
  """Generate response using GPT-OSS with Harmony format"""
109
 
110
+ # Get cached model (loads only once)
111
+ pipe = load_model()
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  # Create new user message
114
  new_message = {"role": "user", "content": input_data}