serenichron commited on
Commit
217c046
·
1 Parent(s): 728a4ac

Use Gradio's native gr.api() for custom endpoints

Browse files

- Remove all FastAPI/Starlette code
- Use gr.api() to register health and chat_completions endpoints
- Endpoints available at /api/health and /api/chat_completions
- Pure Gradio approach for ZeroGPU compatibility
- OpenAI-compatible response format maintained

Files changed (1) hide show
  1. app.py +85 -255
app.py CHANGED
@@ -2,7 +2,7 @@
2
  HuggingFace ZeroGPU Space - OpenAI-compatible inference provider for opencode.
3
 
4
  This Gradio app provides:
5
- - OpenAI-compatible /v1/chat/completions endpoint
6
  - Pass-through model selection (any HF model ID)
7
  - ZeroGPU H200 inference with HF Serverless fallback
8
  - HF Token authentication
@@ -14,16 +14,11 @@ import spaces
14
 
15
  import logging
16
  import time
17
- import json
18
  from typing import Optional
19
 
20
  import gradio as gr
21
  import httpx
22
  from huggingface_hub import HfApi
23
- from starlette.applications import Starlette
24
- from starlette.routing import Route, Mount
25
- from starlette.responses import JSONResponse, StreamingResponse, RedirectResponse
26
- from starlette.requests import Request
27
 
28
  from config import get_config, get_quota_tracker
