serenichron commited on
Commit
ae2483e
·
1 Parent(s): 5ea35f6

Fix Gradio + FastAPI integration for HuggingFace Spaces

Browse files

- Use demo.app to add FastAPI routes to Gradio's internal app
- Remove examples from ChatInterface to avoid caching issues
- Add error handling in gradio_chat for model loading failures
- Simplify app structure for better HF Spaces compatibility

Files changed (1) hide show
  1. app.py +132 -152
app.py CHANGED
@@ -11,12 +11,11 @@ This Gradio app provides:
11
 
12
  import logging
13
  import time
14
- from contextlib import asynccontextmanager
15
  from typing import Optional
16
 
17
  import gradio as gr
18
  import httpx
19
- from fastapi import FastAPI, Header, HTTPException, Request
20
  from fastapi.responses import StreamingResponse, JSONResponse
21
  from huggingface_hub import HfApi
22
 
@@ -195,28 +194,141 @@ async def serverless_generate(
195
  )
196
 
197
 
198
- # --- FastAPI App ---
199
 
200
 
201
- @asynccontextmanager
202
- async def lifespan(app: FastAPI):
203
- """Application lifespan events."""
204
- logger.info("Starting ZeroGPU OpenCode Provider")
205
- logger.info(f"ZeroGPU available: {ZEROGPU_AVAILABLE}")
206
- logger.info(f"Fallback enabled: {config.fallback_enabled}")
207
- yield
208
- logger.info("Shutting down ZeroGPU OpenCode Provider")
 
 
 
 
209
 
 
 
 
 
 
 
 
210
 
