serenichron's picture
Use Gradio's native gr.api() for custom endpoints
217c046
"""
HuggingFace ZeroGPU Space - OpenAI-compatible inference provider for opencode.
This Gradio app provides:
- OpenAI-compatible API via Gradio's native API system
- Pass-through model selection (any HF model ID)
- ZeroGPU H200 inference with HF Serverless fallback
- HF Token authentication
- SSE streaming support
"""
# Import spaces FIRST - required for ZeroGPU GPU detection
import spaces
import logging
import time
from typing import Optional
import gradio as gr
import httpx
from huggingface_hub import HfApi
from config import get_config, get_quota_tracker
from models import (
apply_chat_template,
generate_text,
generate_text_stream,
get_current_model,
)
from openai_compat import (
ChatCompletionRequest,
InferenceParams,
create_chat_response,
create_error_response,
estimate_tokens,
)
logger = logging.getLogger(__name__)
config = get_config()
quota_tracker = get_quota_tracker()
# HuggingFace API for token validation
hf_api = HfApi()
ZEROGPU_AVAILABLE = True
# --- Authentication ---
def validate_hf_token(token: str) -> bool:
"""Validate a HuggingFace token by checking with the API."""
if not token or not token.startswith("hf_"):
return False
try:
hf_api.whoami(token=token)
return True
except Exception:
return False
# --- ZeroGPU Inference Functions ---
# These MUST be decorated with @spaces.GPU for ZeroGPU detection
@spaces.GPU(duration=120)
def zerogpu_generate(
model_id: str,
prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
) -> str:
"""Generate text using ZeroGPU (H200 GPU)."""
start_time = time.time()
result = generate_text(
model_id=model_id,
prompt=prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
stop_sequences=None,
)
# Track quota usage
duration = time.time() - start_time
quota_tracker.add_usage(duration)
return result
# --- HF Serverless Fallback ---
def serverless_generate_sync(
model_id: str,
prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
token: str,
) -> str:
"""Generate text using HuggingFace Serverless Inference API (sync version)."""
url = f"https://api-inference.huggingface.co/models/{model_id}"
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"return_full_text": False,
},
}
with httpx.Client() as client:
response = client.post(
url,
json=payload,
headers={"Authorization": f"Bearer {token}"},
timeout=120.0,
)
if response.status_code != 200:
raise Exception(f"HF Serverless error: {response.text}")
result = response.json()
# Handle different response formats
if isinstance(result, list) and len(result) > 0:
if "generated_text" in result[0]:
return result[0]["generated_text"]
raise Exception(f"Unexpected response format from HF Serverless: {result}")
# --- Gradio Chat Function (GPU decorated for ZeroGPU) ---
@spaces.GPU(duration=120)
def gradio_chat(
message: str,
history: list[list[str]],
model_id: str,
temperature: float,
max_tokens: int,
):
"""Gradio chat interface handler - GPU decorated for ZeroGPU."""
# Validate model_id
if not model_id:
return "Please select a model first."
# Build messages from history
messages = []
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
# Apply chat template
try:
prompt = apply_chat_template(model_id, messages)
except Exception as e:
return f"Error loading model: {str(e)}"
# Generate response (non-streaming for simplicity with ZeroGPU)
try:
response = generate_text(
model_id=model_id,
prompt=prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=0.95,
stop_sequences=None,
)
return response
except Exception as e:
return f"Error generating response: {str(e)}"
# --- API Functions for Gradio's gr.api() ---
def api_health() -> dict:
"""Health check endpoint."""
return {
"status": "healthy",
"zerogpu_available": ZEROGPU_AVAILABLE,
"quota_remaining_minutes": quota_tracker.remaining_minutes(),
"fallback_enabled": config.fallback_enabled,
}
def api_chat_completions(
token: str,
model: str,
messages: list[dict],
temperature: float = 0.7,
max_tokens: int = 512,
top_p: float = 0.95,
) -> dict:
"""
OpenAI-compatible chat completions.
Args:
token: HuggingFace API token (hf_xxx)
model: HuggingFace model ID (e.g., "meta-llama/Llama-3.1-8B-Instruct")
messages: List of message dicts with "role" and "content"
temperature: Sampling temperature (0.0-2.0)
max_tokens: Maximum tokens to generate
top_p: Nucleus sampling probability
Returns:
OpenAI-compatible response dict
"""
# Validate authentication
if not token or not validate_hf_token(token):
return create_error_response(
message="Invalid or missing HuggingFace token",
error_type="authentication_error",
code="invalid_api_key",
).model_dump()
# Apply chat template
try:
prompt = apply_chat_template(model, messages)
except Exception as e:
logger.error(f"Failed to apply chat template: {e}")
return create_error_response(
message=f"Failed to load model or apply chat template: {str(e)}",
error_type="invalid_request_error",
param="model",
).model_dump()
prompt_tokens = estimate_tokens(prompt)
# Determine inference method
use_zerogpu = ZEROGPU_AVAILABLE and not quota_tracker.quota_exhausted
if not use_zerogpu and not config.fallback_enabled:
return create_error_response(
message="ZeroGPU quota exhausted and fallback is disabled",
error_type="server_error",
code="quota_exhausted",
).model_dump()
try:
# Non-streaming response
if use_zerogpu:
response_text = zerogpu_generate(
model_id=model,
prompt=prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
else:
logger.info("Using HF Serverless fallback")
response_text = serverless_generate_sync(
model_id=model,
prompt=prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
token=token,
)
completion_tokens = estimate_tokens(response_text)
return create_chat_response(
model=model,
content=response_text,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
).model_dump()
except Exception as e:
logger.exception(f"Inference error: {e}")
return create_error_response(
message=f"Inference failed: {str(e)}",
error_type="server_error",
).model_dump()
# --- Build Gradio Interface ---
with gr.Blocks(title="ZeroGPU OpenCode Provider") as demo:
gr.Markdown(
"""
# ZeroGPU OpenCode Provider
OpenAI-compatible inference endpoint for [opencode](https://github.com/sst/opencode).
**API Endpoints:**
- `/api/health` - Health check
- `/api/chat_completions` - Chat completions (OpenAI-compatible response format)
## Usage with opencode
Configure in `~/.config/opencode/opencode.json`:
```json
{
"providers": {
"zerogpu": {
"npm": "@ai-sdk/openai-compatible",
"options": {
"baseURL": "https://serenichron-opencode-zerogpu.hf.space/api",
"headers": {
"Authorization": "Bearer hf_YOUR_TOKEN"
}
},
"models": {
"llama-8b": {
"name": "meta-llama/Llama-3.1-8B-Instruct"
}
}
}
}
}
```
---
"""
)
with gr.Row():
with gr.Column(scale=1):
model_dropdown = gr.Dropdown(
label="Model",
choices=[
"meta-llama/Llama-3.1-8B-Instruct",
"mistralai/Mistral-7B-Instruct-v0.3",
"Qwen/Qwen2.5-7B-Instruct",
"Qwen/Qwen2.5-14B-Instruct",
],
value="meta-llama/Llama-3.1-8B-Instruct",
allow_custom_value=True,
)
temperature_slider = gr.Slider(
label="Temperature",
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.1,
)
max_tokens_slider = gr.Slider(
label="Max Tokens",
minimum=64,
maximum=4096,
value=512,
step=64,
)
gr.Markdown(
f"""
### Status
- **ZeroGPU:** {'Available' if ZEROGPU_AVAILABLE else 'Not Available'}
- **Fallback:** {'Enabled' if config.fallback_enabled else 'Disabled'}
"""
)
with gr.Column(scale=3):
chatbot = gr.ChatInterface(
fn=gradio_chat,
additional_inputs=[model_dropdown, temperature_slider, max_tokens_slider],
title="",
)
# Register API endpoints using Gradio's API system
# These will be available at /api/<name>
gr.api(api_health, api_name="health")
gr.api(api_chat_completions, api_name="chat_completions")
# --- Launch the application ---
# On HuggingFace Spaces, the runtime handles the launch automatically
# We just expose the demo object
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)