serenichron commited on
Commit
67f3d72
·
1 Parent(s): 16b4dcd

Restructure app to use Gradio's native launch for ZeroGPU

Browse files

- Remove FastAPI-first approach that was breaking ZeroGPU detection
- Use demo.app to add API routes to Gradio's internal FastAPI app
- Use demo.launch() only for local development
- Keep demo object exposed for HF Spaces runtime

Files changed (1) hide show
  1. app.py +121 -133
app.py CHANGED
@@ -12,25 +12,16 @@ This Gradio app provides:
12
  # Import spaces FIRST - required for ZeroGPU GPU detection
13
  import spaces
14
 
15
- # Define a GPU function immediately after importing spaces
16
- # This ensures ZeroGPU detects it at startup
17
- @spaces.GPU(duration=60)
18
- def _zerogpu_test():
19
- """Test function for ZeroGPU detection."""
20
- return True
21
-
22
  import logging
23
  import time
24
  from typing import Optional
25
 
26
  import gradio as gr
27
  import httpx
28
- from fastapi import Header, HTTPException
29
  from fastapi.responses import StreamingResponse, JSONResponse
30
  from huggingface_hub import HfApi
31
 
32
- ZEROGPU_AVAILABLE = True
33
-
34
  from config import get_config, get_quota_tracker
