Spaces:
Sleeping
Sleeping
File size: 8,887 Bytes
c1ae554 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 | from __future__ import annotations
import asyncio
import json
import time
import uuid
from typing import Any, Dict, List, Optional
import requests
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from .logging import logger
from .models import ChatCompletionsRequest, ChatMessage
from .reorder import reorder_messages_for_anthropic
from .helpers import normalize_content_to_list, segments_to_text
from .packets import packet_template, map_history_to_warp_messages, attach_user_and_tools_to_inputs
from .state import STATE
from .config import BRIDGE_BASE_URL
from .bridge import initialize_once
from .sse_transform import stream_openai_sse
router = APIRouter()
@router.get("/")
def root():
return {"service": "OpenAI Chat Completions (Warp bridge) - Streaming", "status": "ok"}
@router.get("/healthz")
def health_check():
return {"status": "ok", "service": "OpenAI Chat Completions (Warp bridge) - Streaming"}
@router.get("/v1/models")
def list_models():
"""OpenAI-compatible model listing. Forwards to bridge, with local fallback."""
try:
resp = requests.get(f"{BRIDGE_BASE_URL}/v1/models", timeout=10.0)
if resp.status_code != 200:
raise HTTPException(resp.status_code, f"bridge_error: {resp.text}")
return resp.json()
except Exception as e:
try:
# Local fallback: construct models directly if bridge is unreachable
from warp2protobuf.config.models import get_all_unique_models # type: ignore
models = get_all_unique_models()
return {"object": "list", "data": models}
except Exception:
raise HTTPException(502, f"bridge_unreachable: {e}")
@router.post("/v1/chat/completions")
async def chat_completions(req: ChatCompletionsRequest):
try:
initialize_once()
except Exception as e:
logger.warning(f"[OpenAI Compat] initialize_once failed or skipped: {e}")
if not req.messages:
raise HTTPException(400, "messages 不能为空")
# 1) 打印接收到的 Chat Completions 原始请求体
try:
logger.info("[OpenAI Compat] 接收到的 Chat Completions 请求体(原始): %s", json.dumps(req.dict(), ensure_ascii=False))
except Exception:
logger.info("[OpenAI Compat] 接收到的 Chat Completions 请求体(原始) 序列化失败")
# 整理消息
history: List[ChatMessage] = reorder_messages_for_anthropic(list(req.messages))
# 2) 打印整理后的请求体(post-reorder)
try:
logger.info("[OpenAI Compat] 整理后的请求体(post-reorder): %s", json.dumps({
**req.dict(),
"messages": [m.dict() for m in history]
}, ensure_ascii=False))
except Exception:
logger.info("[OpenAI Compat] 整理后的请求体(post-reorder) 序列化失败")
system_prompt_text: Optional[str] = None
try:
chunks: List[str] = []
for _m in history:
if _m.role == "system":
_txt = segments_to_text(normalize_content_to_list(_m.content))
if _txt.strip():
chunks.append(_txt)
if chunks:
system_prompt_text = "\n\n".join(chunks)
except Exception:
system_prompt_text = None
task_id = STATE.baseline_task_id or str(uuid.uuid4())
packet = packet_template()
packet["task_context"] = {
"tasks": [{
"id": task_id,
"description": "",
"status": {"in_progress": {}},
"messages": map_history_to_warp_messages(history, task_id, None, False),
}],
"active_task_id": task_id,
}
packet.setdefault("settings", {}).setdefault("model_config", {})
packet["settings"]["model_config"]["base"] = req.model or packet["settings"]["model_config"].get("base") or "claude-4.1-opus"
if STATE.conversation_id:
packet.setdefault("metadata", {})["conversation_id"] = STATE.conversation_id
attach_user_and_tools_to_inputs(packet, history, system_prompt_text)
if req.tools:
mcp_tools: List[Dict[str, Any]] = []
for t in req.tools:
if t.type != "function" or not t.function:
continue
mcp_tools.append({
"name": t.function.name,
"description": t.function.description or "",
"input_schema": t.function.parameters or {},
})
if mcp_tools:
packet.setdefault("mcp_context", {}).setdefault("tools", []).extend(mcp_tools)
# 3) 打印转换成 protobuf JSON 的请求体(发送到 bridge 的数据包)
try:
logger.info("[OpenAI Compat] 转换成 Protobuf JSON 的请求体: %s", json.dumps(packet, ensure_ascii=False))
except Exception:
logger.info("[OpenAI Compat] 转换成 Protobuf JSON 的请求体 序列化失败")
created_ts = int(time.time())
completion_id = str(uuid.uuid4())
model_id = req.model or "warp-default"
if req.stream:
async def _agen():
async for chunk in stream_openai_sse(packet, completion_id, created_ts, model_id):
yield chunk
return StreamingResponse(_agen(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"})
def _post_once() -> requests.Response:
return requests.post(
f"{BRIDGE_BASE_URL}/api/warp/send_stream",
json={"json_data": packet, "message_type": "warp.multi_agent.v1.Request"},
timeout=(5.0, 180.0),
)
try:
resp = _post_once()
if resp.status_code == 429:
try:
r = requests.post(f"{BRIDGE_BASE_URL}/api/auth/refresh", timeout=10.0)
logger.warning("[OpenAI Compat] Bridge returned 429. Tried JWT refresh -> HTTP %s", getattr(r, 'status_code', 'N/A'))
except Exception as _e:
logger.warning("[OpenAI Compat] JWT refresh attempt failed after 429: %s", _e)
resp = _post_once()
if resp.status_code != 200:
raise HTTPException(resp.status_code, f"bridge_error: {resp.text}")
bridge_resp = resp.json()
except Exception as e:
raise HTTPException(502, f"bridge_unreachable: {e}")
try:
STATE.conversation_id = bridge_resp.get("conversation_id") or STATE.conversation_id
ret_task_id = bridge_resp.get("task_id")
if isinstance(ret_task_id, str) and ret_task_id:
STATE.baseline_task_id = ret_task_id
except Exception:
pass
tool_calls: List[Dict[str, Any]] = []
try:
parsed_events = bridge_resp.get("parsed_events", []) or []
for ev in parsed_events:
evd = ev.get("parsed_data") or ev.get("raw_data") or {}
client_actions = evd.get("client_actions") or evd.get("clientActions") or {}
actions = client_actions.get("actions") or client_actions.get("Actions") or []
for action in actions:
add_msgs = action.get("add_messages_to_task") or action.get("addMessagesToTask") or {}
if not isinstance(add_msgs, dict):
continue
for message in add_msgs.get("messages", []) or []:
tc = message.get("tool_call") or message.get("toolCall") or {}
call_mcp = tc.get("call_mcp_tool") or tc.get("callMcpTool") or {}
if isinstance(call_mcp, dict) and call_mcp.get("name"):
try:
args_obj = call_mcp.get("args", {}) or {}
args_str = json.dumps(args_obj, ensure_ascii=False)
except Exception:
args_str = "{}"
tool_calls.append({
"id": tc.get("tool_call_id") or str(uuid.uuid4()),
"type": "function",
"function": {"name": call_mcp.get("name"), "arguments": args_str},
})
except Exception:
pass
if tool_calls:
msg_payload = {"role": "assistant", "content": "", "tool_calls": tool_calls}
finish_reason = "tool_calls"
else:
response_text = bridge_resp.get("response", "")
msg_payload = {"role": "assistant", "content": response_text}
finish_reason = "stop"
final = {
"id": completion_id,
"object": "chat.completion",
"created": created_ts,
"model": model_id,
"choices": [{"index": 0, "message": msg_payload, "finish_reason": finish_reason}],
}
return final |