Spaces:
Sleeping
Sleeping
| import json | |
| import re | |
| import time | |
| import uuid | |
| import os | |
| import traceback | |
| from typing import List, Optional, Dict, Any, Union | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from sse_starlette.sse import EventSourceResponse | |
| from pydantic import BaseModel, Field | |
| # Import the core from the cloned Grok-Api repository | |
| from core import Grok | |
| app = FastAPI(title="Grok OpenAI Wrapper (Agent & MCP Compatible)") | |
| # --- OpenAI Pydantic Models --- | |
| class FunctionCall(BaseModel): | |
| name: str | |
| arguments: str | |
| class ToolCall(BaseModel): | |
| id: str | |
| type: str = "function" | |
| function: FunctionCall | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: Optional[str] = None | |
| name: Optional[str] = None | |
| tool_calls: Optional[List[ToolCall]] = None | |
| tool_call_id: Optional[str] = None | |
| class ChatCompletionRequest(BaseModel): | |
| model: str = "grok-3-fast" | |
| messages: List[ChatMessage] | |
| stream: Optional[bool] = False | |
| tools: Optional[List[Dict[Any, Any]]] = None | |
| tool_choice: Optional[Any] = None | |
| temperature: Optional[float] = 0.7 | |
| # --- Helper Functions --- | |
| def format_messages_and_tools(messages: List[ChatMessage], tools: Optional[List[Dict]]) -> str: | |
| """Translates the standard OpenAI message history into a single string for Grok""" | |
| prompt = "" | |
| # 1. Inject Tools via System Prompt Strategy if tools exist | |
| if tools: | |
| prompt += ( | |
| "SYSTEM INSTRUCTION: You are an intelligent AI acting as an API. You have access to tools. " | |
| "If you need to call a tool, you MUST reply ONLY with a JSON block in the exact format below, and no other text.\n" | |
| '```json\n{"tool_calls":[{"name": "function_name", "arguments": {"arg_name": "arg_value"}}]}\n```\n' | |
| "Available tools:\n" + json.dumps(tools, indent=2) + "\n\n" | |
| ) | |
| # 2. Append Message History | |
| for msg in messages: | |
| if msg.role == "system": | |
| prompt += f"System: {msg.content}\n\n" | |
| elif msg.role == "user": | |
| prompt += f"User: {msg.content}\n\n" | |
| elif msg.role == "assistant": | |
| if msg.tool_calls: | |
| # Convert tool calls to dicts to cleanly dump them | |
| tc_dicts =[{"name": tc.function.name, "arguments": json.loads(tc.function.arguments)} for tc in msg.tool_calls] | |
| prompt += f"Assistant called tools: {json.dumps(tc_dicts)}\n\n" | |
| if msg.content: | |
| prompt += f"Assistant: {msg.content}\n\n" | |
| elif msg.role in ["tool", "function"]: | |
| prompt += f"TOOL RESULT (for {msg.tool_call_id or msg.name}): {msg.content}\n\n" | |
| prompt += "Assistant: " | |
| return prompt | |
| def extract_tool_calls(text: str): | |
| """Parses Grok's response to check if it emitted our forced JSON tool call""" | |
| # Look for a markdown JSON block, or fall back to raw text | |
| match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL) | |
| json_str = match.group(1) if match else text.strip() | |
| try: | |
| parsed = json.loads(json_str) | |
| if "tool_calls" in parsed and isinstance(parsed["tool_calls"], list): | |
| formatted_calls = [] | |
| for tc in parsed["tool_calls"]: | |
| formatted_calls.append({ | |
| "id": f"call_{uuid.uuid4().hex[:8]}", | |
| "type": "function", | |
| "function": { | |
| "name": tc.get("name"), | |
| "arguments": json.dumps(tc.get("arguments", {})) # OpenAI expects a stringified JSON | |
| } | |
| }) | |
| return formatted_calls | |
| except json.JSONDecodeError: | |
| pass | |
| return None | |
| # --- API Endpoints --- | |
| async def chat_completions(request: ChatCompletionRequest): | |
| # 1. Prepare Prompt | |
| mega_prompt = format_messages_and_tools(request.messages, request.tools) | |
| # 2. Check for Proxy in Environment Variables | |
| # If Hugging Face IPs are blocked by Cloudflare, setting this in HF Secrets fixes it. | |
| proxy_url = os.environ.get("GROK_PROXY", None) | |
| try: | |
| # 3. Call Grok | |
| if proxy_url: | |
| grok_client = Grok(request.model, proxy_url) | |
| else: | |
| grok_client = Grok(request.model) | |
| raw_response = grok_client.start_convo(mega_prompt) | |
| response_text = raw_response.get("response", "") | |
| stream_array = raw_response.get("stream_response",[]) | |
| except UnboundLocalError as e: | |
| # This catches the specific "local variable 'script_content1' referenced before assignment" error | |
| error_msg = ( | |
| "Grok API Scraper failed to find session tokens. " | |
| "This usually means Hugging Face's IP is blocked by Grok's Cloudflare, or Grok updated their website DOM. " | |
| "Fix: Add a 'GROK_PROXY' secret in your HF Space settings (e.g., http://user:pass@ip:port)." | |
| ) | |
| print(f"Scraper Error: {traceback.format_exc()}") | |
| raise HTTPException(status_code=500, detail=error_msg) | |
| except Exception as e: | |
| print(f"Unknown Error: {traceback.format_exc()}") | |
| raise HTTPException(status_code=500, detail=f"Upstream Grok Error: {str(e)}") | |
| # 4. Parse Tool Calls | |
| tool_calls = extract_tool_calls(response_text) if request.tools else None | |
| # 5. Handle Streaming Response | |
| if request.stream: | |
| async def event_generator(): | |
| # Tool calls are emitted as one chunk to prevent breaking JSON parsers in agents | |
| if tool_calls: | |
| chunk = { | |
| "id": f"chatcmpl-{uuid.uuid4()}", | |
| "object": "chat.completion.chunk", | |
| "model": request.model, | |
| "choices":[{"index": 0, "delta": {"tool_calls": tool_calls}, "finish_reason": "tool_calls"}] | |
| } | |
| yield {"data": json.dumps(chunk)} | |
| else: | |
| # Simulate streaming using the token array | |
| for token in stream_array: | |
| chunk = { | |
| "id": f"chatcmpl-{uuid.uuid4()}", | |
| "object": "chat.completion.chunk", | |
| "model": request.model, | |
| "choices": [{"index": 0, "delta": {"content": token}, "finish_reason": None}] | |
| } | |
| yield {"data": json.dumps(chunk)} | |
| time.sleep(0.01) # Small delay for smooth streaming | |
| # Final STOP chunk | |
| yield { | |
| "data": json.dumps({ | |
| "id": f"chatcmpl-{uuid.uuid4()}", | |
| "object": "chat.completion.chunk", | |
| "model": request.model, | |
| "choices":[{"index": 0, "delta": {}, "finish_reason": "stop"}] | |
| }) | |
| } | |
| yield {"data": "[DONE]"} | |
| return EventSourceResponse(event_generator()) | |
| # 6. Handle Standard Sync Response | |
| response_msg = {"role": "assistant", "content": None if tool_calls else response_text} | |
| if tool_calls: | |
| response_msg["tool_calls"] = tool_calls | |
| return JSONResponse(content={ | |
| "id": f"chatcmpl-{uuid.uuid4()}", | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": request.model, | |
| "choices":[{ | |
| "index": 0, | |
| "message": response_msg, | |
| "finish_reason": "tool_calls" if tool_calls else "stop" | |
| }], | |
| "usage": { | |
| "prompt_tokens": len(mega_prompt) // 4, | |
| "completion_tokens": len(response_text) // 4, | |
| "total_tokens": (len(mega_prompt) + len(response_text)) // 4 | |
| } | |
| }) | |
| async def list_models(): | |
| return { | |
| "object": "list", | |
| "data":[ | |
| {"id": "grok-3-auto", "object": "model", "created": int(time.time()), "owned_by": "xai"}, | |
| {"id": "grok-3-fast", "object": "model", "created": int(time.time()), "owned_by": "xai"}, | |
| {"id": "grok-4", "object": "model", "created": int(time.time()), "owned_by": "xai"}, | |
| {"id": "grok-4-mini-thinking-tahoe", "object": "model", "created": int(time.time()), "owned_by": "xai"} | |
| ] | |
| } |