kazukaraya12 commited on
Commit
dae2ba4
·
verified ·
1 Parent(s): bb960b6

Update openai_server.py

Browse files
Files changed (1) hide show
  1. openai_server.py +50 -20
openai_server.py CHANGED
@@ -2,9 +2,11 @@ import json
2
  import re
3
  import time
4
  import uuid
5
- from typing import List, Optional, Dict, Any
 
 
6
 
7
- from fastapi import FastAPI, Request, HTTPException
8
  from fastapi.responses import JSONResponse
9
  from sse_starlette.sse import EventSourceResponse
10
  from pydantic import BaseModel, Field
@@ -15,11 +17,20 @@ from core import Grok
15
  app = FastAPI(title="Grok OpenAI Wrapper (Agent & MCP Compatible)")
16
 
17
  # --- OpenAI Pydantic Models ---
 
 
 
 
 
 
 
 
 
18
  class ChatMessage(BaseModel):
19
  role: str
20
  content: Optional[str] = None
21
  name: Optional[str] = None
22
- tool_calls: Optional[List[Dict[Any, Any]]] = None
23
  tool_call_id: Optional[str] = None
24
 
25
  class ChatCompletionRequest(BaseModel):
@@ -32,7 +43,7 @@ class ChatCompletionRequest(BaseModel):
32
 
33
  # --- Helper Functions ---
34
  def format_messages_and_tools(messages: List[ChatMessage], tools: Optional[List[Dict]]) -> str:
35
- """Translates the standard OpenAI message history into a single string for the web scraper"""
36
  prompt = ""
37
 
38
  # 1. Inject Tools via System Prompt Strategy if tools exist
