serenichron commited on
Commit
e8b5e4c
·
1 Parent(s): 2b8f58e

Make gradio_chat a non-generator GPU function for ZeroGPU detection

Browse files

- Add @spaces.GPU(duration=120) decorator to gradio_chat
- Change from generator (yield) to regular function (return)
- Use generate_text instead of zerogpu_generate_stream
- This ensures ZeroGPU detects GPU usage at startup

Files changed (1) hide show
  1. app.py +9 -12
app.py CHANGED
@@ -204,6 +204,7 @@ async def serverless_generate(
204
  # --- Gradio Interface ---
205
 
206
 
 
207
  def gradio_chat(
208
  message: str,
209
  history: list[list[str]],
@@ -211,11 +212,10 @@ def gradio_chat(
211
  temperature: float,
212
  max_tokens: int,
213
  ):
214
- """Gradio chat interface handler."""
215
  # Validate model_id
216
  if not model_id:
217
- yield "Please select a model first."
218
- return
219
 
220
  # Build messages from history
221
  messages = []
@@ -229,24 +229,21 @@ def gradio_chat(
229
  try:
230
  prompt = apply_chat_template(model_id, messages)
231
  except Exception as e:
232
- yield f"Error loading model: {str(e)}"
233
- return
234
 
235
- # Generate response (streaming)
236
- response = ""
237
  try:
238
- for token in zerogpu_generate_stream(
239
  model_id=model_id,
240
  prompt=prompt,
241
  max_new_tokens=max_tokens,
242
  temperature=temperature,
243
  top_p=0.95,
244
  stop_sequences=None,
245
- ):
246
- response += token
247
- yield response
248
  except Exception as e:
249
- yield f"Error generating response: {str(e)}"
250
 
251
 
252
  # --- FastAPI app for OpenAI-compatible routes ---
 
204
  # --- Gradio Interface ---
205
 
206
 
207
+ @spaces.GPU(duration=120)
208
  def gradio_chat(
209
  message: str,
210
  history: list[list[str]],
 
212
  temperature: float,
213
  max_tokens: int,
214
  ):
215
+ """Gradio chat interface handler - GPU decorated for ZeroGPU."""
216
  # Validate model_id
217
  if not model_id:
218
+ return "Please select a model first."
 
219
 
220
  # Build messages from history
221
  messages = []
 
229
  try:
230
  prompt = apply_chat_template(model_id, messages)
231
  except Exception as e:
232
+ return f"Error loading model: {str(e)}"
 
233
 
234
+ # Generate response (non-streaming for simplicity with ZeroGPU)
 
235
  try:
236
+ response = generate_text(
237
  model_id=model_id,
238
  prompt=prompt,
239
  max_new_tokens=max_tokens,
240
  temperature=temperature,
241
  top_p=0.95,
242
  stop_sequences=None,
243
+ )
244
+ return response
 
245
  except Exception as e:
246
+ return f"Error generating response: {str(e)}"
247
 
248
 
249
  # --- FastAPI app for OpenAI-compatible routes ---