35
  from models import (
36
  apply_chat_template,
@@ -55,6 +46,8 @@ quota_tracker = get_quota_tracker()
55
  # HuggingFace API for token validation
56
  hf_api = HfApi()
57
 
 
 
58
 
59
  # --- Authentication ---
60
 
@@ -82,17 +75,8 @@ def extract_token(authorization: Optional[str]) -> Optional[str]:
82
  return authorization
83
 
84
 
85
- # --- ZeroGPU Inference ---
86
-
87
-
88
- # Simple GPU function for ZeroGPU detection at startup
89
- @spaces.GPU(duration=60)
90
- def gpu_warmup():
91
- """Warmup function to ensure ZeroGPU detects GPU usage."""
92
- import torch
93
- if torch.cuda.is_available():
94
- return f"GPU available: {torch.cuda.get_device_name(0)}"
95
- return "No GPU detected"
96
 
97
 
98
  @spaces.GPU(duration=120)
@@ -201,7 +185,7 @@ async def serverless_generate(
201
  )
202
 
203
 
204
- # --- Gradio Interface ---
205
 
206
 
207
  @spaces.GPU(duration=120)
@@ -246,26 +230,108 @@ def gradio_chat(
246
  return f"Error generating response: {str(e)}"
247
 
248
 
249
- # --- FastAPI app for OpenAI-compatible routes ---
250
- from fastapi import FastAPI
251
 
252
- api_app = FastAPI(
253
- title="ZeroGPU OpenCode Provider",
254
- description="OpenAI-compatible API for HuggingFace models on ZeroGPU",
255
- version="1.0.0",
256
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
 
 
 
 
 
 
 
258
 
259
- @api_app.post("/v1/chat/completions")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  async def chat_completions(
261
- request: ChatCompletionRequest,
262
- authorization: Optional[str] = Header(None),
263
  ):
264
  """
265
  OpenAI-compatible chat completions endpoint.
266
 
267
  Supports both streaming and non-streaming responses.
268
  """
 
 
 
269
  # Validate authentication
270
  token = extract_token(authorization)
271
  if not token or not validate_hf_token(token):
@@ -278,8 +344,21 @@ async def chat_completions(
278
  ).model_dump(),
279
  )
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  # Extract inference parameters
282
- params = InferenceParams.from_request(request)
283
 
284
  # Apply chat template
285
  try:
@@ -392,9 +471,10 @@ async def chat_completions(
392
  )
393
 
394
 
395
- @api_app.get("/v1/models")
396
- async def list_models(authorization: Optional[str] = Header(None)):
397
  """List available models (returns info about current model if loaded)."""
 
398
  token = extract_token(authorization)
399
  if not token or not validate_hf_token(token):
400
  return JSONResponse(
@@ -422,7 +502,7 @@ async def list_models(authorization: Optional[str] = Header(None)):
422
  return {"object": "list", "data": models}
423
 
424
 
425
- @api_app.get("/health")
426
  async def health_check():
427
  """Health check endpoint."""
428
  return {
@@ -433,102 +513,10 @@ async def health_check():
433
  }
434
 
435
 
436
- # Build Gradio Blocks interface
437
- with gr.Blocks(title="ZeroGPU OpenCode Provider") as demo:
438
- gr.Markdown(
439
- """
440
- # ZeroGPU OpenCode Provider
441
-
442
- OpenAI-compatible inference endpoint for [opencode](https://github.com/sst/opencode).
443
-
444
- **API Endpoint:** `/v1/chat/completions`
445
-
446
- ## Usage with opencode
447
-
448
- Configure in `~/.config/opencode/opencode.json`:
449
-
450
- ```json
451
- {
452
- "providers": {
453
- "zerogpu": {
454
- "npm": "@ai-sdk/openai-compatible",
455
- "options": {
456
- "baseURL": "https://serenichron-opencode-zerogpu.hf.space/v1",
457
- "headers": {
458
- "Authorization": "Bearer hf_YOUR_TOKEN"
459
- }
460
- },
461
- "models": {
462
- "llama-8b": {
463
- "name": "meta-llama/Llama-3.1-8B-Instruct"
464
- }
465
- }
466
- }
467
- }
468
- }
469
- ```
470
-
471
- ---
472
- """
473
- )
474
-
475
- with gr.Row():
476
- with gr.Column(scale=1):
477
- model_dropdown = gr.Dropdown(
478
- label="Model",
479
- choices=[
480
- "meta-llama/Llama-3.1-8B-Instruct",
481
- "mistralai/Mistral-7B-Instruct-v0.3",
482
- "Qwen/Qwen2.5-7B-Instruct",
483
- "Qwen/Qwen2.5-14B-Instruct",
484
- ],
485
- value="meta-llama/Llama-3.1-8B-Instruct",
486
- allow_custom_value=True,
487
- )
488
- temperature_slider = gr.Slider(
489
- label="Temperature",
490
- minimum=0.0,
491
- maximum=2.0,
492
- value=0.7,
493
- step=0.1,
494
- )
495
- max_tokens_slider = gr.Slider(
496
- label="Max Tokens",
497
- minimum=64,
498
- maximum=4096,
499
- value=512,
500
- step=64,
501
- )
502
-
503
- gr.Markdown(
504
- f"""
505
- ### Status
506
- - **ZeroGPU:** {'Available' if ZEROGPU_AVAILABLE else 'Not Available'}
507
- - **Fallback:** {'Enabled' if config.fallback_enabled else 'Disabled'}
508
- """
509
- )
510
-
511
- with gr.Column(scale=3):
512
- chatbot = gr.ChatInterface(
513
- fn=gradio_chat,
514
- additional_inputs=[model_dropdown, temperature_slider, max_tokens_slider],
515
- title="",
516
- )
517
-
518
-
519
- # Add redirect from root to /ui for convenience
520
- from fastapi.responses import RedirectResponse
521
-
522
- @api_app.get("/", include_in_schema=False)
523
- async def redirect_to_ui():
524
- """Redirect root to Gradio UI."""
525
- return RedirectResponse(url="/ui")
526
-
527
-
528
- # Mount Gradio into FastAPI app at /ui, API at root level
529
- app = gr.mount_gradio_app(api_app, demo, path="/ui")
530
-
531
 
532
  if __name__ == "__main__":
533
- import uvicorn
534
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
12
  # Import spaces FIRST - required for ZeroGPU GPU detection
13
  import spaces
14
 
 
 
 
 
 
 
 
15
  import logging
16
  import time
17
  from typing import Optional
18
 
19
  import gradio as gr
20
  import httpx
21
+ from fastapi import Header, HTTPException, Request
22
  from fastapi.responses import StreamingResponse, JSONResponse
23
  from huggingface_hub import HfApi
24
 
 
 
25
  from config import get_config, get_quota_tracker
26
  from models import (
27
  apply_chat_template,
 
46
  # HuggingFace API for token validation
47
  hf_api = HfApi()
48
 
49
+ ZEROGPU_AVAILABLE = True
50
+
51
 
52
  # --- Authentication ---
53
 
 
75
  return authorization
76
 
77
 
78
+ # --- ZeroGPU Inference Functions ---
79
+ # These MUST be decorated with @spaces.GPU for ZeroGPU detection
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  @spaces.GPU(duration=120)
 
185
  )
186
 
187
 
188
+ # --- Gradio Chat Function (GPU decorated for ZeroGPU) ---
189
 
190
 
191
  @spaces.GPU(duration=120)
 
230
  return f"Error generating response: {str(e)}"
231
 
232
 
233
+ # --- Build Gradio Interface ---
 
234
 
235
+ with gr.Blocks(title="ZeroGPU OpenCode Provider") as demo:
236
+ gr.Markdown(
237
+ """
238
+ # ZeroGPU OpenCode Provider
239
+
240
+ OpenAI-compatible inference endpoint for [opencode](https://github.com/sst/opencode).
241
+
242
+ **API Endpoint:** `/v1/chat/completions`
243
+
244
+ ## Usage with opencode
245
+
246
+ Configure in `~/.config/opencode/opencode.json`:
247
+
248
+ ```json
249
+ {
250
+ "providers": {
251
+ "zerogpu": {
252
+ "npm": "@ai-sdk/openai-compatible",
253
+ "options": {
254
+ "baseURL": "https://serenichron-opencode-zerogpu.hf.space/v1",
255
+ "headers": {
256
+ "Authorization": "Bearer hf_YOUR_TOKEN"
257
+ }
258
+ },
259
+ "models": {
260
+ "llama-8b": {
261
+ "name": "meta-llama/Llama-3.1-8B-Instruct"
262
+ }
263
+ }
264
+ }
265
+ }
266
+ }
267
+ ```
268
+
269
+ ---
270
+ """
271
+ )
272
+
273
+ with gr.Row():
274
+ with gr.Column(scale=1):
275
+ model_dropdown = gr.Dropdown(
276
+ label="Model",
277
+ choices=[
278
+ "meta-llama/Llama-3.1-8B-Instruct",
279
+ "mistralai/Mistral-7B-Instruct-v0.3",
280
+ "Qwen/Qwen2.5-7B-Instruct",
281
+ "Qwen/Qwen2.5-14B-Instruct",
282
+ ],
283
+ value="meta-llama/Llama-3.1-8B-Instruct",
284
+ allow_custom_value=True,
285
+ )
286
+ temperature_slider = gr.Slider(
287
+ label="Temperature",
288
+ minimum=0.0,
289
+ maximum=2.0,
290
+ value=0.7,
291
+ step=0.1,
292
+ )
293
+ max_tokens_slider = gr.Slider(
294
+ label="Max Tokens",
295
+ minimum=64,
296
+ maximum=4096,
297
+ value=512,
298
+ step=64,
299
+ )
300
 
301
+ gr.Markdown(
302
+ f"""
303
+ ### Status
304
+ - **ZeroGPU:** {'Available' if ZEROGPU_AVAILABLE else 'Not Available'}
305
+ - **Fallback:** {'Enabled' if config.fallback_enabled else 'Disabled'}
306
+ """
307
+ )
308
 
309
+ with gr.Column(scale=3):
310
+ chatbot = gr.ChatInterface(
311
+ fn=gradio_chat,
312
+ additional_inputs=[model_dropdown, temperature_slider, max_tokens_slider],
313
+ title="",
314
+ )
315
+
316
+
317
+ # --- Add OpenAI-compatible API routes to Gradio's internal FastAPI app ---
318
+
319
+ # Get the underlying FastAPI app from Gradio
320
+ app = demo.app
321
+
322
+
323
+ @app.post("/v1/chat/completions")
324
  async def chat_completions(
325
+ request: Request,
 
326
  ):
327
  """
328
  OpenAI-compatible chat completions endpoint.
329
 
330
  Supports both streaming and non-streaming responses.
331
  """
332
+ # Get authorization header
333
+ authorization = request.headers.get("authorization")
334
+
335
  # Validate authentication
336
  token = extract_token(authorization)
337
  if not token or not validate_hf_token(token):
 
344
  ).model_dump(),
345
  )
346
 
347
+ # Parse request body
348
+ try:
349
+ body = await request.json()
350
+ chat_request = ChatCompletionRequest(**body)
351
+ except Exception as e:
352
+ return JSONResponse(
353
+ status_code=400,
354
+ content=create_error_response(
355
+ message=f"Invalid request body: {str(e)}",
356
+ error_type="invalid_request_error",
357
+ ).model_dump(),
358
+ )
359
+
360
  # Extract inference parameters
361
+ params = InferenceParams.from_request(chat_request)
362
 
363
  # Apply chat template
364
  try:
 
471
  )
472
 
473
 
474
+ @app.get("/v1/models")
475
+ async def list_models(request: Request):
476
  """List available models (returns info about current model if loaded)."""
477
+ authorization = request.headers.get("authorization")
478
  token = extract_token(authorization)
479
  if not token or not validate_hf_token(token):
480
  return JSONResponse(
 
502
  return {"object": "list", "data": models}
503
 
504
 
505
+ @app.get("/health")
506
  async def health_check():
507
  """Health check endpoint."""
508
  return {
 
513
  }
514
 
515
 
516
+ # --- Launch the application ---
517
+ # On HuggingFace Spaces, the runtime handles the launch automatically
518
+ # The demo object is exposed for the Gradio SDK to use
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
 
520
  if __name__ == "__main__":
521
+ # Local development
522
+ demo.launch(server_name="0.0.0.0", server_port=7860)