Grok-OpenAI-API / openai_server.py
kazukaraya12's picture
Update openai_server.py
dae2ba4 verified
import json
import re
import time
import uuid
import os
import traceback
from typing import List, Optional, Dict, Any, Union
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel, Field
# Import the core from the cloned Grok-Api repository
from core import Grok
app = FastAPI(title="Grok OpenAI Wrapper (Agent & MCP Compatible)")
# --- OpenAI Pydantic Models ---
class FunctionCall(BaseModel):
name: str
arguments: str
class ToolCall(BaseModel):
id: str
type: str = "function"
function: FunctionCall
class ChatMessage(BaseModel):
role: str
content: Optional[str] = None
name: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = None
tool_call_id: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: str = "grok-3-fast"
messages: List[ChatMessage]
stream: Optional[bool] = False
tools: Optional[List[Dict[Any, Any]]] = None
tool_choice: Optional[Any] = None
temperature: Optional[float] = 0.7
# --- Helper Functions ---
def format_messages_and_tools(messages: List[ChatMessage], tools: Optional[List[Dict]]) -> str:
"""Translates the standard OpenAI message history into a single string for Grok"""
prompt = ""
# 1. Inject Tools via System Prompt Strategy if tools exist
if tools:
prompt += (
"SYSTEM INSTRUCTION: You are an intelligent AI acting as an API. You have access to tools. "
"If you need to call a tool, you MUST reply ONLY with a JSON block in the exact format below, and no other text.\n"
'```json\n{"tool_calls":[{"name": "function_name", "arguments": {"arg_name": "arg_value"}}]}\n```\n'
"Available tools:\n" + json.dumps(tools, indent=2) + "\n\n"
)
# 2. Append Message History
for msg in messages:
if msg.role == "system":
prompt += f"System: {msg.content}\n\n"
elif msg.role == "user":
prompt += f"User: {msg.content}\n\n"
elif msg.role == "assistant":
if msg.tool_calls:
# Convert tool calls to dicts to cleanly dump them
tc_dicts =[{"name": tc.function.name, "arguments": json.loads(tc.function.arguments)} for tc in msg.tool_calls]
prompt += f"Assistant called tools: {json.dumps(tc_dicts)}\n\n"
if msg.content:
prompt += f"Assistant: {msg.content}\n\n"
elif msg.role in ["tool", "function"]:
prompt += f"TOOL RESULT (for {msg.tool_call_id or msg.name}): {msg.content}\n\n"
prompt += "Assistant: "
return prompt
def extract_tool_calls(text: str):
"""Parses Grok's response to check if it emitted our forced JSON tool call"""
# Look for a markdown JSON block, or fall back to raw text
match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
json_str = match.group(1) if match else text.strip()
try:
parsed = json.loads(json_str)
if "tool_calls" in parsed and isinstance(parsed["tool_calls"], list):
formatted_calls = []
for tc in parsed["tool_calls"]:
formatted_calls.append({
"id": f"call_{uuid.uuid4().hex[:8]}",
"type": "function",
"function": {
"name": tc.get("name"),
"arguments": json.dumps(tc.get("arguments", {})) # OpenAI expects a stringified JSON
}
})
return formatted_calls
except json.JSONDecodeError:
pass
return None
# --- API Endpoints ---
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
# 1. Prepare Prompt
mega_prompt = format_messages_and_tools(request.messages, request.tools)
# 2. Check for Proxy in Environment Variables
# If Hugging Face IPs are blocked by Cloudflare, setting this in HF Secrets fixes it.
proxy_url = os.environ.get("GROK_PROXY", None)
try:
# 3. Call Grok
if proxy_url:
grok_client = Grok(request.model, proxy_url)
else:
grok_client = Grok(request.model)
raw_response = grok_client.start_convo(mega_prompt)
response_text = raw_response.get("response", "")
stream_array = raw_response.get("stream_response",[])
except UnboundLocalError as e:
# This catches the specific "local variable 'script_content1' referenced before assignment" error
error_msg = (
"Grok API Scraper failed to find session tokens. "
"This usually means Hugging Face's IP is blocked by Grok's Cloudflare, or Grok updated their website DOM. "
"Fix: Add a 'GROK_PROXY' secret in your HF Space settings (e.g., http://user:pass@ip:port)."
)
print(f"Scraper Error: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=error_msg)
except Exception as e:
print(f"Unknown Error: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Upstream Grok Error: {str(e)}")
# 4. Parse Tool Calls
tool_calls = extract_tool_calls(response_text) if request.tools else None
# 5. Handle Streaming Response
if request.stream:
async def event_generator():
# Tool calls are emitted as one chunk to prevent breaking JSON parsers in agents
if tool_calls:
chunk = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"model": request.model,
"choices":[{"index": 0, "delta": {"tool_calls": tool_calls}, "finish_reason": "tool_calls"}]
}
yield {"data": json.dumps(chunk)}
else:
# Simulate streaming using the token array
for token in stream_array:
chunk = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"model": request.model,
"choices": [{"index": 0, "delta": {"content": token}, "finish_reason": None}]
}
yield {"data": json.dumps(chunk)}
time.sleep(0.01) # Small delay for smooth streaming
# Final STOP chunk
yield {
"data": json.dumps({
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"model": request.model,
"choices":[{"index": 0, "delta": {}, "finish_reason": "stop"}]
})
}
yield {"data": "[DONE]"}
return EventSourceResponse(event_generator())
# 6. Handle Standard Sync Response
response_msg = {"role": "assistant", "content": None if tool_calls else response_text}
if tool_calls:
response_msg["tool_calls"] = tool_calls
return JSONResponse(content={
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices":[{
"index": 0,
"message": response_msg,
"finish_reason": "tool_calls" if tool_calls else "stop"
}],
"usage": {
"prompt_tokens": len(mega_prompt) // 4,
"completion_tokens": len(response_text) // 4,
"total_tokens": (len(mega_prompt) + len(response_text)) // 4
}
})
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data":[
{"id": "grok-3-auto", "object": "model", "created": int(time.time()), "owned_by": "xai"},
{"id": "grok-3-fast", "object": "model", "created": int(time.time()), "owned_by": "xai"},
{"id": "grok-4", "object": "model", "created": int(time.time()), "owned_by": "xai"},
{"id": "grok-4-mini-thinking-tahoe", "object": "model", "created": int(time.time()), "owned_by": "xai"}
]
}