Spaces:
Sleeping
Sleeping
File size: 8,303 Bytes
aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 dae2ba4 aec3672 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | import json
import re
import time
import uuid
import os
import traceback
from typing import List, Optional, Dict, Any, Union
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel, Field
# Import the core from the cloned Grok-Api repository
from core import Grok
app = FastAPI(title="Grok OpenAI Wrapper (Agent & MCP Compatible)")
# --- OpenAI Pydantic Models ---
class FunctionCall(BaseModel):
name: str
arguments: str
class ToolCall(BaseModel):
id: str
type: str = "function"
function: FunctionCall
class ChatMessage(BaseModel):
role: str
content: Optional[str] = None
name: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = None
tool_call_id: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: str = "grok-3-fast"
messages: List[ChatMessage]
stream: Optional[bool] = False
tools: Optional[List[Dict[Any, Any]]] = None
tool_choice: Optional[Any] = None
temperature: Optional[float] = 0.7
# --- Helper Functions ---
def format_messages_and_tools(messages: List[ChatMessage], tools: Optional[List[Dict]]) -> str:
"""Translates the standard OpenAI message history into a single string for Grok"""
prompt = ""
# 1. Inject Tools via System Prompt Strategy if tools exist
if tools:
prompt += (
"SYSTEM INSTRUCTION: You are an intelligent AI acting as an API. You have access to tools. "
"If you need to call a tool, you MUST reply ONLY with a JSON block in the exact format below, and no other text.\n"
'```json\n{"tool_calls":[{"name": "function_name", "arguments": {"arg_name": "arg_value"}}]}\n```\n'
"Available tools:\n" + json.dumps(tools, indent=2) + "\n\n"
)
# 2. Append Message History
for msg in messages:
if msg.role == "system":
prompt += f"System: {msg.content}\n\n"
elif msg.role == "user":
prompt += f"User: {msg.content}\n\n"
elif msg.role == "assistant":
if msg.tool_calls:
# Convert tool calls to dicts to cleanly dump them
tc_dicts =[{"name": tc.function.name, "arguments": json.loads(tc.function.arguments)} for tc in msg.tool_calls]
prompt += f"Assistant called tools: {json.dumps(tc_dicts)}\n\n"
if msg.content:
prompt += f"Assistant: {msg.content}\n\n"
elif msg.role in ["tool", "function"]:
prompt += f"TOOL RESULT (for {msg.tool_call_id or msg.name}): {msg.content}\n\n"
prompt += "Assistant: "
return prompt
def extract_tool_calls(text: str):
"""Parses Grok's response to check if it emitted our forced JSON tool call"""
# Look for a markdown JSON block, or fall back to raw text
match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
json_str = match.group(1) if match else text.strip()
try:
parsed = json.loads(json_str)
if "tool_calls" in parsed and isinstance(parsed["tool_calls"], list):
formatted_calls = []
for tc in parsed["tool_calls"]:
formatted_calls.append({
"id": f"call_{uuid.uuid4().hex[:8]}",
"type": "function",
"function": {
"name": tc.get("name"),
"arguments": json.dumps(tc.get("arguments", {})) # OpenAI expects a stringified JSON
}
})
return formatted_calls
except json.JSONDecodeError:
pass
return None
# --- API Endpoints ---
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
# 1. Prepare Prompt
mega_prompt = format_messages_and_tools(request.messages, request.tools)
# 2. Check for Proxy in Environment Variables
# If Hugging Face IPs are blocked by Cloudflare, setting this in HF Secrets fixes it.
proxy_url = os.environ.get("GROK_PROXY", None)
try:
# 3. Call Grok
if proxy_url:
grok_client = Grok(request.model, proxy_url)
else:
grok_client = Grok(request.model)
raw_response = grok_client.start_convo(mega_prompt)
response_text = raw_response.get("response", "")
stream_array = raw_response.get("stream_response",[])
except UnboundLocalError as e:
# This catches the specific "local variable 'script_content1' referenced before assignment" error
error_msg = (
"Grok API Scraper failed to find session tokens. "
"This usually means Hugging Face's IP is blocked by Grok's Cloudflare, or Grok updated their website DOM. "
"Fix: Add a 'GROK_PROXY' secret in your HF Space settings (e.g., http://user:pass@ip:port)."
)
print(f"Scraper Error: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=error_msg)
except Exception as e:
print(f"Unknown Error: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Upstream Grok Error: {str(e)}")
# 4. Parse Tool Calls
tool_calls = extract_tool_calls(response_text) if request.tools else None
# 5. Handle Streaming Response
if request.stream:
async def event_generator():
# Tool calls are emitted as one chunk to prevent breaking JSON parsers in agents
if tool_calls:
chunk = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"model": request.model,
"choices":[{"index": 0, "delta": {"tool_calls": tool_calls}, "finish_reason": "tool_calls"}]
}
yield {"data": json.dumps(chunk)}
else:
# Simulate streaming using the token array
for token in stream_array:
chunk = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"model": request.model,
"choices": [{"index": 0, "delta": {"content": token}, "finish_reason": None}]
}
yield {"data": json.dumps(chunk)}
time.sleep(0.01) # Small delay for smooth streaming
# Final STOP chunk
yield {
"data": json.dumps({
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"model": request.model,
"choices":[{"index": 0, "delta": {}, "finish_reason": "stop"}]
})
}
yield {"data": "[DONE]"}
return EventSourceResponse(event_generator())
# 6. Handle Standard Sync Response
response_msg = {"role": "assistant", "content": None if tool_calls else response_text}
if tool_calls:
response_msg["tool_calls"] = tool_calls
return JSONResponse(content={
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices":[{
"index": 0,
"message": response_msg,
"finish_reason": "tool_calls" if tool_calls else "stop"
}],
"usage": {
"prompt_tokens": len(mega_prompt) // 4,
"completion_tokens": len(response_text) // 4,
"total_tokens": (len(mega_prompt) + len(response_text)) // 4
}
})
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data":[
{"id": "grok-3-auto", "object": "model", "created": int(time.time()), "owned_by": "xai"},
{"id": "grok-3-fast", "object": "model", "created": int(time.time()), "owned_by": "xai"},
{"id": "grok-4", "object": "model", "created": int(time.time()), "owned_by": "xai"},
{"id": "grok-4-mini-thinking-tahoe", "object": "model", "created": int(time.time()), "owned_by": "xai"}
]
} |