211
- api = FastAPI(
212
- title="ZeroGPU OpenCode Provider",
213
- description="OpenAI-compatible API for HuggingFace models on ZeroGPU",
214
- version="1.0.0",
215
- lifespan=lifespan,
216
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
- @api.post("/v1/chat/completions")
 
220
  async def chat_completions(
221
  request: ChatCompletionRequest,
222
  authorization: Optional[str] = Header(None),
@@ -352,7 +464,7 @@ async def chat_completions(
352
  )
353
 
354
 
355
- @api.get("/v1/models")
356
  async def list_models(authorization: Optional[str] = Header(None)):
357
  """List available models (returns info about current model if loaded)."""
358
  token = extract_token(authorization)
@@ -382,7 +494,7 @@ async def list_models(authorization: Optional[str] = Header(None)):
382
  return {"object": "list", "data": models}
383
 
384
 
385
- @api.get("/health")
386
  async def health_check():
387
  """Health check endpoint."""
388
  return {
@@ -393,137 +505,5 @@ async def health_check():
393
  }
394
 
395
 
396
- # --- Gradio Interface ---
397
-
398
-
399
- def gradio_chat(
400
- message: str,
401
- history: list[list[str]],
402
- model_id: str,
403
- temperature: float,
404
- max_tokens: int,
405
- ):
406
- """Gradio chat interface handler."""
407
- # Validate model_id
408
- if not model_id:
409
- yield "Please select a model first."
410
- return
411
-
412
- # Build messages from history
413
- messages = []
414
- for user_msg, assistant_msg in history:
415
- messages.append({"role": "user", "content": user_msg})
416
- if assistant_msg:
417
- messages.append({"role": "assistant", "content": assistant_msg})
418
- messages.append({"role": "user", "content": message})
419
-
420
- # Apply chat template
421
- prompt = apply_chat_template(model_id, messages)
422
-
423
- # Generate response (streaming)
424
- response = ""
425
- for token in zerogpu_generate_stream(
426
- model_id=model_id,
427
- prompt=prompt,
428
- max_new_tokens=max_tokens,
429
- temperature=temperature,
430
- top_p=0.95,
431
- stop_sequences=None,
432
- ):
433
- response += token
434
- yield response
435
-
436
-
437
- # Gradio Blocks interface
438
- with gr.Blocks(title="ZeroGPU OpenCode Provider") as demo:
439
- gr.Markdown(
440
- """
441
- # ZeroGPU OpenCode Provider
442
-
443
- OpenAI-compatible inference endpoint for [opencode](https://github.com/sst/opencode).
444
-
445
- **API Endpoint:** `/v1/chat/completions`
446
-
447
- ## Usage with opencode
448
-
449
- Configure in `~/.config/opencode/opencode.json`:
450
-
451
- ```json
452
- {
453
- "providers": {
454
- "zerogpu": {
455
- "npm": "@ai-sdk/openai-compatible",
456
- "options": {
457
- "baseURL": "https://serenichron-opencode-zerogpu.hf.space/v1",
458
- "headers": {
459
- "Authorization": "Bearer hf_YOUR_TOKEN"
460
- }
461
- },
462
- "models": {
463
- "llama-8b": {
464
- "name": "meta-llama/Llama-3.1-8B-Instruct"
465
- }
466
- }
467
- }
468
- }
469
- }
470
- ```
471
-
472
- ---
473
- """
474
- )
475
-
476
- with gr.Row():
477
- with gr.Column(scale=1):
478
- model_dropdown = gr.Dropdown(
479
- label="Model",
480
- choices=[
481
- "meta-llama/Llama-3.1-8B-Instruct",
482
- "mistralai/Mistral-7B-Instruct-v0.3",
483
- "Qwen/Qwen2.5-7B-Instruct",
484
- "Qwen/Qwen2.5-14B-Instruct",
485
- ],
486
- value="meta-llama/Llama-3.1-8B-Instruct",
487
- allow_custom_value=True,
488
- )
489
- temperature_slider = gr.Slider(
490
- label="Temperature",
491
- minimum=0.0,
492
- maximum=2.0,
493
- value=0.7,
494
- step=0.1,
495
- )
496
- max_tokens_slider = gr.Slider(
497
- label="Max Tokens",
498
- minimum=64,
499
- maximum=4096,
500
- value=512,
501
- step=64,
502
- )
503
-
504
- gr.Markdown(
505
- f"""
506
- ### Status
507
- - **ZeroGPU:** {'Available' if ZEROGPU_AVAILABLE else 'Not Available'}
508
- - **Fallback:** {'Enabled' if config.fallback_enabled else 'Disabled'}
509
- """
510
- )
511
-
512
- with gr.Column(scale=3):
513
- chatbot = gr.ChatInterface(
514
- fn=gradio_chat,
515
- additional_inputs=[model_dropdown, temperature_slider, max_tokens_slider],
516
- title="",
517
- examples=[
518
- ["Hello! How are you?"],
519
- ["Explain quantum computing in simple terms."],
520
- ["Write a Python function to calculate fibonacci numbers."],
521
- ],
522
- )
523
-
524
- # Mount FastAPI to Gradio
525
- demo = gr.mount_gradio_app(demo, api, path="/")
526
-
527
-
528
  if __name__ == "__main__":
529
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
11
 
12
  import logging
13
  import time
 
14
  from typing import Optional
15
 
16
  import gradio as gr
17
  import httpx
18
+ from fastapi import Header, HTTPException
19
  from fastapi.responses import StreamingResponse, JSONResponse
20
  from huggingface_hub import HfApi
21
 
 
194
  )
195
 
196
 
197
+ # --- Gradio Interface ---
198
 
199
 
200
+ def gradio_chat(
201
+ message: str,
202
+ history: list[list[str]],
203
+ model_id: str,
204
+ temperature: float,
205
+ max_tokens: int,
206
+ ):
207
+ """Gradio chat interface handler."""
208
+ # Validate model_id
209
+ if not model_id:
210
+ yield "Please select a model first."
211
+ return
212
 
213
+ # Build messages from history
214
+ messages = []
215
+ for user_msg, assistant_msg in history:
216
+ messages.append({"role": "user", "content": user_msg})
217
+ if assistant_msg:
218
+ messages.append({"role": "assistant", "content": assistant_msg})
219
+ messages.append({"role": "user", "content": message})
220
 
221
+ # Apply chat template
222
+ try:
223
+ prompt = apply_chat_template(model_id, messages)
224
+ except Exception as e:
225
+ yield f"Error loading model: {str(e)}"
226
+ return
227
+
228
+ # Generate response (streaming)
229
+ response = ""
230
+ try:
231
+ for token in zerogpu_generate_stream(
232
+ model_id=model_id,
233
+ prompt=prompt,
234
+ max_new_tokens=max_tokens,
235
+ temperature=temperature,
236
+ top_p=0.95,
237
+ stop_sequences=None,
238
+ ):
239
+ response += token
240
+ yield response
241
+ except Exception as e:
242
+ yield f"Error generating response: {str(e)}"
243
+
244
+
245
+ # Build Gradio Blocks interface
246
+ with gr.Blocks(title="ZeroGPU OpenCode Provider") as demo:
247
+ gr.Markdown(
248
+ """
249
+ # ZeroGPU OpenCode Provider
250
+
251
+ OpenAI-compatible inference endpoint for [opencode](https://github.com/sst/opencode).
252
+
253
+ **API Endpoint:** `/v1/chat/completions`
254
+
255
+ ## Usage with opencode
256
+
257
+ Configure in `~/.config/opencode/opencode.json`:
258
+
259
+ ```json
260
+ {
261
+ "providers": {
262
+ "zerogpu": {
263
+ "npm": "@ai-sdk/openai-compatible",
264
+ "options": {
265
+ "baseURL": "https://serenichron-opencode-zerogpu.hf.space/v1",
266
+ "headers": {
267
+ "Authorization": "Bearer hf_YOUR_TOKEN"
268
+ }
269
+ },
270
+ "models": {
271
+ "llama-8b": {
272
+ "name": "meta-llama/Llama-3.1-8B-Instruct"
273
+ }
274
+ }
275
+ }
276
+ }
277
+ }
278
+ ```
279
+
280
+ ---
281
+ """
282
+ )
283
+
284
+ with gr.Row():
285
+ with gr.Column(scale=1):
286
+ model_dropdown = gr.Dropdown(
287
+ label="Model",
288
+ choices=[
289
+ "meta-llama/Llama-3.1-8B-Instruct",
290
+ "mistralai/Mistral-7B-Instruct-v0.3",
291
+ "Qwen/Qwen2.5-7B-Instruct",
292
+ "Qwen/Qwen2.5-14B-Instruct",
293
+ ],
294
+ value="meta-llama/Llama-3.1-8B-Instruct",
295
+ allow_custom_value=True,
296
+ )
297
+ temperature_slider = gr.Slider(
298
+ label="Temperature",
299
+ minimum=0.0,
300
+ maximum=2.0,
301
+ value=0.7,
302
+ step=0.1,
303
+ )
304
+ max_tokens_slider = gr.Slider(
305
+ label="Max Tokens",
306
+ minimum=64,
307
+ maximum=4096,
308
+ value=512,
309
+ step=64,
310
+ )
311
 
312
+ gr.Markdown(
313
+ f"""
314
+ ### Status
315
+ - **ZeroGPU:** {'Available' if ZEROGPU_AVAILABLE else 'Not Available'}
316
+ - **Fallback:** {'Enabled' if config.fallback_enabled else 'Disabled'}
317
+ """
318
+ )
319
+
320
+ with gr.Column(scale=3):
321
+ chatbot = gr.ChatInterface(
322
+ fn=gradio_chat,
323
+ additional_inputs=[model_dropdown, temperature_slider, max_tokens_slider],
324
+ title="",
325
+ )
326
+
327
+
328
+ # --- Add OpenAI-compatible API routes to Gradio's FastAPI app ---
329
 
330
+
331
+ @demo.app.post("/v1/chat/completions")
332
  async def chat_completions(
333
  request: ChatCompletionRequest,
334
  authorization: Optional[str] = Header(None),
 
464
  )
465
 
466
 
467
+ @demo.app.get("/v1/models")
468
  async def list_models(authorization: Optional[str] = Header(None)):
469
  """List available models (returns info about current model if loaded)."""
470
  token = extract_token(authorization)
 
494
  return {"object": "list", "data": models}
495
 
496
 
497
+ @demo.app.get("/health")
498
  async def health_check():
499
  """Health check endpoint."""
500
  return {
 
505
  }
506
 
507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  if __name__ == "__main__":
509
  demo.launch(server_name="0.0.0.0", server_port=7860)