gcli2api / src /router /antigravity /openai.py
a3216's picture
sync: github -> hf space
c50496f
"""
OpenAI Router - Handles OpenAI format API requests via Antigravity
通过Antigravity处理OpenAI格式请求的路由模块
"""
import sys
from pathlib import Path
# 添加项目根目录到Python路径
project_root = Path(__file__).resolve().parent.parent.parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
# 标准库
import asyncio
import json
# 第三方库
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
# 本地模块 - 配置和日志
from config import get_anti_truncation_max_attempts
from log import log
# 本地模块 - 工具和认证
from src.utils import (
get_base_model_from_feature_model,
is_anti_truncation_model,
is_fake_streaming_model,
authenticate_bearer,
)
# 本地模块 - 转换器(假流式需要)
from src.converter.fake_stream import (
parse_response_for_fake_stream,
build_openai_fake_stream_chunks,
create_openai_heartbeat_chunk,
)
# 本地模块 - 基础路由工具
from src.router.hi_check import is_health_check_request, create_health_check_response
from src.router.stream_passthrough import (
build_streaming_response_or_error,
prepend_async_item,
read_first_async_item,
)
# 本地模块 - 数据模型
from src.models import OpenAIChatCompletionRequest, model_to_dict
# 本地模块 - 任务管理
from src.task_manager import create_managed_task
# ==================== 路由器初始化 ====================
router = APIRouter()
# ==================== API 路由 ====================
@router.post("/antigravity/v1/chat/completions")
async def chat_completions(
openai_request: OpenAIChatCompletionRequest,
token: str = Depends(authenticate_bearer)
):
"""
处理OpenAI格式的聊天完成请求(流式和非流式)
Args:
openai_request: OpenAI格式的请求体
token: Bearer认证令牌
"""
log.debug(f"[ANTIGRAVITY-OPENAI] Request for model: {openai_request.model}")
# 转换为字典
normalized_dict = model_to_dict(openai_request)
# 健康检查
if is_health_check_request(normalized_dict, format="openai"):
response = create_health_check_response(format="openai")
return JSONResponse(content=response)
# 处理模型名称和功能检测
use_fake_streaming = is_fake_streaming_model(openai_request.model)
use_anti_truncation = is_anti_truncation_model(openai_request.model)
real_model = get_base_model_from_feature_model(openai_request.model)
# 获取流式标志
is_streaming = openai_request.stream
# 对于抗截断模型的非流式请求,给出警告
if use_anti_truncation and not is_streaming:
log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置")
# 更新模型名为真实模型名
normalized_dict["model"] = real_model
# 转换为 Gemini 格式 (使用 converter)
from src.converter.openai2gemini import convert_openai_to_gemini_request
gemini_dict = await convert_openai_to_gemini_request(normalized_dict)
# convert_openai_to_gemini_request 不包含 model 字段,需要手动添加
gemini_dict["model"] = real_model
# 规范化 Gemini 请求 (使用 antigravity 模式)
from src.converter.gemini_fix import normalize_gemini_request
gemini_dict = await normalize_gemini_request(gemini_dict, mode="antigravity")
# 准备API请求格式 - 提取model并将其他字段放入request中
api_request = {
"model": gemini_dict.pop("model"),
"request": gemini_dict
}
# ========== 非流式请求 ==========
if not is_streaming:
# 调用 API 层的非流式请求
from src.api.antigravity import non_stream_request
response = await non_stream_request(body=api_request)
# 检查响应状态码
status_code = getattr(response, "status_code", 200)
# 提取响应体
if hasattr(response, "body"):
response_body = response.body.decode() if isinstance(response.body, bytes) else response.body
elif hasattr(response, "content"):
response_body = response.content.decode() if isinstance(response.content, bytes) else response.content
else:
response_body = str(response)
try:
gemini_response = json.loads(response_body)
except Exception as e:
log.error(f"Failed to parse Gemini response: {e}")
raise HTTPException(status_code=500, detail="Response parsing failed")
# 转换为 OpenAI 格式
from src.converter.openai2gemini import convert_gemini_to_openai_response
openai_response = convert_gemini_to_openai_response(
gemini_response,
real_model,
status_code
)
return JSONResponse(content=openai_response, status_code=status_code)
# ========== 流式请求 ==========
# ========== 假流式生成器 ==========
async def fake_stream_generator():
from src.api.antigravity import non_stream_request
response = await non_stream_request(body=api_request)
# 检查响应状态码
if hasattr(response, "status_code") and response.status_code != 200:
log.error(f"Fake streaming got error response: status={response.status_code}")
yield response
return
# 处理成功响应 - 提取响应内容
if hasattr(response, "body"):
response_body = response.body.decode() if isinstance(response.body, bytes) else response.body
elif hasattr(response, "content"):
response_body = response.content.decode() if isinstance(response.content, bytes) else response.content
else:
response_body = str(response)
try:
gemini_response = json.loads(response_body)
log.debug(f"OpenAI fake stream Gemini response: {gemini_response}")
# 检查是否是错误响应(有些错误可能status_code是200但包含error字段)
if "error" in gemini_response:
log.error(f"Fake streaming got error in response body: {gemini_response['error']}")
# 转换错误为 OpenAI 格式
from src.converter.openai2gemini import convert_gemini_to_openai_response
openai_error = convert_gemini_to_openai_response(
gemini_response,
real_model,
200
)
yield f"data: {json.dumps(openai_error)}\n\n".encode()
yield "data: [DONE]\n\n".encode()
return
# 使用统一的解析函数
content, reasoning_content, finish_reason, images = parse_response_for_fake_stream(gemini_response)
log.debug(f"OpenAI extracted content: {content}")
log.debug(f"OpenAI extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...")
log.debug(f"OpenAI extracted images count: {len(images)}")
# 构建响应块
chunks = build_openai_fake_stream_chunks(content, reasoning_content, finish_reason, real_model, images)
for idx, chunk in enumerate(chunks):
chunk_json = json.dumps(chunk)
log.debug(f"[FAKE_STREAM] Yielding chunk #{idx+1}: {chunk_json[:200]}")
yield f"data: {chunk_json}\n\n".encode()
except Exception as e:
log.error(f"Response parsing failed: {e}, directly yield error")
# 构建错误响应
error_chunk = {
"id": "error",
"object": "chat.completion.chunk",
"created": int(asyncio.get_event_loop().time()),
"model": real_model,
"choices": [{
"index": 0,
"delta": {"content": f"Error: {str(e)}"},
"finish_reason": "error"
}]
}
yield f"data: {json.dumps(error_chunk)}\n\n".encode()
yield "data: [DONE]\n\n".encode()
# ========== 流式抗截断生成器 ==========
async def anti_truncation_generator():
from src.converter.anti_truncation import AntiTruncationStreamProcessor
from src.api.antigravity import stream_request
from src.converter.anti_truncation import apply_anti_truncation
from fastapi import Response
max_attempts = await get_anti_truncation_max_attempts()
# 首先对payload应用反截断指令
anti_truncation_payload = apply_anti_truncation(api_request)
first_attempt_stream = stream_request(body=anti_truncation_payload, native=False)
try:
first_chunk = await read_first_async_item(first_attempt_stream)
except StopAsyncIteration:
return
if isinstance(first_chunk, Response):
yield first_chunk
return
first_attempt_pending = True
async def stream_request_wrapper(payload):
nonlocal first_attempt_pending
if first_attempt_pending:
first_attempt_pending = False
stream_gen = prepend_async_item(first_chunk, first_attempt_stream)
else:
stream_gen = stream_request(body=payload, native=False)
return StreamingResponse(stream_gen, media_type="text/event-stream")
# 创建反截断处理器
processor = AntiTruncationStreamProcessor(
stream_request_wrapper,
anti_truncation_payload,
max_attempts,
enable_prefill_mode=("claude" not in str(api_request.get("model", "")).lower()),
)
# 转换为 OpenAI 格式
import uuid
response_id = str(uuid.uuid4())
# 直接迭代 process_stream() 生成器,并转换为 OpenAI 格式
async for chunk in processor.process_stream():
if not chunk:
continue
# 解析 Gemini SSE 格式
chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk
# 跳过空行
if not chunk_str.strip():
continue
# 处理 [DONE] 标记
if chunk_str.strip() == "data: [DONE]":
yield "data: [DONE]\n\n".encode('utf-8')
return
# 解析 "data: {...}" 格式
if chunk_str.startswith("data: "):
try:
# 转换为 OpenAI 格式
from src.converter.openai2gemini import convert_gemini_to_openai_stream
openai_chunk_str = convert_gemini_to_openai_stream(
chunk_str,
real_model,
response_id
)
if openai_chunk_str:
yield openai_chunk_str.encode('utf-8')
except Exception as e:
log.error(f"Failed to convert chunk: {e}")
continue
# 发送结束标记
yield "data: [DONE]\n\n".encode('utf-8')
# ========== 普通流式生成器 ==========
async def normal_stream_generator():
from src.api.antigravity import stream_request
from fastapi import Response
import uuid
# 调用 API 层的流式请求(不使用 native 模式)
stream_gen = stream_request(body=api_request, native=False)
try:
first_chunk = await read_first_async_item(stream_gen)
except StopAsyncIteration:
return
if isinstance(first_chunk, Response):
yield first_chunk
return
response_id = str(uuid.uuid4())
# yield所有数据,处理可能的错误Response
async for chunk in prepend_async_item(first_chunk, stream_gen):
# 检查是否是Response对象(错误情况)
if isinstance(chunk, Response):
# 将Response转换为SSE格式的错误消息
try:
error_content = chunk.body if isinstance(chunk.body, bytes) else (chunk.body or b'').encode('utf-8')
gemini_error = json.loads(error_content.decode('utf-8'))
# 转换为 OpenAI 格式错误
from src.converter.openai2gemini import convert_gemini_to_openai_response
openai_error = convert_gemini_to_openai_response(
gemini_error,
real_model,
chunk.status_code
)
yield f"data: {json.dumps(openai_error)}\n\n".encode('utf-8')
except Exception:
yield f"data: {json.dumps({'error': 'Stream error'})}\n\n".encode('utf-8')
yield b"data: [DONE]\n\n"
return
else:
# 正常的bytes数据,转换为 OpenAI 格式
chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk
# 跳过空行
if not chunk_str.strip():
continue
# 处理 [DONE] 标记
if chunk_str.strip() == "data: [DONE]":
yield "data: [DONE]\n\n".encode('utf-8')
return
# 解析并转换 Gemini chunk 为 OpenAI 格式
if chunk_str.startswith("data: "):
try:
# 转换为 OpenAI 格式
from src.converter.openai2gemini import convert_gemini_to_openai_stream
openai_chunk_str = convert_gemini_to_openai_stream(
chunk_str,
real_model,
response_id
)
if openai_chunk_str:
yield openai_chunk_str.encode('utf-8')
except Exception as e:
log.error(f"Failed to convert chunk: {e}")
continue
# 发送结束标记
yield "data: [DONE]\n\n".encode('utf-8')
# ========== 根据模式选择生成器 ==========
if use_fake_streaming:
return await build_streaming_response_or_error(fake_stream_generator())
elif use_anti_truncation:
log.info("启用流式抗截断功能")
return await build_streaming_response_or_error(anti_truncation_generator())
else:
return await build_streaming_response_or_error(normal_stream_generator())
# ==================== 测试代码 ====================
if __name__ == "__main__":
"""
测试代码:演示OpenAI路由的流式和非流式响应
运行方式: python src/router/antigravity/openai.py
"""
from fastapi.testclient import TestClient
from fastapi import FastAPI
print("=" * 80)
print("OpenAI Router 测试")
print("=" * 80)
# 创建测试应用
app = FastAPI()
app.include_router(router)
# 测试客户端
client = TestClient(app)
# 测试请求体 (OpenAI格式)
test_request_body = {
"model": "gemini-2.5-flash",
"messages": [
{"role": "user", "content": "Hello, tell me a joke in one sentence."}
]
}
# 测试Bearer令牌(模拟)
test_token = "Bearer pwd"
def test_non_stream_request():
"""测试非流式请求"""
print("\n" + "=" * 80)
print("【测试1】非流式请求 (POST /antigravity/v1/chat/completions)")
print("=" * 80)
print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n")
response = client.post(
"/antigravity/v1/chat/completions",
json=test_request_body,
headers={"Authorization": test_token}
)
print("非流式响应数据:")
print("-" * 80)
print(f"状态码: {response.status_code}")
print(f"Content-Type: {response.headers.get('content-type', 'N/A')}")
try:
content = response.text
print(f"\n响应内容 (原始):\n{content}\n")
# 尝试解析JSON
try:
json_data = response.json()
print(f"响应内容 (格式化JSON):")
print(json.dumps(json_data, indent=2, ensure_ascii=False))
except json.JSONDecodeError:
print("(非JSON格式)")
except Exception as e:
print(f"内容解析失败: {e}")
def test_stream_request():
"""测试流式请求"""
print("\n" + "=" * 80)
print("【测试2】流式请求 (POST /antigravity/v1/chat/completions)")
print("=" * 80)
stream_request_body = test_request_body.copy()
stream_request_body["stream"] = True
print(f"请求体: {json.dumps(stream_request_body, indent=2, ensure_ascii=False)}\n")
print("流式响应数据 (每个chunk):")
print("-" * 80)
with client.stream(
"POST",
"/antigravity/v1/chat/completions",
json=stream_request_body,
headers={"Authorization": test_token}
) as response:
print(f"状态码: {response.status_code}")
print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n")
chunk_count = 0
for chunk in response.iter_bytes():
if chunk:
chunk_count += 1
print(f"\nChunk #{chunk_count}:")
print(f" 类型: {type(chunk).__name__}")
print(f" 长度: {len(chunk)}")
# 解码chunk
try:
chunk_str = chunk.decode('utf-8')
print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}")
# 如果是SSE格式,尝试解析每一行
if chunk_str.startswith("data: "):
# 按行分割,处理每个SSE事件
for line in chunk_str.strip().split('\n'):
line = line.strip()
if not line:
continue
if line == "data: [DONE]":
print(f" => 流结束标记")
elif line.startswith("data: "):
try:
json_str = line[6:] # 去掉 "data: " 前缀
json_data = json.loads(json_str)
print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}")
except Exception as e:
print(f" SSE解析失败: {e}")
except Exception as e:
print(f" 解码失败: {e}")
print(f"\n总共收到 {chunk_count} 个chunk")
def test_fake_stream_request():
"""测试假流式请求"""
print("\n" + "=" * 80)
print("【测试3】假流式请求 (POST /antigravity/v1/chat/completions with 假流式 prefix)")
print("=" * 80)
fake_stream_request_body = test_request_body.copy()
fake_stream_request_body["model"] = "假流式/gemini-2.5-flash"
fake_stream_request_body["stream"] = True
print(f"请求体: {json.dumps(fake_stream_request_body, indent=2, ensure_ascii=False)}\n")
print("假流式响应数据 (每个chunk):")
print("-" * 80)
with client.stream(
"POST",
"/antigravity/v1/chat/completions",
json=fake_stream_request_body,
headers={"Authorization": test_token}
) as response:
print(f"状态码: {response.status_code}")
print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n")
chunk_count = 0
for chunk in response.iter_bytes():
if chunk:
chunk_count += 1
chunk_str = chunk.decode('utf-8')
print(f"\nChunk #{chunk_count}:")
print(f" 长度: {len(chunk_str)} 字节")
# 解析chunk中的所有SSE事件
events = []
for line in chunk_str.split('\n'):
line = line.strip()
if line.startswith("data: "):
events.append(line)
print(f" 包含 {len(events)} 个SSE事件")
# 显示每个事件
for event_idx, event_line in enumerate(events, 1):
if event_line == "data: [DONE]":
print(f" 事件 #{event_idx}: [DONE]")
else:
try:
json_str = event_line[6:] # 去掉 "data: " 前缀
json_data = json.loads(json_str)
# 提取content内容
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "")
finish_reason = json_data.get("choices", [{}])[0].get("finish_reason")
print(f" 事件 #{event_idx}: content={repr(content[:50])}{'...' if len(content) > 50 else ''}, finish_reason={finish_reason}")
except Exception as e:
print(f" 事件 #{event_idx}: 解析失败 - {e}")
print(f"\n总共收到 {chunk_count} 个HTTP chunk")
# 运行测试
try:
# 测试非流式请求
test_non_stream_request()
# 测试流式请求
test_stream_request()
# 测试假流式请求
test_fake_stream_request()
print("\n" + "=" * 80)
print("测试完成")
print("=" * 80)
except Exception as e:
print(f"\n❌ 测试过程中出现异常: {e}")
import traceback
traceback.print_exc()