File size: 8,303 Bytes
aec3672
 
 
 
dae2ba4
 
 
aec3672
dae2ba4
aec3672
 
 
 
 
 
 
 
 
 
dae2ba4
 
 
 
 
 
 
 
 
aec3672
 
 
 
dae2ba4
aec3672
 
 
 
 
 
 
 
 
 
 
 
dae2ba4
aec3672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dae2ba4
 
 
aec3672
 
dae2ba4
aec3672
 
 
 
 
 
 
dae2ba4
aec3672
 
 
 
 
 
 
 
 
 
 
 
 
dae2ba4
aec3672
 
 
 
 
 
 
 
 
 
 
 
 
dae2ba4
 
 
 
aec3672
dae2ba4
 
 
 
 
 
aec3672
 
 
 
 
dae2ba4
 
 
 
 
 
 
 
 
aec3672
dae2ba4
 
aec3672
dae2ba4
aec3672
 
dae2ba4
aec3672
 
dae2ba4
aec3672
 
 
 
 
 
 
 
 
dae2ba4
aec3672
 
 
 
 
dae2ba4
aec3672
 
dae2ba4
aec3672
dae2ba4
aec3672
 
 
 
 
 
 
 
 
 
 
 
dae2ba4
aec3672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import json
import re
import time
import uuid
import os
import traceback
from typing import List, Optional, Dict, Any, Union

from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel, Field

# Import the core from the cloned Grok-Api repository
from core import Grok

app = FastAPI(title="Grok OpenAI Wrapper (Agent & MCP Compatible)")

# --- OpenAI Pydantic Models ---
class FunctionCall(BaseModel):
    name: str
    arguments: str

class ToolCall(BaseModel):
    id: str
    type: str = "function"
    function: FunctionCall

class ChatMessage(BaseModel):
    role: str
    content: Optional[str] = None
    name: Optional[str] = None
    tool_calls: Optional[List[ToolCall]] = None
    tool_call_id: Optional[str] = None

class ChatCompletionRequest(BaseModel):
    model: str = "grok-3-fast"
    messages: List[ChatMessage]
    stream: Optional[bool] = False
    tools: Optional[List[Dict[Any, Any]]] = None
    tool_choice: Optional[Any] = None
    temperature: Optional[float] = 0.7

# --- Helper Functions ---
def format_messages_and_tools(messages: List[ChatMessage], tools: Optional[List[Dict]]) -> str:
    """Translates the standard OpenAI message history into a single string for Grok"""
    prompt = ""
    
    # 1. Inject Tools via System Prompt Strategy if tools exist
    if tools:
        prompt += (
            "SYSTEM INSTRUCTION: You are an intelligent AI acting as an API. You have access to tools. "
            "If you need to call a tool, you MUST reply ONLY with a JSON block in the exact format below, and no other text.\n"
            '```json\n{"tool_calls":[{"name": "function_name", "arguments": {"arg_name": "arg_value"}}]}\n```\n'
            "Available tools:\n" + json.dumps(tools, indent=2) + "\n\n"
        )

    # 2. Append Message History
    for msg in messages:
        if msg.role == "system":
            prompt += f"System: {msg.content}\n\n"
        elif msg.role == "user":
            prompt += f"User: {msg.content}\n\n"
        elif msg.role == "assistant":
            if msg.tool_calls:
                # Convert tool calls to dicts to cleanly dump them
                tc_dicts =[{"name": tc.function.name, "arguments": json.loads(tc.function.arguments)} for tc in msg.tool_calls]
                prompt += f"Assistant called tools: {json.dumps(tc_dicts)}\n\n"
            if msg.content:
                prompt += f"Assistant: {msg.content}\n\n"
        elif msg.role in ["tool", "function"]:
            prompt += f"TOOL RESULT (for {msg.tool_call_id or msg.name}): {msg.content}\n\n"
            
    prompt += "Assistant: "
    return prompt

def extract_tool_calls(text: str):
    """Parses Grok's response to check if it emitted our forced JSON tool call"""
    # Look for a markdown JSON block, or fall back to raw text
    match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
    json_str = match.group(1) if match else text.strip()
    
    try:
        parsed = json.loads(json_str)
        if "tool_calls" in parsed and isinstance(parsed["tool_calls"], list):
            formatted_calls = []
            for tc in parsed["tool_calls"]:
                formatted_calls.append({
                    "id": f"call_{uuid.uuid4().hex[:8]}",
                    "type": "function",
                    "function": {
                        "name": tc.get("name"),
                        "arguments": json.dumps(tc.get("arguments", {})) # OpenAI expects a stringified JSON
                    }
                })
            return formatted_calls
    except json.JSONDecodeError:
        pass
    return None