29
  from models import (
@@ -38,7 +33,6 @@ from openai_compat import (
38
  create_chat_response,
39
  create_error_response,
40
  estimate_tokens,
41
- stream_response_generator,
42
  )
43
 
44
  logger = logging.getLogger(__name__)
@@ -67,17 +61,6 @@ def validate_hf_token(token: str) -> bool:
67
  return False
68
 
69
 
70
- def extract_token(authorization: Optional[str]) -> Optional[str]:
71
- """Extract the token from the Authorization header."""
72
- if not authorization:
73
- return None
74
-
75
- if authorization.startswith("Bearer "):
76
- return authorization[7:]
77
-
78
- return authorization
79
-
80
-
81
  # --- ZeroGPU Inference Functions ---
82
  # These MUST be decorated with @spaces.GPU for ZeroGPU detection
83
 
@@ -89,7 +72,6 @@ def zerogpu_generate(
89
  max_new_tokens: int,
90
  temperature: float,
91
  top_p: float,
92
- stop_sequences: Optional[list[str]],
93
  ) -> str:
94
  """Generate text using ZeroGPU (H200 GPU)."""
95
  start_time = time.time()
@@ -100,7 +82,7 @@ def zerogpu_generate(
100
  max_new_tokens=max_new_tokens,
101
  temperature=temperature,
102
  top_p=top_p,
103
- stop_sequences=stop_sequences,
104
  )
105
 
106
  # Track quota usage
@@ -110,37 +92,10 @@ def zerogpu_generate(
110
  return result
111
 
112
 
113
- @spaces.GPU(duration=120)
114
- def zerogpu_generate_stream(
115
- model_id: str,
116
- prompt: str,
117
- max_new_tokens: int,
118
- temperature: float,
119
- top_p: float,
120
- stop_sequences: Optional[list[str]],
121
- ):
122
- """Generate text with streaming using ZeroGPU (H200 GPU)."""
123
- start_time = time.time()
124
-
125
- for token in generate_text_stream(
126
- model_id=model_id,
127
- prompt=prompt,
128
- max_new_tokens=max_new_tokens,
129
- temperature=temperature,
130
- top_p=top_p,
131
- stop_sequences=stop_sequences,
132
- ):
133
- yield token
134
-
135
- # Track quota usage
136
- duration = time.time() - start_time
137
- quota_tracker.add_usage(duration)
138
-
139
-
140
  # --- HF Serverless Fallback ---
141
 
142
 
143
- async def serverless_generate(
144
  model_id: str,
145
  prompt: str,
146
  max_new_tokens: int,
@@ -148,7 +103,7 @@ async def serverless_generate(
148
  top_p: float,
149
  token: str,
150
  ) -> str:
151
- """Generate text using HuggingFace Serverless Inference API."""
152
  url = f"https://api-inference.huggingface.co/models/{model_id}"
153
 
154
  payload = {
@@ -161,8 +116,8 @@ async def serverless_generate(
161
  },
162
  }
163
 
164
- async with httpx.AsyncClient() as client:
165
- response = await client.post(
166
  url,
167
  json=payload,
168
  headers={"Authorization": f"Bearer {token}"},
@@ -227,99 +182,59 @@ def gradio_chat(
227
  return f"Error generating response: {str(e)}"
228
 
229
 
230
- # --- API Route Handlers (Starlette) ---
231
 
232
 
233
- async def health_check(request: Request):
234
  """Health check endpoint."""
235
- return JSONResponse({
236
  "status": "healthy",
237
  "zerogpu_available": ZEROGPU_AVAILABLE,
238
  "quota_remaining_minutes": quota_tracker.remaining_minutes(),
239
  "fallback_enabled": config.fallback_enabled,
240
- })
241
-
242
-
243
- async def list_models(request: Request):
244
- """List available models (returns info about current model if loaded)."""
245
- authorization = request.headers.get("authorization")
246
- token = extract_token(authorization)
247
- if not token or not validate_hf_token(token):
248
- return JSONResponse(
249
- create_error_response(
250
- message="Invalid or missing HuggingFace token",
251
- error_type="authentication_error",
252
- code="invalid_api_key",
253
- ).model_dump(),
254
- status_code=401,
255
- )
256
-
257
- current = get_current_model()
258
- models = []
259
-
260
- if current:
261
- models.append(
262
- {
263
- "id": current.model_id,
264
- "object": "model",
265
- "created": int(time.time()),
266
- "owned_by": "huggingface",
267
- }
268
- )
269
-
270
- return JSONResponse({"object": "list", "data": models})
271
 
272
 
273
- async def chat_completions(request: Request):
 
 
 
 
 
 
 
274
  """
275
- OpenAI-compatible chat completions endpoint.
276
-
277
- Supports both streaming and non-streaming responses.
 
 
 
 
 
 
 
 
 
278
  """
279
- # Get authorization header
280
- authorization = request.headers.get("authorization")
281
-
282
  # Validate authentication
283
- token = extract_token(authorization)
284
  if not token or not validate_hf_token(token):
285
- return JSONResponse(
286
- create_error_response(
287
- message="Invalid or missing HuggingFace token",
288
- error_type="authentication_error",
289
- code="invalid_api_key",
290
- ).model_dump(),
291
- status_code=401,
292
- )
293
-
294
- # Parse request body
295
- try:
296
- body = await request.json()
297
- chat_request = ChatCompletionRequest(**body)
298
- except Exception as e:
299
- return JSONResponse(
300
- create_error_response(
301
- message=f"Invalid request body: {str(e)}",
302
- error_type="invalid_request_error",
303
- ).model_dump(),
304
- status_code=400,
305
- )
306
-
307
- # Extract inference parameters
308
- params = InferenceParams.from_request(chat_request)
309
 
310
  # Apply chat template
311
  try:
312
- prompt = apply_chat_template(params.model_id, params.messages)
313
  except Exception as e:
314
  logger.error(f"Failed to apply chat template: {e}")
315
- return JSONResponse(
316
- create_error_response(
317
- message=f"Failed to load model or apply chat template: {str(e)}",
318
- error_type="invalid_request_error",
319
- param="model",
320
- ).model_dump(),
321
- status_code=400,
322
- )
323
 
324
  prompt_tokens = estimate_tokens(prompt)
325
 
@@ -327,96 +242,48 @@ async def chat_completions(request: Request):
327
  use_zerogpu = ZEROGPU_AVAILABLE and not quota_tracker.quota_exhausted
328
 
329
  if not use_zerogpu and not config.fallback_enabled:
330
- return JSONResponse(
331
- create_error_response(
332
- message="ZeroGPU quota exhausted and fallback is disabled",
333
- error_type="server_error",
334
- code="quota_exhausted",
335
- ).model_dump(),
336
- status_code=503,
337
- )
338
 
339
  try:
340
- if params.stream:
341
- # Streaming response
342
- if use_zerogpu:
343
- token_gen = zerogpu_generate_stream(
344
- model_id=params.model_id,
345
- prompt=prompt,
346
- max_new_tokens=params.max_new_tokens,
347
- temperature=params.temperature,
348
- top_p=params.top_p,
349
- stop_sequences=params.stop_sequences,
350
- )
351
- else:
352
- # Fallback doesn't support streaming, so generate full response
353
- # and simulate streaming
354
- logger.info("Using HF Serverless fallback (no streaming)")
355
- full_response = await serverless_generate(
356
- model_id=params.model_id,
357
- prompt=prompt,
358
- max_new_tokens=params.max_new_tokens,
359
- temperature=params.temperature,
360
- top_p=params.top_p,
361
- token=token,
362
- )
363
-
364
- def simulate_stream():
365
- # Yield the full response as a single chunk
366
- yield full_response
367
-
368
- token_gen = simulate_stream()
369
-
370
- return StreamingResponse(
371
- stream_response_generator(params.model_id, token_gen),
372
- media_type="text/event-stream",
373
- headers={
374
- "Cache-Control": "no-cache",
375
- "Connection": "keep-alive",
376
- "X-Accel-Buffering": "no",
377
- },
378
  )
379
  else:
380
- # Non-streaming response
381
- if use_zerogpu:
382
- response_text = zerogpu_generate(
383
- model_id=params.model_id,
384
- prompt=prompt,
385
- max_new_tokens=params.max_new_tokens,
386
- temperature=params.temperature,
387
- top_p=params.top_p,
388
- stop_sequences=params.stop_sequences,
389
- )
390
- else:
391
- logger.info("Using HF Serverless fallback")
392
- response_text = await serverless_generate(
393
- model_id=params.model_id,
394
- prompt=prompt,
395
- max_new_tokens=params.max_new_tokens,
396
- temperature=params.temperature,
397
- top_p=params.top_p,
398
- token=token,
399
- )
400
-
401
- completion_tokens = estimate_tokens(response_text)
402
-
403
- response = create_chat_response(
404
- model=params.model_id,
405
- content=response_text,
406
- prompt_tokens=prompt_tokens,
407
- completion_tokens=completion_tokens,
408
  )
409
- return JSONResponse(response.model_dump())
 
 
 
 
 
 
 
 
410
 
411
  except Exception as e:
412
  logger.exception(f"Inference error: {e}")
413
- return JSONResponse(
414
- create_error_response(
415
- message=f"Inference failed: {str(e)}",
416
- error_type="server_error",
417
- ).model_dump(),
418
- status_code=500,
419
- )
420
 
421
 
422
  # --- Build Gradio Interface ---
@@ -428,7 +295,9 @@ with gr.Blocks(title="ZeroGPU OpenCode Provider") as demo:
428
 
429
  OpenAI-compatible inference endpoint for [opencode](https://github.com/sst/opencode).
430
 
431
- **API Endpoint:** `/v1/chat/completions`
 
 
432
 
433
  ## Usage with opencode
434
 
@@ -440,7 +309,7 @@ with gr.Blocks(title="ZeroGPU OpenCode Provider") as demo:
440
  "zerogpu": {
441
  "npm": "@ai-sdk/openai-compatible",
442
  "options": {
443
- "baseURL": "https://serenichron-opencode-zerogpu.hf.space/v1",
444
  "headers": {
445
  "Authorization": "Bearer hf_YOUR_TOKEN"
446
  }
@@ -502,54 +371,15 @@ with gr.Blocks(title="ZeroGPU OpenCode Provider") as demo:
502
  title="",
503
  )
504
 
505
-
506
- # --- Create combined ASGI app with API routes BEFORE Gradio ---
507
- # This ensures our API routes take precedence over Gradio's catch-all
508
-
509
- api_routes = [
510
- Route("/health", health_check, methods=["GET"]),
511
- Route("/v1/models", list_models, methods=["GET"]),
512
- Route("/v1/chat/completions", chat_completions, methods=["POST"]),
513
- ]
514
-
515
- # Create a Starlette app for API routes
516
- api_app = Starlette(routes=api_routes)
517
-
518
-
519
- # Custom ASGI middleware that routes API paths to our handlers
520
- class APIRoutingMiddleware:
521
- def __init__(self, app, api_app, api_paths):
522
- self.app = app # Gradio app
523
- self.api_app = api_app # Starlette app with API routes
524
- self.api_paths = api_paths # Paths to route to API
525
-
526
- async def __call__(self, scope, receive, send):
527
- if scope["type"] == "http":
528
- path = scope["path"]
529
- # Check if this path should go to our API
530
- for api_path in self.api_paths:
531
- if path == api_path or path.startswith(api_path + "/"):
532
- await self.api_app(scope, receive, send)
533
- return
534
- # Otherwise, let Gradio handle it
535
- await self.app(scope, receive, send)
536
-
537
-
538
- # Get Gradio's ASGI app and wrap it with our middleware
539
- gradio_app = demo.app
540
-
541
- # Wrap Gradio with our API routing middleware
542
- app = APIRoutingMiddleware(
543
- gradio_app,
544
- api_app,
545
- api_paths=["/health", "/v1"]
546
- )
547
 
548
 
549
  # --- Launch the application ---
550
  # On HuggingFace Spaces, the runtime handles the launch automatically
551
- # The demo object is exposed for ZeroGPU detection
552
 
553
  if __name__ == "__main__":
554
- import uvicorn
555
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
2
  HuggingFace ZeroGPU Space - OpenAI-compatible inference provider for opencode.
3
 
4
  This Gradio app provides:
5
+ - OpenAI-compatible API via Gradio's native API system
6
  - Pass-through model selection (any HF model ID)
7
  - ZeroGPU H200 inference with HF Serverless fallback
8
  - HF Token authentication
 
14
 
15
  import logging
16
  import time
 
17
  from typing import Optional
18
 
19
  import gradio as gr
20
  import httpx
21
  from huggingface_hub import HfApi
 
 
 
 
22
 
23
  from config import get_config, get_quota_tracker
24
  from models import (
 
33
  create_chat_response,
34
  create_error_response,
35
  estimate_tokens,
 
36
  )
37
 
38
  logger = logging.getLogger(__name__)
 
61
  return False
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
64
  # --- ZeroGPU Inference Functions ---
65
  # These MUST be decorated with @spaces.GPU for ZeroGPU detection
66
 
 
72
  max_new_tokens: int,
73
  temperature: float,
74
  top_p: float,
 
75
  ) -> str:
76
  """Generate text using ZeroGPU (H200 GPU)."""
77
  start_time = time.time()
 
82
  max_new_tokens=max_new_tokens,
83
  temperature=temperature,
84
  top_p=top_p,
85
+ stop_sequences=None,
86
  )
87
 
88
  # Track quota usage
 
92
  return result
93
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  # --- HF Serverless Fallback ---
96
 
97
 
98
+ def serverless_generate_sync(
99
  model_id: str,
100
  prompt: str,
101
  max_new_tokens: int,
 
103
  top_p: float,
104
  token: str,
105
  ) -> str:
106
+ """Generate text using HuggingFace Serverless Inference API (sync version)."""
107
  url = f"https://api-inference.huggingface.co/models/{model_id}"
108
 
109
  payload = {
 
116
  },
117
  }
118
 
119
+ with httpx.Client() as client:
120
+ response = client.post(
121
  url,
122
  json=payload,
123
  headers={"Authorization": f"Bearer {token}"},
 
182
  return f"Error generating response: {str(e)}"
183
 
184
 
185
+ # --- API Functions for Gradio's gr.api() ---
186
 
187
 
188
+ def api_health() -> dict:
189
  """Health check endpoint."""
190
+ return {
191
  "status": "healthy",
192
  "zerogpu_available": ZEROGPU_AVAILABLE,
193
  "quota_remaining_minutes": quota_tracker.remaining_minutes(),
194
  "fallback_enabled": config.fallback_enabled,
195
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
 
198
+ def api_chat_completions(
199
+ token: str,
200
+ model: str,
201
+ messages: list[dict],
202
+ temperature: float = 0.7,
203
+ max_tokens: int = 512,
204
+ top_p: float = 0.95,
205
+ ) -> dict:
206
  """
207
+ OpenAI-compatible chat completions.
208
+
209
+ Args:
210
+ token: HuggingFace API token (hf_xxx)
211
+ model: HuggingFace model ID (e.g., "meta-llama/Llama-3.1-8B-Instruct")
212
+ messages: List of message dicts with "role" and "content"
213
+ temperature: Sampling temperature (0.0-2.0)
214
+ max_tokens: Maximum tokens to generate
215
+ top_p: Nucleus sampling probability
216
+
217
+ Returns:
218
+ OpenAI-compatible response dict
219
  """
 
 
 
220
  # Validate authentication
 
221
  if not token or not validate_hf_token(token):
222
+ return create_error_response(
223
+ message="Invalid or missing HuggingFace token",
224
+ error_type="authentication_error",
225
+ code="invalid_api_key",
226
+ ).model_dump()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  # Apply chat template
229
  try:
230
+ prompt = apply_chat_template(model, messages)
231
  except Exception as e:
232
  logger.error(f"Failed to apply chat template: {e}")
233
+ return create_error_response(
234
+ message=f"Failed to load model or apply chat template: {str(e)}",
235
+ error_type="invalid_request_error",
236
+ param="model",
237
+ ).model_dump()
 
 
 
238
 
239
  prompt_tokens = estimate_tokens(prompt)
240
 
 
242
  use_zerogpu = ZEROGPU_AVAILABLE and not quota_tracker.quota_exhausted
243
 
244
  if not use_zerogpu and not config.fallback_enabled:
245
+ return create_error_response(
246
+ message="ZeroGPU quota exhausted and fallback is disabled",
247
+ error_type="server_error",
248
+ code="quota_exhausted",
249
+ ).model_dump()
 
 
 
250
 
251
  try:
252
+ # Non-streaming response
253
+ if use_zerogpu:
254
+ response_text = zerogpu_generate(
255
+ model_id=model,
256
+ prompt=prompt,
257
+ max_new_tokens=max_tokens,
258
+ temperature=temperature,
259
+ top_p=top_p,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  )
261
  else:
262
+ logger.info("Using HF Serverless fallback")
263
+ response_text = serverless_generate_sync(
264
+ model_id=model,
265
+ prompt=prompt,
266
+ max_new_tokens=max_tokens,
267
+ temperature=temperature,
268
+ top_p=top_p,
269
+ token=token,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  )
271
+
272
+ completion_tokens = estimate_tokens(response_text)
273
+
274
+ return create_chat_response(
275
+ model=model,
276
+ content=response_text,
277
+ prompt_tokens=prompt_tokens,
278
+ completion_tokens=completion_tokens,
279
+ ).model_dump()
280
 
281
  except Exception as e:
282
  logger.exception(f"Inference error: {e}")
283
+ return create_error_response(
284
+ message=f"Inference failed: {str(e)}",
285
+ error_type="server_error",
286
+ ).model_dump()
 
 
 
287
 
288
 
289
  # --- Build Gradio Interface ---
 
295
 
296
  OpenAI-compatible inference endpoint for [opencode](https://github.com/sst/opencode).
297
 
298
+ **API Endpoints:**
299
+ - `/api/health` - Health check
300
+ - `/api/chat_completions` - Chat completions (OpenAI-compatible response format)
301
 
302
  ## Usage with opencode
303
 
 
309
  "zerogpu": {
310
  "npm": "@ai-sdk/openai-compatible",
311
  "options": {
312
+ "baseURL": "https://serenichron-opencode-zerogpu.hf.space/api",
313
  "headers": {
314
  "Authorization": "Bearer hf_YOUR_TOKEN"
315
  }
 
371
  title="",
372
  )
373
 
374
+ # Register API endpoints using Gradio's API system
375
+ # These will be available at /api/<name>
376
+ gr.api(api_health, api_name="health")
377
+ gr.api(api_chat_completions, api_name="chat_completions")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
 
379
 
380
  # --- Launch the application ---
381
  # On HuggingFace Spaces, the runtime handles the launch automatically
382
+ # We just expose the demo object
383
 
384
  if __name__ == "__main__":
385
+ demo.launch(server_name="0.0.0.0", server_port=7860)