gcli2api / src /router /antigravity /gemini.py
a3216's picture
sync: github -> hf space
96d34f2
"""
Gemini Router - Handles native Gemini format API requests (Antigravity backend)
处理原生Gemini格式请求的路由模块(Antigravity后端)
"""
import sys
from pathlib import Path
# 添加项目根目录到Python路径
project_root = Path(__file__).resolve().parent.parent.parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
# 标准库
import asyncio
import json
# 第三方库
from fastapi import APIRouter, Depends, HTTPException, Path, Request
from fastapi.responses import JSONResponse, StreamingResponse
# 本地模块 - 配置和日志
from config import get_anti_truncation_max_attempts
from log import log
# 本地模块 - 工具和认证
from src.utils import (
get_base_model_from_feature_model,
is_anti_truncation_model,
authenticate_gemini_flexible,
is_fake_streaming_model
)
# 本地模块 - 转换器(假流式需要)
from src.converter.fake_stream import (
parse_response_for_fake_stream,
build_gemini_fake_stream_chunks,
create_gemini_heartbeat_chunk,
)
# 本地模块 - 基础路由工具
from src.router.hi_check import is_health_check_request, create_health_check_response
from src.router.stream_passthrough import (
build_streaming_response_or_error,
prepend_async_item,
read_first_async_item,
)
# 本地模块 - 数据模型
from src.models import GeminiRequest, model_to_dict
# 本地模块 - 任务管理
from src.task_manager import create_managed_task
# ==================== 路由器初始化 ====================
router = APIRouter()
# ==================== API 路由 ====================
@router.post("/antigravity/v1beta/models/{model:path}:generateContent")
@router.post("/antigravity/v1/models/{model:path}:generateContent")
async def generate_content(
gemini_request: "GeminiRequest",
model: str = Path(..., description="Model name"),
api_key: str = Depends(authenticate_gemini_flexible),
):
"""
处理Gemini格式的内容生成请求(非流式)
Args:
gemini_request: Gemini格式的请求体
model: 模型名称
api_key: API 密钥
"""
log.debug(f"[ANTIGRAVITY] Non-streaming request for model: {model}")
# 转换为字典
normalized_dict = model_to_dict(gemini_request)
# 健康检查
if is_health_check_request(normalized_dict, format="gemini"):
response = create_health_check_response(format="gemini")
return JSONResponse(content=response)
# 处理模型名称和功能检测
use_anti_truncation = is_anti_truncation_model(model)
real_model = get_base_model_from_feature_model(model)
# 对于抗截断模型的非流式请求,给出警告
if use_anti_truncation:
log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置")
# 更新模型名为真实模型名
normalized_dict["model"] = real_model
# 规范化 Gemini 请求 (使用 antigravity 模式)
from src.converter.gemini_fix import normalize_gemini_request
normalized_dict = await normalize_gemini_request(normalized_dict, mode="antigravity")
# 准备API请求格式 - 提取model并将其他字段放入request中
api_request = {
"model": normalized_dict.pop("model"),
"request": normalized_dict
}
# 调用 API 层的非流式请求
from src.api.antigravity import non_stream_request
response = await non_stream_request(body=api_request)
# 解包装响应:Antigravity API 可能返回的格式有额外的 response 包装层
# 需要提取并返回标准 Gemini 格式
# 保持 Gemini 原生的 inlineData 格式,不进行 Markdown 转换
try:
if response.status_code == 200:
response_data = json.loads(response.body if hasattr(response, 'body') else response.content)
# 如果有 response 包装,解包装它
if "response" in response_data:
unwrapped_data = response_data["response"]
return JSONResponse(content=unwrapped_data)
# 错误响应或没有 response 字段,直接返回
return response
except Exception as e:
log.warning(f"Failed to unwrap response: {e}, returning original response")
return response
@router.post("/antigravity/v1beta/models/{model:path}:streamGenerateContent")
@router.post("/antigravity/v1/models/{model:path}:streamGenerateContent")
async def stream_generate_content(
gemini_request: GeminiRequest,
model: str = Path(..., description="Model name"),
api_key: str = Depends(authenticate_gemini_flexible),
):
"""
处理Gemini格式的流式内容生成请求
Args:
gemini_request: Gemini格式的请求体
model: 模型名称
api_key: API 密钥
"""
log.debug(f"[ANTIGRAVITY] Streaming request for model: {model}")
# 转换为字典
normalized_dict = model_to_dict(gemini_request)
# 处理模型名称和功能检测
use_fake_streaming = is_fake_streaming_model(model)
use_anti_truncation = is_anti_truncation_model(model)
real_model = get_base_model_from_feature_model(model)
# 更新模型名为真实模型名
normalized_dict["model"] = real_model
# ========== 假流式生成器 ==========
async def fake_stream_generator():
from src.converter.gemini_fix import normalize_gemini_request
from src.api.antigravity import non_stream_request
normalized_req = await normalize_gemini_request(normalized_dict.copy(), mode="antigravity")
# 准备API请求格式 - 提取model并将其他字段放入request中
api_request = {
"model": normalized_req.pop("model"),
"request": normalized_req
}
response = await non_stream_request(body=api_request)
# 检查响应状态码
if hasattr(response, "status_code") and response.status_code != 200:
log.error(f"Fake streaming got error response: status={response.status_code}")
yield response
return
# 处理成功响应 - 提取响应内容
if hasattr(response, "body"):
response_body = response.body.decode() if isinstance(response.body, bytes) else response.body
elif hasattr(response, "content"):
response_body = response.content.decode() if isinstance(response.content, bytes) else response.content
else:
response_body = str(response)
try:
response_data = json.loads(response_body)
log.debug(f"Gemini fake stream response data: {response_data}")
# 检查是否是错误响应(有些错误可能status_code是200但包含error字段)
if "error" in response_data:
log.error(f"Fake streaming got error in response body: {response_data['error']}")
yield f"data: {json.dumps(response_data)}\n\n".encode()
yield "data: [DONE]\n\n".encode()
return
# 使用统一的解析函数
content, reasoning_content, finish_reason, images = parse_response_for_fake_stream(response_data)
log.debug(f"Gemini extracted content: {content}")
log.debug(f"Gemini extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...")
log.debug(f"Gemini extracted images count: {len(images)}")
# 构建响应块
chunks = build_gemini_fake_stream_chunks(content, reasoning_content, finish_reason, images)
for idx, chunk in enumerate(chunks):
chunk_json = json.dumps(chunk)
log.debug(f"[FAKE_STREAM] Yielding chunk #{idx+1}: {chunk_json[:200]}")
yield f"data: {chunk_json}\n\n".encode()
except Exception as e:
log.error(f"Response parsing failed: {e}, directly yield original response")
# 直接yield原始响应,不进行包装
yield f"data: {response_body}\n\n".encode()
yield "data: [DONE]\n\n".encode()
# ========== 流式抗截断生成器 ==========
async def anti_truncation_generator():
from src.converter.gemini_fix import normalize_gemini_request
from src.converter.anti_truncation import AntiTruncationStreamProcessor
from src.converter.anti_truncation import apply_anti_truncation
from src.api.antigravity import stream_request
from fastapi import Response
# 先进行基础标准化
normalized_req = await normalize_gemini_request(normalized_dict.copy(), mode="antigravity")
# 准备API请求格式 - 提取model并将其他字段放入request中
api_request = {
"model": normalized_req.pop("model") if "model" in normalized_req else real_model,
"request": normalized_req
}
max_attempts = await get_anti_truncation_max_attempts()
# 首先对payload应用反截断指令
anti_truncation_payload = apply_anti_truncation(api_request)
first_attempt_stream = stream_request(body=anti_truncation_payload, native=False)
try:
first_chunk = await read_first_async_item(first_attempt_stream)
except StopAsyncIteration:
return
if isinstance(first_chunk, Response):
yield first_chunk
return
first_attempt_pending = True
async def stream_request_wrapper(payload):
nonlocal first_attempt_pending
if first_attempt_pending:
first_attempt_pending = False
stream_gen = prepend_async_item(first_chunk, first_attempt_stream)
else:
stream_gen = stream_request(body=payload, native=False)
return StreamingResponse(stream_gen, media_type="text/event-stream")
# 创建反截断处理器
processor = AntiTruncationStreamProcessor(
stream_request_wrapper,
anti_truncation_payload,
max_attempts,
enable_prefill_mode=("claude" not in str(api_request.get("model", "")).lower()),
)
# 迭代 process_stream() 生成器,并展开 response 包装
async for chunk in processor.process_stream():
if isinstance(chunk, (str, bytes)):
chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk
# 解析并展开 response 包装
if chunk_str.startswith("data: "):
json_str = chunk_str[6:].strip()
# 跳过 [DONE] 标记
if json_str == "[DONE]":
yield chunk
continue
try:
# 解析JSON
data = json.loads(json_str)
# 展开 response 包装
if "response" in data and "candidates" not in data:
log.debug(f"[ANTIGRAVITY-ANTI-TRUNCATION] 展开response包装")
unwrapped_data = data["response"]
# 重新构建SSE格式
yield f"data: {json.dumps(unwrapped_data, ensure_ascii=False)}\n\n".encode('utf-8')
else:
# 已经是展开的格式,直接返回
yield chunk
except json.JSONDecodeError:
# JSON解析失败,直接返回原始chunk
yield chunk
else:
# 不是SSE格式,直接返回
yield chunk
else:
# 其他类型,直接返回
yield chunk
# ========== 普通流式生成器 ==========
async def normal_stream_generator():
from src.converter.gemini_fix import normalize_gemini_request
from src.api.antigravity import stream_request
from fastapi import Response
normalized_req = await normalize_gemini_request(normalized_dict.copy(), mode="antigravity")
# 准备API请求格式 - 提取model并将其他字段放入request中
api_request = {
"model": normalized_req.pop("model"),
"request": normalized_req
}
# 所有流式请求都使用非 native 模式(SSE格式)并展开 response 包装
log.debug(f"[ANTIGRAVITY] 使用非native模式,将展开response包装")
stream_gen = stream_request(body=api_request, native=False)
try:
first_chunk = await read_first_async_item(stream_gen)
except StopAsyncIteration:
return
if isinstance(first_chunk, Response):
yield first_chunk
return
# 展开 response 包装
async for chunk in prepend_async_item(first_chunk, stream_gen):
# 检查是否是Response对象(错误情况)
if isinstance(chunk, Response):
# 将Response转换为SSE格式的错误消息
try:
error_content = chunk.body if isinstance(chunk.body, bytes) else (chunk.body or b'').encode('utf-8')
error_json = json.loads(error_content.decode('utf-8'))
except Exception:
error_json = {"error": {"code": chunk.status_code, "message": "upstream error", "status": "ERROR"}}
log.error(f"[ANTIGRAVITY STREAM] 返回错误给客户端: status={chunk.status_code}, error={str(error_json)[:200]}")
yield f"data: {json.dumps(error_json)}\n\n".encode('utf-8')
yield b"data: [DONE]\n\n"
return
# 处理SSE格式的chunk
if isinstance(chunk, (str, bytes)):
chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk
# 解析并展开 response 包装
if chunk_str.startswith("data: "):
json_str = chunk_str[6:].strip()
# 跳过 [DONE] 标记
if json_str == "[DONE]":
yield chunk
continue
try:
# 解析JSON
data = json.loads(json_str)
# 展开 response 包装
if "response" in data and "candidates" not in data:
log.debug(f"[ANTIGRAVITY] 展开response包装")
unwrapped_data = data["response"]
# 重新构建SSE格式
yield f"data: {json.dumps(unwrapped_data, ensure_ascii=False)}\n\n".encode('utf-8')
else:
# 已经是展开的格式,直接返回
yield chunk
except json.JSONDecodeError:
# JSON解析失败,直接返回原始chunk
yield chunk
else:
# 不是SSE格式,直接返回
yield chunk
# ========== 根据模式选择生成器 ==========
if use_fake_streaming:
return await build_streaming_response_or_error(fake_stream_generator())
elif use_anti_truncation:
log.info("启用流式抗截断功能")
return await build_streaming_response_or_error(anti_truncation_generator())
else:
return await build_streaming_response_or_error(normal_stream_generator())
@router.post("/antigravity/v1beta/models/{model:path}:countTokens")
@router.post("/antigravity/v1/models/{model:path}:countTokens")
async def count_tokens(
request: Request = None,
api_key: str = Depends(authenticate_gemini_flexible),
):
"""
模拟Gemini格式的token计数
使用简单的启发式方法:大约4字符=1token
"""
try:
request_data = await request.json()
except Exception as e:
log.error(f"Failed to parse JSON request: {e}")
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
# 简单的token计数模拟 - 基于文本长度估算
total_tokens = 0
# 如果有contents字段
if "contents" in request_data:
for content in request_data["contents"]:
if "parts" in content:
for part in content["parts"]:
if "text" in part:
# 简单估算:大约4字符=1token
text_length = len(part["text"])
total_tokens += max(1, text_length // 4)
# 如果有generateContentRequest字段
elif "generateContentRequest" in request_data:
gen_request = request_data["generateContentRequest"]
if "contents" in gen_request:
for content in gen_request["contents"]:
if "parts" in content:
for part in content["parts"]:
if "text" in part:
text_length = len(part["text"])
total_tokens += max(1, text_length // 4)
# 返回Gemini格式的响应
return JSONResponse(content={"totalTokens": total_tokens})
# ==================== 测试代码 ====================
if __name__ == "__main__":
"""
测试代码:演示Gemini路由的流式和非流式响应
运行方式: python src/router/antigravity/gemini.py
"""
from fastapi.testclient import TestClient
from fastapi import FastAPI
print("=" * 80)
print("Gemini Router (Antigravity Backend) 测试")
print("=" * 80)
# 创建测试应用
app = FastAPI()
app.include_router(router)
# 测试客户端
client = TestClient(app)
# 测试请求体 (Gemini格式)
test_request_body = {
"contents": [
{
"role": "user",
"parts": [{"text": "Hello, tell me a joke in one sentence."}]
}
]
}
# 测试API密钥(模拟)
test_api_key = "pwd"
def test_non_stream_request():
"""测试非流式请求"""
print("\n" + "=" * 80)
print("【测试2】非流式请求 (POST /antigravity/v1/models/gemini-2.5-flash:generateContent)")
print("=" * 80)
print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n")
response = client.post(
"/antigravity/v1/models/gemini-2.5-flash:generateContent",
json=test_request_body,
params={"key": test_api_key}
)
print("非流式响应数据:")
print("-" * 80)
print(f"状态码: {response.status_code}")
print(f"Content-Type: {response.headers.get('content-type', 'N/A')}")
try:
content = response.text
print(f"\n响应内容 (原始):\n{content}\n")
# 尝试解析JSON
try:
json_data = response.json()
print(f"响应内容 (格式化JSON):")
print(json.dumps(json_data, indent=2, ensure_ascii=False))
except json.JSONDecodeError:
print("(非JSON格式)")
except Exception as e:
print(f"内容解析失败: {e}")
def test_stream_request():
"""测试流式请求"""
print("\n" + "=" * 80)
print("【测试3】流式请求 (POST /antigravity/v1/models/gemini-2.5-flash:streamGenerateContent)")
print("=" * 80)
print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n")
print("流式响应数据 (每个chunk):")
print("-" * 80)
with client.stream(
"POST",
"/antigravity/v1/models/gemini-2.5-flash:streamGenerateContent",
json=test_request_body,
params={"key": test_api_key}
) as response:
print(f"状态码: {response.status_code}")
print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n")
chunk_count = 0
for chunk in response.iter_bytes():
if chunk:
chunk_count += 1
print(f"\nChunk #{chunk_count}:")
print(f" 类型: {type(chunk).__name__}")
print(f" 长度: {len(chunk)}")
# 解码chunk
try:
chunk_str = chunk.decode('utf-8')
print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}")
# 如果是SSE格式,尝试解析每一行
if chunk_str.startswith("data: "):
# 按行分割,处理每个SSE事件
for line in chunk_str.strip().split('\n'):
line = line.strip()
if not line:
continue
if line == "data: [DONE]":
print(f" => 流结束标记")
elif line.startswith("data: "):
try:
json_str = line[6:] # 去掉 "data: " 前缀
json_data = json.loads(json_str)
print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}")
except Exception as e:
print(f" SSE解析失败: {e}")
except Exception as e:
print(f" 解码失败: {e}")
print(f"\n总共收到 {chunk_count} 个chunk")
def test_fake_stream_request():
"""测试假流式请求"""
print("\n" + "=" * 80)
print("【测试4】假流式请求 (POST /antigravity/v1/models/假流式/gemini-2.5-flash:streamGenerateContent)")
print("=" * 80)
print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n")
print("假流式响应数据 (每个chunk):")
print("-" * 80)
with client.stream(
"POST",
"/antigravity/v1/models/假流式/gemini-2.5-flash:streamGenerateContent",
json=test_request_body,
params={"key": test_api_key}
) as response:
print(f"状态码: {response.status_code}")
print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n")
chunk_count = 0
for chunk in response.iter_bytes():
if chunk:
chunk_count += 1
chunk_str = chunk.decode('utf-8')
print(f"\nChunk #{chunk_count}:")
print(f" 长度: {len(chunk_str)} 字节")
# 解析chunk中的所有SSE事件
events = []
for line in chunk_str.split('\n'):
line = line.strip()
if line.startswith("data: "):
events.append(line)
print(f" 包含 {len(events)} 个SSE事件")
# 显示每个事件
for event_idx, event_line in enumerate(events, 1):
if event_line == "data: [DONE]":
print(f" 事件 #{event_idx}: [DONE]")
else:
try:
json_str = event_line[6:] # 去掉 "data: " 前缀
json_data = json.loads(json_str)
# 提取text内容
text = json_data.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "")
finish_reason = json_data.get("candidates", [{}])[0].get("finishReason")
print(f" 事件 #{event_idx}: text={repr(text[:50])}{'...' if len(text) > 50 else ''}, finishReason={finish_reason}")
except Exception as e:
print(f" 事件 #{event_idx}: 解析失败 - {e}")
print(f"\n总共收到 {chunk_count} 个HTTP chunk")
def test_anti_truncation_stream_request():
"""测试流式抗截断请求"""
print("\n" + "=" * 80)
print("【测试5】流式抗截断请求 (POST /antigravity/v1/models/流式抗截断/gemini-2.5-flash:streamGenerateContent)")
print("=" * 80)
print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n")
print("流式抗截断响应数据 (每个chunk):")
print("-" * 80)
with client.stream(
"POST",
"/antigravity/v1/models/流式抗截断/gemini-2.5-flash:streamGenerateContent",
json=test_request_body,
params={"key": test_api_key}
) as response:
print(f"状态码: {response.status_code}")
print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n")
chunk_count = 0
for chunk in response.iter_bytes():
if chunk:
chunk_count += 1
print(f"\nChunk #{chunk_count}:")
print(f" 类型: {type(chunk).__name__}")
print(f" 长度: {len(chunk)}")
# 解码chunk
try:
chunk_str = chunk.decode('utf-8')
print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}")
# 如果是SSE格式,尝试解析每一行
if chunk_str.startswith("data: "):
# 按行分割,处理每个SSE事件
for line in chunk_str.strip().split('\n'):
line = line.strip()
if not line:
continue
if line == "data: [DONE]":
print(f" => 流结束标记")
elif line.startswith("data: "):
try:
json_str = line[6:] # 去掉 "data: " 前缀
json_data = json.loads(json_str)
print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}")
except Exception as e:
print(f" SSE解析失败: {e}")
except Exception as e:
print(f" 解码失败: {e}")
print(f"\n总共收到 {chunk_count} 个chunk")
# 运行测试
try:
# 测试非流式请求
test_non_stream_request()
# 测试流式请求
test_stream_request()
# 测试假流式请求
test_fake_stream_request()
# 测试流式抗截断请求
test_anti_truncation_stream_request()
print("\n" + "=" * 80)
print("测试完成")
print("=" * 80)
except Exception as e:
print(f"\n❌ 测试过程中出现异常: {e}")
import traceback
traceback.print_exc()