| """ |
| OpenAI Router - Handles OpenAI format API requests via Antigravity |
| 通过Antigravity处理OpenAI格式请求的路由模块 |
| """ |
|
|
| import sys |
| from pathlib import Path |
|
|
| |
| project_root = Path(__file__).resolve().parent.parent.parent.parent |
| if str(project_root) not in sys.path: |
| sys.path.insert(0, str(project_root)) |
|
|
| |
| import asyncio |
| import json |
|
|
| |
| from fastapi import APIRouter, Depends, HTTPException |
| from fastapi.responses import JSONResponse, StreamingResponse |
|
|
| |
| from config import get_anti_truncation_max_attempts |
| from log import log |
|
|
| |
| from src.utils import ( |
| get_base_model_from_feature_model, |
| is_anti_truncation_model, |
| is_fake_streaming_model, |
| authenticate_bearer, |
| ) |
|
|
| |
| from src.converter.fake_stream import ( |
| parse_response_for_fake_stream, |
| build_openai_fake_stream_chunks, |
| create_openai_heartbeat_chunk, |
| ) |
|
|
| |
| from src.router.hi_check import is_health_check_request, create_health_check_response |
| from src.router.stream_passthrough import ( |
| build_streaming_response_or_error, |
| prepend_async_item, |
| read_first_async_item, |
| ) |
|
|
| |
| from src.models import OpenAIChatCompletionRequest, model_to_dict |
|
|
| |
| from src.task_manager import create_managed_task |
|
|
|
|
| |
|
|
| router = APIRouter() |
|
|
|
|
| |
|
|
| @router.post("/antigravity/v1/chat/completions") |
| async def chat_completions( |
| openai_request: OpenAIChatCompletionRequest, |
| token: str = Depends(authenticate_bearer) |
| ): |
| """ |
| 处理OpenAI格式的聊天完成请求(流式和非流式) |
| |
| Args: |
| openai_request: OpenAI格式的请求体 |
| token: Bearer认证令牌 |
| """ |
| log.debug(f"[ANTIGRAVITY-OPENAI] Request for model: {openai_request.model}") |
|
|
| |
| normalized_dict = model_to_dict(openai_request) |
|
|
| |
| if is_health_check_request(normalized_dict, format="openai"): |
| response = create_health_check_response(format="openai") |
| return JSONResponse(content=response) |
|
|
| |
| use_fake_streaming = is_fake_streaming_model(openai_request.model) |
| use_anti_truncation = is_anti_truncation_model(openai_request.model) |
| real_model = get_base_model_from_feature_model(openai_request.model) |
|
|
| |
| is_streaming = openai_request.stream |
|
|
| |
| if use_anti_truncation and not is_streaming: |
| log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置") |
|
|
| |
| normalized_dict["model"] = real_model |
|
|
| |
| from src.converter.openai2gemini import convert_openai_to_gemini_request |
| gemini_dict = await convert_openai_to_gemini_request(normalized_dict) |
|
|
| |
| gemini_dict["model"] = real_model |
|
|
| |
| from src.converter.gemini_fix import normalize_gemini_request |
| gemini_dict = await normalize_gemini_request(gemini_dict, mode="antigravity") |
|
|
| |
| api_request = { |
| "model": gemini_dict.pop("model"), |
| "request": gemini_dict |
| } |
|
|
| |
| if not is_streaming: |
| |
| from src.api.antigravity import non_stream_request |
| response = await non_stream_request(body=api_request) |
|
|
| |
| status_code = getattr(response, "status_code", 200) |
|
|
| |
| if hasattr(response, "body"): |
| response_body = response.body.decode() if isinstance(response.body, bytes) else response.body |
| elif hasattr(response, "content"): |
| response_body = response.content.decode() if isinstance(response.content, bytes) else response.content |
| else: |
| response_body = str(response) |
|
|
| try: |
| gemini_response = json.loads(response_body) |
| except Exception as e: |
| log.error(f"Failed to parse Gemini response: {e}") |
| raise HTTPException(status_code=500, detail="Response parsing failed") |
|
|
| |
| from src.converter.openai2gemini import convert_gemini_to_openai_response |
| openai_response = convert_gemini_to_openai_response( |
| gemini_response, |
| real_model, |
| status_code |
| ) |
|
|
| return JSONResponse(content=openai_response, status_code=status_code) |
|
|
| |
|
|
| |
| async def fake_stream_generator(): |
| from src.api.antigravity import non_stream_request |
|
|
| response = await non_stream_request(body=api_request) |
|
|
| |
| if hasattr(response, "status_code") and response.status_code != 200: |
| log.error(f"Fake streaming got error response: status={response.status_code}") |
| yield response |
| return |
|
|
| |
| if hasattr(response, "body"): |
| response_body = response.body.decode() if isinstance(response.body, bytes) else response.body |
| elif hasattr(response, "content"): |
| response_body = response.content.decode() if isinstance(response.content, bytes) else response.content |
| else: |
| response_body = str(response) |
|
|
| try: |
| gemini_response = json.loads(response_body) |
| log.debug(f"OpenAI fake stream Gemini response: {gemini_response}") |
|
|
| |
| if "error" in gemini_response: |
| log.error(f"Fake streaming got error in response body: {gemini_response['error']}") |
| |
| from src.converter.openai2gemini import convert_gemini_to_openai_response |
| openai_error = convert_gemini_to_openai_response( |
| gemini_response, |
| real_model, |
| 200 |
| ) |
| yield f"data: {json.dumps(openai_error)}\n\n".encode() |
| yield "data: [DONE]\n\n".encode() |
| return |
|
|
| |
| content, reasoning_content, finish_reason, images = parse_response_for_fake_stream(gemini_response) |
|
|
| log.debug(f"OpenAI extracted content: {content}") |
| log.debug(f"OpenAI extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...") |
| log.debug(f"OpenAI extracted images count: {len(images)}") |
|
|
| |
| chunks = build_openai_fake_stream_chunks(content, reasoning_content, finish_reason, real_model, images) |
| for idx, chunk in enumerate(chunks): |
| chunk_json = json.dumps(chunk) |
| log.debug(f"[FAKE_STREAM] Yielding chunk #{idx+1}: {chunk_json[:200]}") |
| yield f"data: {chunk_json}\n\n".encode() |
|
|
| except Exception as e: |
| log.error(f"Response parsing failed: {e}, directly yield error") |
| |
| error_chunk = { |
| "id": "error", |
| "object": "chat.completion.chunk", |
| "created": int(asyncio.get_event_loop().time()), |
| "model": real_model, |
| "choices": [{ |
| "index": 0, |
| "delta": {"content": f"Error: {str(e)}"}, |
| "finish_reason": "error" |
| }] |
| } |
| yield f"data: {json.dumps(error_chunk)}\n\n".encode() |
|
|
| yield "data: [DONE]\n\n".encode() |
|
|
| |
| async def anti_truncation_generator(): |
| from src.converter.anti_truncation import AntiTruncationStreamProcessor |
| from src.api.antigravity import stream_request |
| from src.converter.anti_truncation import apply_anti_truncation |
| from fastapi import Response |
|
|
| max_attempts = await get_anti_truncation_max_attempts() |
|
|
| |
| anti_truncation_payload = apply_anti_truncation(api_request) |
|
|
| first_attempt_stream = stream_request(body=anti_truncation_payload, native=False) |
| try: |
| first_chunk = await read_first_async_item(first_attempt_stream) |
| except StopAsyncIteration: |
| return |
|
|
| if isinstance(first_chunk, Response): |
| yield first_chunk |
| return |
|
|
| first_attempt_pending = True |
|
|
| async def stream_request_wrapper(payload): |
| nonlocal first_attempt_pending |
|
|
| if first_attempt_pending: |
| first_attempt_pending = False |
| stream_gen = prepend_async_item(first_chunk, first_attempt_stream) |
| else: |
| stream_gen = stream_request(body=payload, native=False) |
|
|
| return StreamingResponse(stream_gen, media_type="text/event-stream") |
|
|
| |
| processor = AntiTruncationStreamProcessor( |
| stream_request_wrapper, |
| anti_truncation_payload, |
| max_attempts, |
| enable_prefill_mode=("claude" not in str(api_request.get("model", "")).lower()), |
| ) |
|
|
| |
| import uuid |
| response_id = str(uuid.uuid4()) |
|
|
| |
| async for chunk in processor.process_stream(): |
| if not chunk: |
| continue |
|
|
| |
| chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk |
|
|
| |
| if not chunk_str.strip(): |
| continue |
|
|
| |
| if chunk_str.strip() == "data: [DONE]": |
| yield "data: [DONE]\n\n".encode('utf-8') |
| return |
|
|
| |
| if chunk_str.startswith("data: "): |
| try: |
| |
| from src.converter.openai2gemini import convert_gemini_to_openai_stream |
| openai_chunk_str = convert_gemini_to_openai_stream( |
| chunk_str, |
| real_model, |
| response_id |
| ) |
|
|
| if openai_chunk_str: |
| yield openai_chunk_str.encode('utf-8') |
|
|
| except Exception as e: |
| log.error(f"Failed to convert chunk: {e}") |
| continue |
|
|
| |
| yield "data: [DONE]\n\n".encode('utf-8') |
|
|
| |
| async def normal_stream_generator(): |
| from src.api.antigravity import stream_request |
| from fastapi import Response |
| import uuid |
|
|
| |
| stream_gen = stream_request(body=api_request, native=False) |
| try: |
| first_chunk = await read_first_async_item(stream_gen) |
| except StopAsyncIteration: |
| return |
|
|
| if isinstance(first_chunk, Response): |
| yield first_chunk |
| return |
|
|
| response_id = str(uuid.uuid4()) |
|
|
| |
| async for chunk in prepend_async_item(first_chunk, stream_gen): |
| |
| if isinstance(chunk, Response): |
| |
| try: |
| error_content = chunk.body if isinstance(chunk.body, bytes) else (chunk.body or b'').encode('utf-8') |
| gemini_error = json.loads(error_content.decode('utf-8')) |
| |
| from src.converter.openai2gemini import convert_gemini_to_openai_response |
| openai_error = convert_gemini_to_openai_response( |
| gemini_error, |
| real_model, |
| chunk.status_code |
| ) |
| yield f"data: {json.dumps(openai_error)}\n\n".encode('utf-8') |
| except Exception: |
| yield f"data: {json.dumps({'error': 'Stream error'})}\n\n".encode('utf-8') |
| yield b"data: [DONE]\n\n" |
| return |
| else: |
| |
| chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk |
|
|
| |
| if not chunk_str.strip(): |
| continue |
|
|
| |
| if chunk_str.strip() == "data: [DONE]": |
| yield "data: [DONE]\n\n".encode('utf-8') |
| return |
|
|
| |
| if chunk_str.startswith("data: "): |
| try: |
| |
| from src.converter.openai2gemini import convert_gemini_to_openai_stream |
| openai_chunk_str = convert_gemini_to_openai_stream( |
| chunk_str, |
| real_model, |
| response_id |
| ) |
|
|
| if openai_chunk_str: |
| yield openai_chunk_str.encode('utf-8') |
|
|
| except Exception as e: |
| log.error(f"Failed to convert chunk: {e}") |
| continue |
|
|
| |
| yield "data: [DONE]\n\n".encode('utf-8') |
|
|
| |
| if use_fake_streaming: |
| return await build_streaming_response_or_error(fake_stream_generator()) |
| elif use_anti_truncation: |
| log.info("启用流式抗截断功能") |
| return await build_streaming_response_or_error(anti_truncation_generator()) |
| else: |
| return await build_streaming_response_or_error(normal_stream_generator()) |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| """ |
| 测试代码:演示OpenAI路由的流式和非流式响应 |
| 运行方式: python src/router/antigravity/openai.py |
| """ |
|
|
| from fastapi.testclient import TestClient |
| from fastapi import FastAPI |
|
|
| print("=" * 80) |
| print("OpenAI Router 测试") |
| print("=" * 80) |
|
|
| |
| app = FastAPI() |
| app.include_router(router) |
|
|
| |
| client = TestClient(app) |
|
|
| |
| test_request_body = { |
| "model": "gemini-2.5-flash", |
| "messages": [ |
| {"role": "user", "content": "Hello, tell me a joke in one sentence."} |
| ] |
| } |
|
|
| |
| test_token = "Bearer pwd" |
|
|
| def test_non_stream_request(): |
| """测试非流式请求""" |
| print("\n" + "=" * 80) |
| print("【测试1】非流式请求 (POST /antigravity/v1/chat/completions)") |
| print("=" * 80) |
| print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") |
|
|
| response = client.post( |
| "/antigravity/v1/chat/completions", |
| json=test_request_body, |
| headers={"Authorization": test_token} |
| ) |
|
|
| print("非流式响应数据:") |
| print("-" * 80) |
| print(f"状态码: {response.status_code}") |
| print(f"Content-Type: {response.headers.get('content-type', 'N/A')}") |
|
|
| try: |
| content = response.text |
| print(f"\n响应内容 (原始):\n{content}\n") |
|
|
| |
| try: |
| json_data = response.json() |
| print(f"响应内容 (格式化JSON):") |
| print(json.dumps(json_data, indent=2, ensure_ascii=False)) |
| except json.JSONDecodeError: |
| print("(非JSON格式)") |
| except Exception as e: |
| print(f"内容解析失败: {e}") |
|
|
| def test_stream_request(): |
| """测试流式请求""" |
| print("\n" + "=" * 80) |
| print("【测试2】流式请求 (POST /antigravity/v1/chat/completions)") |
| print("=" * 80) |
|
|
| stream_request_body = test_request_body.copy() |
| stream_request_body["stream"] = True |
|
|
| print(f"请求体: {json.dumps(stream_request_body, indent=2, ensure_ascii=False)}\n") |
|
|
| print("流式响应数据 (每个chunk):") |
| print("-" * 80) |
|
|
| with client.stream( |
| "POST", |
| "/antigravity/v1/chat/completions", |
| json=stream_request_body, |
| headers={"Authorization": test_token} |
| ) as response: |
| print(f"状态码: {response.status_code}") |
| print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") |
|
|
| chunk_count = 0 |
| for chunk in response.iter_bytes(): |
| if chunk: |
| chunk_count += 1 |
| print(f"\nChunk #{chunk_count}:") |
| print(f" 类型: {type(chunk).__name__}") |
| print(f" 长度: {len(chunk)}") |
|
|
| |
| try: |
| chunk_str = chunk.decode('utf-8') |
| print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") |
|
|
| |
| if chunk_str.startswith("data: "): |
| |
| for line in chunk_str.strip().split('\n'): |
| line = line.strip() |
| if not line: |
| continue |
|
|
| if line == "data: [DONE]": |
| print(f" => 流结束标记") |
| elif line.startswith("data: "): |
| try: |
| json_str = line[6:] |
| json_data = json.loads(json_str) |
| print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") |
| except Exception as e: |
| print(f" SSE解析失败: {e}") |
| except Exception as e: |
| print(f" 解码失败: {e}") |
|
|
| print(f"\n总共收到 {chunk_count} 个chunk") |
|
|
| def test_fake_stream_request(): |
| """测试假流式请求""" |
| print("\n" + "=" * 80) |
| print("【测试3】假流式请求 (POST /antigravity/v1/chat/completions with 假流式 prefix)") |
| print("=" * 80) |
|
|
| fake_stream_request_body = test_request_body.copy() |
| fake_stream_request_body["model"] = "假流式/gemini-2.5-flash" |
| fake_stream_request_body["stream"] = True |
|
|
| print(f"请求体: {json.dumps(fake_stream_request_body, indent=2, ensure_ascii=False)}\n") |
|
|
| print("假流式响应数据 (每个chunk):") |
| print("-" * 80) |
|
|
| with client.stream( |
| "POST", |
| "/antigravity/v1/chat/completions", |
| json=fake_stream_request_body, |
| headers={"Authorization": test_token} |
| ) as response: |
| print(f"状态码: {response.status_code}") |
| print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") |
|
|
| chunk_count = 0 |
| for chunk in response.iter_bytes(): |
| if chunk: |
| chunk_count += 1 |
| chunk_str = chunk.decode('utf-8') |
|
|
| print(f"\nChunk #{chunk_count}:") |
| print(f" 长度: {len(chunk_str)} 字节") |
|
|
| |
| events = [] |
| for line in chunk_str.split('\n'): |
| line = line.strip() |
| if line.startswith("data: "): |
| events.append(line) |
|
|
| print(f" 包含 {len(events)} 个SSE事件") |
|
|
| |
| for event_idx, event_line in enumerate(events, 1): |
| if event_line == "data: [DONE]": |
| print(f" 事件 #{event_idx}: [DONE]") |
| else: |
| try: |
| json_str = event_line[6:] |
| json_data = json.loads(json_str) |
| |
| content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "") |
| finish_reason = json_data.get("choices", [{}])[0].get("finish_reason") |
| print(f" 事件 #{event_idx}: content={repr(content[:50])}{'...' if len(content) > 50 else ''}, finish_reason={finish_reason}") |
| except Exception as e: |
| print(f" 事件 #{event_idx}: 解析失败 - {e}") |
|
|
| print(f"\n总共收到 {chunk_count} 个HTTP chunk") |
|
|
| |
| try: |
| |
| test_non_stream_request() |
|
|
| |
| test_stream_request() |
|
|
| |
| test_fake_stream_request() |
|
|
| print("\n" + "=" * 80) |
| print("测试完成") |
| print("=" * 80) |
|
|
| except Exception as e: |
| print(f"\n❌ 测试过程中出现异常: {e}") |
| import traceback |
| traceback.print_exc() |
|
|