""" HuggingFace ZeroGPU Space - OpenAI-compatible inference provider for opencode. This Gradio app provides: - OpenAI-compatible API via Gradio's native API system - Pass-through model selection (any HF model ID) - ZeroGPU H200 inference with HF Serverless fallback - HF Token authentication - SSE streaming support """ # Import spaces FIRST - required for ZeroGPU GPU detection import spaces import logging import time from typing import Optional import gradio as gr import httpx from huggingface_hub import HfApi from config import get_config, get_quota_tracker from models import ( apply_chat_template, generate_text, generate_text_stream, get_current_model, ) from openai_compat import ( ChatCompletionRequest, InferenceParams, create_chat_response, create_error_response, estimate_tokens, ) logger = logging.getLogger(__name__) config = get_config() quota_tracker = get_quota_tracker() # HuggingFace API for token validation hf_api = HfApi() ZEROGPU_AVAILABLE = True # --- Authentication --- def validate_hf_token(token: str) -> bool: """Validate a HuggingFace token by checking with the API.""" if not token or not token.startswith("hf_"): return False try: hf_api.whoami(token=token) return True except Exception: return False # --- ZeroGPU Inference Functions --- # These MUST be decorated with @spaces.GPU for ZeroGPU detection @spaces.GPU(duration=120) def zerogpu_generate( model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float, ) -> str: """Generate text using ZeroGPU (H200 GPU).""" start_time = time.time() result = generate_text( model_id=model_id, prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, stop_sequences=None, ) # Track quota usage duration = time.time() - start_time quota_tracker.add_usage(duration) return result # --- HF Serverless Fallback --- def serverless_generate_sync( model_id: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float, token: str, ) -> str: """Generate text using HuggingFace Serverless Inference API (sync version).""" url = f"https://api-inference.huggingface.co/models/{model_id}" payload = { "inputs": prompt, "parameters": { "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "return_full_text": False, }, } with httpx.Client() as client: response = client.post( url, json=payload, headers={"Authorization": f"Bearer {token}"}, timeout=120.0, ) if response.status_code != 200: raise Exception(f"HF Serverless error: {response.text}") result = response.json() # Handle different response formats if isinstance(result, list) and len(result) > 0: if "generated_text" in result[0]: return result[0]["generated_text"] raise Exception(f"Unexpected response format from HF Serverless: {result}") # --- Gradio Chat Function (GPU decorated for ZeroGPU) --- @spaces.GPU(duration=120) def gradio_chat( message: str, history: list[list[str]], model_id: str, temperature: float, max_tokens: int, ): """Gradio chat interface handler - GPU decorated for ZeroGPU.""" # Validate model_id if not model_id: return "Please select a model first." # Build messages from history messages = [] for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) messages.append({"role": "user", "content": message}) # Apply chat template try: prompt = apply_chat_template(model_id, messages) except Exception as e: return f"Error loading model: {str(e)}" # Generate response (non-streaming for simplicity with ZeroGPU) try: response = generate_text( model_id=model_id, prompt=prompt, max_new_tokens=max_tokens, temperature=temperature, top_p=0.95, stop_sequences=None, ) return response except Exception as e: return f"Error generating response: {str(e)}" # --- API Functions for Gradio's gr.api() --- def api_health() -> dict: """Health check endpoint.""" return { "status": "healthy", "zerogpu_available": ZEROGPU_AVAILABLE, "quota_remaining_minutes": quota_tracker.remaining_minutes(), "fallback_enabled": config.fallback_enabled, } def api_chat_completions( token: str, model: str, messages: list[dict], temperature: float = 0.7, max_tokens: int = 512, top_p: float = 0.95, ) -> dict: """ OpenAI-compatible chat completions. Args: token: HuggingFace API token (hf_xxx) model: HuggingFace model ID (e.g., "meta-llama/Llama-3.1-8B-Instruct") messages: List of message dicts with "role" and "content" temperature: Sampling temperature (0.0-2.0) max_tokens: Maximum tokens to generate top_p: Nucleus sampling probability Returns: OpenAI-compatible response dict """ # Validate authentication if not token or not validate_hf_token(token): return create_error_response( message="Invalid or missing HuggingFace token", error_type="authentication_error", code="invalid_api_key", ).model_dump() # Apply chat template try: prompt = apply_chat_template(model, messages) except Exception as e: logger.error(f"Failed to apply chat template: {e}") return create_error_response( message=f"Failed to load model or apply chat template: {str(e)}", error_type="invalid_request_error", param="model", ).model_dump() prompt_tokens = estimate_tokens(prompt) # Determine inference method use_zerogpu = ZEROGPU_AVAILABLE and not quota_tracker.quota_exhausted if not use_zerogpu and not config.fallback_enabled: return create_error_response( message="ZeroGPU quota exhausted and fallback is disabled", error_type="server_error", code="quota_exhausted", ).model_dump() try: # Non-streaming response if use_zerogpu: response_text = zerogpu_generate( model_id=model, prompt=prompt, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, ) else: logger.info("Using HF Serverless fallback") response_text = serverless_generate_sync( model_id=model, prompt=prompt, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, token=token, ) completion_tokens = estimate_tokens(response_text) return create_chat_response( model=model, content=response_text, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ).model_dump() except Exception as e: logger.exception(f"Inference error: {e}") return create_error_response( message=f"Inference failed: {str(e)}", error_type="server_error", ).model_dump() # --- Build Gradio Interface --- with gr.Blocks(title="ZeroGPU OpenCode Provider") as demo: gr.Markdown( """ # ZeroGPU OpenCode Provider OpenAI-compatible inference endpoint for [opencode](https://github.com/sst/opencode). **API Endpoints:** - `/api/health` - Health check - `/api/chat_completions` - Chat completions (OpenAI-compatible response format) ## Usage with opencode Configure in `~/.config/opencode/opencode.json`: ```json { "providers": { "zerogpu": { "npm": "@ai-sdk/openai-compatible", "options": { "baseURL": "https://serenichron-opencode-zerogpu.hf.space/api", "headers": { "Authorization": "Bearer hf_YOUR_TOKEN" } }, "models": { "llama-8b": { "name": "meta-llama/Llama-3.1-8B-Instruct" } } } } } ``` --- """ ) with gr.Row(): with gr.Column(scale=1): model_dropdown = gr.Dropdown( label="Model", choices=[ "meta-llama/Llama-3.1-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3", "Qwen/Qwen2.5-7B-Instruct", "Qwen/Qwen2.5-14B-Instruct", ], value="meta-llama/Llama-3.1-8B-Instruct", allow_custom_value=True, ) temperature_slider = gr.Slider( label="Temperature", minimum=0.0, maximum=2.0, value=0.7, step=0.1, ) max_tokens_slider = gr.Slider( label="Max Tokens", minimum=64, maximum=4096, value=512, step=64, ) gr.Markdown( f""" ### Status - **ZeroGPU:** {'Available' if ZEROGPU_AVAILABLE else 'Not Available'} - **Fallback:** {'Enabled' if config.fallback_enabled else 'Disabled'} """ ) with gr.Column(scale=3): chatbot = gr.ChatInterface( fn=gradio_chat, additional_inputs=[model_dropdown, temperature_slider, max_tokens_slider], title="", ) # Register API endpoints using Gradio's API system # These will be available at /api/ gr.api(api_health, api_name="health") gr.api(api_chat_completions, api_name="chat_completions") # --- Launch the application --- # On HuggingFace Spaces, the runtime handles the launch automatically # We just expose the demo object if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)