from __future__ import annotations import asyncio import json import time import uuid from typing import Any, Dict, List, Optional import requests from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from .logging import logger from .models import ChatCompletionsRequest, ChatMessage from .reorder import reorder_messages_for_anthropic from .helpers import normalize_content_to_list, segments_to_text from .packets import packet_template, map_history_to_warp_messages, attach_user_and_tools_to_inputs from .state import STATE from .config import BRIDGE_BASE_URL from .bridge import initialize_once from .sse_transform import stream_openai_sse router = APIRouter() @router.get("/") def root(): return {"service": "OpenAI Chat Completions (Warp bridge) - Streaming", "status": "ok"} @router.get("/healthz") def health_check(): return {"status": "ok", "service": "OpenAI Chat Completions (Warp bridge) - Streaming"} @router.get("/v1/models") def list_models(): """OpenAI-compatible model listing. Forwards to bridge, with local fallback.""" try: resp = requests.get(f"{BRIDGE_BASE_URL}/v1/models", timeout=10.0) if resp.status_code != 200: raise HTTPException(resp.status_code, f"bridge_error: {resp.text}") return resp.json() except Exception as e: try: # Local fallback: construct models directly if bridge is unreachable from warp2protobuf.config.models import get_all_unique_models # type: ignore models = get_all_unique_models() return {"object": "list", "data": models} except Exception: raise HTTPException(502, f"bridge_unreachable: {e}") @router.post("/v1/chat/completions") async def chat_completions(req: ChatCompletionsRequest): try: initialize_once() except Exception as e: logger.warning(f"[OpenAI Compat] initialize_once failed or skipped: {e}") if not req.messages: raise HTTPException(400, "messages 不能为空") # 1) 打印接收到的 Chat Completions 原始请求体 try: logger.info("[OpenAI Compat] 接收到的 Chat Completions 请求体(原始): %s", json.dumps(req.dict(), ensure_ascii=False)) except Exception: logger.info("[OpenAI Compat] 接收到的 Chat Completions 请求体(原始) 序列化失败") # 整理消息 history: List[ChatMessage] = reorder_messages_for_anthropic(list(req.messages)) # 2) 打印整理后的请求体(post-reorder) try: logger.info("[OpenAI Compat] 整理后的请求体(post-reorder): %s", json.dumps({ **req.dict(), "messages": [m.dict() for m in history] }, ensure_ascii=False)) except Exception: logger.info("[OpenAI Compat] 整理后的请求体(post-reorder) 序列化失败") system_prompt_text: Optional[str] = None try: chunks: List[str] = [] for _m in history: if _m.role == "system": _txt = segments_to_text(normalize_content_to_list(_m.content)) if _txt.strip(): chunks.append(_txt) if chunks: system_prompt_text = "\n\n".join(chunks) except Exception: system_prompt_text = None task_id = STATE.baseline_task_id or str(uuid.uuid4()) packet = packet_template() packet["task_context"] = { "tasks": [{ "id": task_id, "description": "", "status": {"in_progress": {}}, "messages": map_history_to_warp_messages(history, task_id, None, False), }], "active_task_id": task_id, } packet.setdefault("settings", {}).setdefault("model_config", {}) packet["settings"]["model_config"]["base"] = req.model or packet["settings"]["model_config"].get("base") or "claude-4.1-opus" if STATE.conversation_id: packet.setdefault("metadata", {})["conversation_id"] = STATE.conversation_id attach_user_and_tools_to_inputs(packet, history, system_prompt_text) if req.tools: mcp_tools: List[Dict[str, Any]] = [] for t in req.tools: if t.type != "function" or not t.function: continue mcp_tools.append({ "name": t.function.name, "description": t.function.description or "", "input_schema": t.function.parameters or {}, }) if mcp_tools: packet.setdefault("mcp_context", {}).setdefault("tools", []).extend(mcp_tools) # 3) 打印转换成 protobuf JSON 的请求体(发送到 bridge 的数据包) try: logger.info("[OpenAI Compat] 转换成 Protobuf JSON 的请求体: %s", json.dumps(packet, ensure_ascii=False)) except Exception: logger.info("[OpenAI Compat] 转换成 Protobuf JSON 的请求体 序列化失败") created_ts = int(time.time()) completion_id = str(uuid.uuid4()) model_id = req.model or "warp-default" if req.stream: async def _agen(): async for chunk in stream_openai_sse(packet, completion_id, created_ts, model_id): yield chunk return StreamingResponse(_agen(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}) def _post_once() -> requests.Response: return requests.post( f"{BRIDGE_BASE_URL}/api/warp/send_stream", json={"json_data": packet, "message_type": "warp.multi_agent.v1.Request"}, timeout=(5.0, 180.0), ) try: resp = _post_once() if resp.status_code == 429: try: r = requests.post(f"{BRIDGE_BASE_URL}/api/auth/refresh", timeout=10.0) logger.warning("[OpenAI Compat] Bridge returned 429. Tried JWT refresh -> HTTP %s", getattr(r, 'status_code', 'N/A')) except Exception as _e: logger.warning("[OpenAI Compat] JWT refresh attempt failed after 429: %s", _e) resp = _post_once() if resp.status_code != 200: raise HTTPException(resp.status_code, f"bridge_error: {resp.text}") bridge_resp = resp.json() except Exception as e: raise HTTPException(502, f"bridge_unreachable: {e}") try: STATE.conversation_id = bridge_resp.get("conversation_id") or STATE.conversation_id ret_task_id = bridge_resp.get("task_id") if isinstance(ret_task_id, str) and ret_task_id: STATE.baseline_task_id = ret_task_id except Exception: pass tool_calls: List[Dict[str, Any]] = [] try: parsed_events = bridge_resp.get("parsed_events", []) or [] for ev in parsed_events: evd = ev.get("parsed_data") or ev.get("raw_data") or {} client_actions = evd.get("client_actions") or evd.get("clientActions") or {} actions = client_actions.get("actions") or client_actions.get("Actions") or [] for action in actions: add_msgs = action.get("add_messages_to_task") or action.get("addMessagesToTask") or {} if not isinstance(add_msgs, dict): continue for message in add_msgs.get("messages", []) or []: tc = message.get("tool_call") or message.get("toolCall") or {} call_mcp = tc.get("call_mcp_tool") or tc.get("callMcpTool") or {} if isinstance(call_mcp, dict) and call_mcp.get("name"): try: args_obj = call_mcp.get("args", {}) or {} args_str = json.dumps(args_obj, ensure_ascii=False) except Exception: args_str = "{}" tool_calls.append({ "id": tc.get("tool_call_id") or str(uuid.uuid4()), "type": "function", "function": {"name": call_mcp.get("name"), "arguments": args_str}, }) except Exception: pass if tool_calls: msg_payload = {"role": "assistant", "content": "", "tool_calls": tool_calls} finish_reason = "tool_calls" else: response_text = bridge_resp.get("response", "") msg_payload = {"role": "assistant", "content": response_text} finish_reason = "stop" final = { "id": completion_id, "object": "chat.completion", "created": created_ts, "model": model_id, "choices": [{"index": 0, "message": msg_payload, "finish_reason": finish_reason}], } return final