# --- API Endpoints ---
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
    # 1. Prepare Prompt
    mega_prompt = format_messages_and_tools(request.messages, request.tools)
    
    # 2. Check for Proxy in Environment Variables
    # If Hugging Face IPs are blocked by Cloudflare, setting this in HF Secrets fixes it.
    proxy_url = os.environ.get("GROK_PROXY", None)
    
    try:
        # 3. Call Grok
        if proxy_url:
            grok_client = Grok(request.model, proxy_url)
        else:
            grok_client = Grok(request.model)
            
        raw_response = grok_client.start_convo(mega_prompt)
        
        response_text = raw_response.get("response", "")
        stream_array = raw_response.get("stream_response",[])
        
    except UnboundLocalError as e:
        # This catches the specific "local variable 'script_content1' referenced before assignment" error
        error_msg = (
            "Grok API Scraper failed to find session tokens. "
            "This usually means Hugging Face's IP is blocked by Grok's Cloudflare, or Grok updated their website DOM. "
            "Fix: Add a 'GROK_PROXY' secret in your HF Space settings (e.g., http://user:pass@ip:port)."
        )
        print(f"Scraper Error: {traceback.format_exc()}")
        raise HTTPException(status_code=500, detail=error_msg)
    except Exception as e:
        print(f"Unknown Error: {traceback.format_exc()}")
        raise HTTPException(status_code=500, detail=f"Upstream Grok Error: {str(e)}")

    # 4. Parse Tool Calls
    tool_calls = extract_tool_calls(response_text) if request.tools else None

    # 5. Handle Streaming Response
    if request.stream:
        async def event_generator():
            # Tool calls are emitted as one chunk to prevent breaking JSON parsers in agents
            if tool_calls:
                chunk = {
                    "id": f"chatcmpl-{uuid.uuid4()}",
                    "object": "chat.completion.chunk",
                    "model": request.model,
                    "choices":[{"index": 0, "delta": {"tool_calls": tool_calls}, "finish_reason": "tool_calls"}]
                }
                yield {"data": json.dumps(chunk)}
            else:
                # Simulate streaming using the token array
                for token in stream_array:
                    chunk = {
                        "id": f"chatcmpl-{uuid.uuid4()}",
                        "object": "chat.completion.chunk",
                        "model": request.model,
                        "choices": [{"index": 0, "delta": {"content": token}, "finish_reason": None}]
                    }
                    yield {"data": json.dumps(chunk)}
                    time.sleep(0.01) # Small delay for smooth streaming
                
                # Final STOP chunk
                yield {
                    "data": json.dumps({
                        "id": f"chatcmpl-{uuid.uuid4()}",
                        "object": "chat.completion.chunk",
                        "model": request.model,
                        "choices":[{"index": 0, "delta": {}, "finish_reason": "stop"}]
                    })
                }
            yield {"data": "[DONE]"}
            
        return EventSourceResponse(event_generator())

    # 6. Handle Standard Sync Response
    response_msg = {"role": "assistant", "content": None if tool_calls else response_text}
    if tool_calls:
        response_msg["tool_calls"] = tool_calls

    return JSONResponse(content={
        "id": f"chatcmpl-{uuid.uuid4()}",
        "object": "chat.completion",
        "created": int(time.time()),
        "model": request.model,
        "choices":[{
            "index": 0,
            "message": response_msg,
            "finish_reason": "tool_calls" if tool_calls else "stop"
        }],
        "usage": {
            "prompt_tokens": len(mega_prompt) // 4,
            "completion_tokens": len(response_text) // 4,
            "total_tokens": (len(mega_prompt) + len(response_text)) // 4
        }
    })

@app.get("/v1/models")
async def list_models():
    return {
        "object": "list",
        "data":[
            {"id": "grok-3-auto", "object": "model", "created": int(time.time()), "owned_by": "xai"},
            {"id": "grok-3-fast", "object": "model", "created": int(time.time()), "owned_by": "xai"},
            {"id": "grok-4", "object": "model", "created": int(time.time()), "owned_by": "xai"},
            {"id": "grok-4-mini-thinking-tahoe", "object": "model", "created": int(time.time()), "owned_by": "xai"}
        ]
    }