import json import re import time import uuid import os import traceback from typing import List, Optional, Dict, Any, Union from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from sse_starlette.sse import EventSourceResponse from pydantic import BaseModel, Field # Import the core from the cloned Grok-Api repository from core import Grok app = FastAPI(title="Grok OpenAI Wrapper (Agent & MCP Compatible)") # --- OpenAI Pydantic Models --- class FunctionCall(BaseModel): name: str arguments: str class ToolCall(BaseModel): id: str type: str = "function" function: FunctionCall class ChatMessage(BaseModel): role: str content: Optional[str] = None name: Optional[str] = None tool_calls: Optional[List[ToolCall]] = None tool_call_id: Optional[str] = None class ChatCompletionRequest(BaseModel): model: str = "grok-3-fast" messages: List[ChatMessage] stream: Optional[bool] = False tools: Optional[List[Dict[Any, Any]]] = None tool_choice: Optional[Any] = None temperature: Optional[float] = 0.7 # --- Helper Functions --- def format_messages_and_tools(messages: List[ChatMessage], tools: Optional[List[Dict]]) -> str: """Translates the standard OpenAI message history into a single string for Grok""" prompt = "" # 1. Inject Tools via System Prompt Strategy if tools exist if tools: prompt += ( "SYSTEM INSTRUCTION: You are an intelligent AI acting as an API. You have access to tools. " "If you need to call a tool, you MUST reply ONLY with a JSON block in the exact format below, and no other text.\n" '```json\n{"tool_calls":[{"name": "function_name", "arguments": {"arg_name": "arg_value"}}]}\n```\n' "Available tools:\n" + json.dumps(tools, indent=2) + "\n\n" ) # 2. Append Message History for msg in messages: if msg.role == "system": prompt += f"System: {msg.content}\n\n" elif msg.role == "user": prompt += f"User: {msg.content}\n\n" elif msg.role == "assistant": if msg.tool_calls: # Convert tool calls to dicts to cleanly dump them tc_dicts =[{"name": tc.function.name, "arguments": json.loads(tc.function.arguments)} for tc in msg.tool_calls] prompt += f"Assistant called tools: {json.dumps(tc_dicts)}\n\n" if msg.content: prompt += f"Assistant: {msg.content}\n\n" elif msg.role in ["tool", "function"]: prompt += f"TOOL RESULT (for {msg.tool_call_id or msg.name}): {msg.content}\n\n" prompt += "Assistant: " return prompt def extract_tool_calls(text: str): """Parses Grok's response to check if it emitted our forced JSON tool call""" # Look for a markdown JSON block, or fall back to raw text match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL) json_str = match.group(1) if match else text.strip() try: parsed = json.loads(json_str) if "tool_calls" in parsed and isinstance(parsed["tool_calls"], list): formatted_calls = [] for tc in parsed["tool_calls"]: formatted_calls.append({ "id": f"call_{uuid.uuid4().hex[:8]}", "type": "function", "function": { "name": tc.get("name"), "arguments": json.dumps(tc.get("arguments", {})) # OpenAI expects a stringified JSON } }) return formatted_calls except json.JSONDecodeError: pass return None # --- API Endpoints --- @app.post("/v1/chat/completions") async def chat_completions(request: ChatCompletionRequest): # 1. Prepare Prompt mega_prompt = format_messages_and_tools(request.messages, request.tools) # 2. Check for Proxy in Environment Variables # If Hugging Face IPs are blocked by Cloudflare, setting this in HF Secrets fixes it. proxy_url = os.environ.get("GROK_PROXY", None) try: # 3. Call Grok if proxy_url: grok_client = Grok(request.model, proxy_url) else: grok_client = Grok(request.model) raw_response = grok_client.start_convo(mega_prompt) response_text = raw_response.get("response", "") stream_array = raw_response.get("stream_response",[]) except UnboundLocalError as e: # This catches the specific "local variable 'script_content1' referenced before assignment" error error_msg = ( "Grok API Scraper failed to find session tokens. " "This usually means Hugging Face's IP is blocked by Grok's Cloudflare, or Grok updated their website DOM. " "Fix: Add a 'GROK_PROXY' secret in your HF Space settings (e.g., http://user:pass@ip:port)." ) print(f"Scraper Error: {traceback.format_exc()}") raise HTTPException(status_code=500, detail=error_msg) except Exception as e: print(f"Unknown Error: {traceback.format_exc()}") raise HTTPException(status_code=500, detail=f"Upstream Grok Error: {str(e)}") # 4. Parse Tool Calls tool_calls = extract_tool_calls(response_text) if request.tools else None # 5. Handle Streaming Response if request.stream: async def event_generator(): # Tool calls are emitted as one chunk to prevent breaking JSON parsers in agents if tool_calls: chunk = { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion.chunk", "model": request.model, "choices":[{"index": 0, "delta": {"tool_calls": tool_calls}, "finish_reason": "tool_calls"}] } yield {"data": json.dumps(chunk)} else: # Simulate streaming using the token array for token in stream_array: chunk = { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion.chunk", "model": request.model, "choices": [{"index": 0, "delta": {"content": token}, "finish_reason": None}] } yield {"data": json.dumps(chunk)} time.sleep(0.01) # Small delay for smooth streaming # Final STOP chunk yield { "data": json.dumps({ "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion.chunk", "model": request.model, "choices":[{"index": 0, "delta": {}, "finish_reason": "stop"}] }) } yield {"data": "[DONE]"} return EventSourceResponse(event_generator()) # 6. Handle Standard Sync Response response_msg = {"role": "assistant", "content": None if tool_calls else response_text} if tool_calls: response_msg["tool_calls"] = tool_calls return JSONResponse(content={ "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(time.time()), "model": request.model, "choices":[{ "index": 0, "message": response_msg, "finish_reason": "tool_calls" if tool_calls else "stop" }], "usage": { "prompt_tokens": len(mega_prompt) // 4, "completion_tokens": len(response_text) // 4, "total_tokens": (len(mega_prompt) + len(response_text)) // 4 } }) @app.get("/v1/models") async def list_models(): return { "object": "list", "data":[ {"id": "grok-3-auto", "object": "model", "created": int(time.time()), "owned_by": "xai"}, {"id": "grok-3-fast", "object": "model", "created": int(time.time()), "owned_by": "xai"}, {"id": "grok-4", "object": "model", "created": int(time.time()), "owned_by": "xai"}, {"id": "grok-4-mini-thinking-tahoe", "object": "model", "created": int(time.time()), "owned_by": "xai"} ] }