@@ -52,11 +63,12 @@ def format_messages_and_tools(messages: List[ChatMessage], tools: Optional[List[
52
  prompt += f"User: {msg.content}\n\n"
53
  elif msg.role == "assistant":
54
  if msg.tool_calls:
55
- prompt += f"Assistant called tools: {json.dumps(msg.tool_calls)}\n\n"
 
 
56
  if msg.content:
57
  prompt += f"Assistant: {msg.content}\n\n"
58
- elif msg.role == "tool" or msg.role == "function":
59
- # Pass tool results back to the model
60
  prompt += f"TOOL RESULT (for {msg.tool_call_id or msg.name}): {msg.content}\n\n"
61
 
62
  prompt += "Assistant: "
@@ -64,7 +76,7 @@ def format_messages_and_tools(messages: List[ChatMessage], tools: Optional[List[
64
 
65
  def extract_tool_calls(text: str):
66
  """Parses Grok's response to check if it emitted our forced JSON tool call"""
67
- # Look for a markdown JSON block
68
  match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
69
  json_str = match.group(1) if match else text.strip()
70
 
@@ -78,7 +90,7 @@ def extract_tool_calls(text: str):
78
  "type": "function",
79
  "function": {
80
  "name": tc.get("name"),
81
- "arguments": json.dumps(tc.get("arguments", {})) # OpenAI expects a stringified JSON here
82
  }
83
  })
84
  return formatted_calls
@@ -92,24 +104,42 @@ async def chat_completions(request: ChatCompletionRequest):
92
  # 1. Prepare Prompt
93
  mega_prompt = format_messages_and_tools(request.messages, request.tools)
94
 
 
 
 
 
95
  try:
96
- # 2. Call the underlying Grok Wrapper (Stateless, passing entire context in prompt)
97
- grok_client = Grok(request.model)
 
 
 
 
98
  raw_response = grok_client.start_convo(mega_prompt)
99
 
100
  response_text = raw_response.get("response", "")
101
  stream_array = raw_response.get("stream_response",[])
102
 
 
 
 
 
 
 
 
 
 
103
  except Exception as e:
104
- raise HTTPException(status_code=500, detail=str(e))
 
105
 
106
- # 3. Check if response is a tool call
107
  tool_calls = extract_tool_calls(response_text) if request.tools else None
108
 
109
- # 4. Handle Streaming Response
110
  if request.stream:
111
  async def event_generator():
112
- # If it's a tool call, we typically don't stream it, but send it as one chunk
113
  if tool_calls:
114
  chunk = {
115
  "id": f"chatcmpl-{uuid.uuid4()}",
@@ -119,18 +149,18 @@ async def chat_completions(request: ChatCompletionRequest):
119
  }
120
  yield {"data": json.dumps(chunk)}
121
  else:
122
- # Fake the stream using the token array returned by the API
123
  for token in stream_array:
124
  chunk = {
125
  "id": f"chatcmpl-{uuid.uuid4()}",
126
  "object": "chat.completion.chunk",
127
  "model": request.model,
128
- "choices":[{"index": 0, "delta": {"content": token}, "finish_reason": None}]
129
  }
130
  yield {"data": json.dumps(chunk)}
131
- time.sleep(0.01) # slight delay to emulate natural streaming
132
 
133
- # Final finish reason chunk
134
  yield {
135
  "data": json.dumps({
136
  "id": f"chatcmpl-{uuid.uuid4()}",
@@ -143,7 +173,7 @@ async def chat_completions(request: ChatCompletionRequest):
143
 
144
  return EventSourceResponse(event_generator())
145
 
146
- # 5. Handle Standard Sync Response
147
  response_msg = {"role": "assistant", "content": None if tool_calls else response_text}
148
  if tool_calls:
149
  response_msg["tool_calls"] = tool_calls
 
2
  import re
3
  import time
4
  import uuid
5
+ import os
6
+ import traceback
7
+ from typing import List, Optional, Dict, Any, Union
8
 
9
+ from fastapi import FastAPI, HTTPException
10
  from fastapi.responses import JSONResponse
11
  from sse_starlette.sse import EventSourceResponse
12
  from pydantic import BaseModel, Field
 
17
  app = FastAPI(title="Grok OpenAI Wrapper (Agent & MCP Compatible)")
18
 
19
  # --- OpenAI Pydantic Models ---
20
+ class FunctionCall(BaseModel):
21
+ name: str
22
+ arguments: str
23
+
24
+ class ToolCall(BaseModel):
25
+ id: str
26
+ type: str = "function"
27
+ function: FunctionCall
28
+
29
  class ChatMessage(BaseModel):
30
  role: str
31
  content: Optional[str] = None
32
  name: Optional[str] = None
33
+ tool_calls: Optional[List[ToolCall]] = None
34
  tool_call_id: Optional[str] = None
35
 
36
  class ChatCompletionRequest(BaseModel):
 
43
 
44
  # --- Helper Functions ---
45
  def format_messages_and_tools(messages: List[ChatMessage], tools: Optional[List[Dict]]) -> str:
46
+ """Translates the standard OpenAI message history into a single string for Grok"""
47
  prompt = ""
48
 
49
  # 1. Inject Tools via System Prompt Strategy if tools exist
 
63
  prompt += f"User: {msg.content}\n\n"
64
  elif msg.role == "assistant":
65
  if msg.tool_calls:
66
+ # Convert tool calls to dicts to cleanly dump them
67
+ tc_dicts =[{"name": tc.function.name, "arguments": json.loads(tc.function.arguments)} for tc in msg.tool_calls]
68
+ prompt += f"Assistant called tools: {json.dumps(tc_dicts)}\n\n"
69
  if msg.content:
70
  prompt += f"Assistant: {msg.content}\n\n"
71
+ elif msg.role in ["tool", "function"]:
 
72
  prompt += f"TOOL RESULT (for {msg.tool_call_id or msg.name}): {msg.content}\n\n"
73
 
74
  prompt += "Assistant: "
 
76
 
77
  def extract_tool_calls(text: str):
78
  """Parses Grok's response to check if it emitted our forced JSON tool call"""
79
+ # Look for a markdown JSON block, or fall back to raw text
80
  match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
81
  json_str = match.group(1) if match else text.strip()
82
 
 
90
  "type": "function",
91
  "function": {
92
  "name": tc.get("name"),
93
+ "arguments": json.dumps(tc.get("arguments", {})) # OpenAI expects a stringified JSON
94
  }
95
  })
96
  return formatted_calls
 
104
  # 1. Prepare Prompt
105
  mega_prompt = format_messages_and_tools(request.messages, request.tools)
106
 
107
+ # 2. Check for Proxy in Environment Variables
108
+ # If Hugging Face IPs are blocked by Cloudflare, setting this in HF Secrets fixes it.
109
+ proxy_url = os.environ.get("GROK_PROXY", None)
110
+
111
  try:
112
+ # 3. Call Grok
113
+ if proxy_url:
114
+ grok_client = Grok(request.model, proxy_url)
115
+ else:
116
+ grok_client = Grok(request.model)
117
+
118
  raw_response = grok_client.start_convo(mega_prompt)
119
 
120
  response_text = raw_response.get("response", "")
121
  stream_array = raw_response.get("stream_response",[])
122
 
123
+ except UnboundLocalError as e:
124
+ # This catches the specific "local variable 'script_content1' referenced before assignment" error
125
+ error_msg = (
126
+ "Grok API Scraper failed to find session tokens. "
127
+ "This usually means Hugging Face's IP is blocked by Grok's Cloudflare, or Grok updated their website DOM. "
128
+ "Fix: Add a 'GROK_PROXY' secret in your HF Space settings (e.g., http://user:pass@ip:port)."
129
+ )
130
+ print(f"Scraper Error: {traceback.format_exc()}")
131
+ raise HTTPException(status_code=500, detail=error_msg)
132
  except Exception as e:
133
+ print(f"Unknown Error: {traceback.format_exc()}")
134
+ raise HTTPException(status_code=500, detail=f"Upstream Grok Error: {str(e)}")
135
 
136
+ # 4. Parse Tool Calls
137
  tool_calls = extract_tool_calls(response_text) if request.tools else None
138
 
139
+ # 5. Handle Streaming Response
140
  if request.stream:
141
  async def event_generator():
142
+ # Tool calls are emitted as one chunk to prevent breaking JSON parsers in agents
143
  if tool_calls:
144
  chunk = {
145
  "id": f"chatcmpl-{uuid.uuid4()}",
 
149
  }
150
  yield {"data": json.dumps(chunk)}
151
  else:
152
+ # Simulate streaming using the token array
153
  for token in stream_array:
154
  chunk = {
155
  "id": f"chatcmpl-{uuid.uuid4()}",
156
  "object": "chat.completion.chunk",
157
  "model": request.model,
158
+ "choices": [{"index": 0, "delta": {"content": token}, "finish_reason": None}]
159
  }
160
  yield {"data": json.dumps(chunk)}
161
+ time.sleep(0.01) # Small delay for smooth streaming
162
 
163
+ # Final STOP chunk
164
  yield {
165
  "data": json.dumps({
166
  "id": f"chatcmpl-{uuid.uuid4()}",
 
173
 
174
  return EventSourceResponse(event_generator())
175
 
176
+ # 6. Handle Standard Sync Response
177
  response_msg = {"role": "assistant", "content": None if tool_calls else response_text}
178
  if tool_calls:
179
  response_msg["tool_calls"] = tool_calls