| """ |
| Gemini Router - Handles native Gemini format API requests |
| 处理原生Gemini格式请求的路由模块 |
| """ |
|
|
| import sys |
| from pathlib import Path |
|
|
| |
| project_root = Path(__file__).resolve().parent.parent.parent.parent |
| if str(project_root) not in sys.path: |
| sys.path.insert(0, str(project_root)) |
|
|
| import json |
|
|
| |
| from fastapi import APIRouter, Depends, HTTPException, Path, Request |
| from fastapi.responses import JSONResponse, StreamingResponse |
|
|
| |
| from config import get_anti_truncation_max_attempts |
| from log import log |
|
|
| |
| from src.utils import ( |
| get_base_model_from_feature_model, |
| is_anti_truncation_model, |
| authenticate_gemini_flexible, |
| is_fake_streaming_model |
| ) |
|
|
| |
| from src.converter.fake_stream import ( |
| parse_response_for_fake_stream, |
| build_gemini_fake_stream_chunks, |
| create_gemini_heartbeat_chunk, |
| ) |
|
|
| |
| from src.router.hi_check import is_health_check_request, create_health_check_response |
| from src.router.stream_passthrough import ( |
| build_streaming_response_or_error, |
| prepend_async_item, |
| read_first_async_item, |
| ) |
|
|
| |
| from src.models import GeminiRequest, model_to_dict |
|
|
| |
| from src.task_manager import create_managed_task |
|
|
|
|
| |
|
|
| router = APIRouter() |
|
|
|
|
| |
|
|
| @router.post("/v1beta/models/{model:path}:generateContent") |
| @router.post("/v1/models/{model:path}:generateContent") |
| async def generate_content( |
| gemini_request: "GeminiRequest", |
| model: str = Path(..., description="Model name"), |
| api_key: str = Depends(authenticate_gemini_flexible), |
| ): |
| """ |
| 处理Gemini格式的内容生成请求(非流式) |
| |
| Args: |
| gemini_request: Gemini格式的请求体 |
| model: 模型名称 |
| api_key: API 密钥 |
| """ |
| log.debug(f"[GEMINICLI] Non-streaming request for model: {model}") |
|
|
| |
| normalized_dict = model_to_dict(gemini_request) |
|
|
| |
| if is_health_check_request(normalized_dict, format="gemini"): |
| response = create_health_check_response(format="gemini") |
| return JSONResponse(content=response) |
|
|
| |
| use_anti_truncation = is_anti_truncation_model(model) |
| real_model = get_base_model_from_feature_model(model) |
|
|
| |
| if use_anti_truncation: |
| log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置") |
|
|
| |
| normalized_dict["model"] = real_model |
|
|
| |
| from src.converter.gemini_fix import normalize_gemini_request |
| normalized_dict = await normalize_gemini_request(normalized_dict, mode="geminicli") |
|
|
| |
| api_request = { |
| "model": normalized_dict.pop("model"), |
| "request": normalized_dict |
| } |
|
|
| |
| from src.api.geminicli import non_stream_request |
| response = await non_stream_request(body=api_request) |
|
|
| |
| |
| try: |
| if response.status_code == 200: |
| response_data = json.loads(response.body if hasattr(response, 'body') else response.content) |
| |
| if "response" in response_data: |
| unwrapped_data = response_data["response"] |
| return JSONResponse(content=unwrapped_data) |
| |
| return response |
| except Exception as e: |
| log.warning(f"Failed to unwrap response: {e}, returning original response") |
| return response |
|
|
| @router.post("/v1beta/models/{model:path}:streamGenerateContent") |
| @router.post("/v1/models/{model:path}:streamGenerateContent") |
| async def stream_generate_content( |
| gemini_request: GeminiRequest, |
| model: str = Path(..., description="Model name"), |
| api_key: str = Depends(authenticate_gemini_flexible), |
| ): |
| """ |
| 处理Gemini格式的流式内容生成请求 |
| |
| Args: |
| gemini_request: Gemini格式的请求体 |
| model: 模型名称 |
| api_key: API 密钥 |
| """ |
| log.debug(f"[GEMINICLI] Streaming request for model: {model}") |
|
|
| |
| normalized_dict = model_to_dict(gemini_request) |
|
|
| |
| use_fake_streaming = is_fake_streaming_model(model) |
| use_anti_truncation = is_anti_truncation_model(model) |
| real_model = get_base_model_from_feature_model(model) |
|
|
| |
| normalized_dict["model"] = real_model |
|
|
| |
| async def fake_stream_generator(): |
| from src.converter.gemini_fix import normalize_gemini_request |
| from src.api.geminicli import non_stream_request |
|
|
| normalized_req = await normalize_gemini_request(normalized_dict.copy(), mode="geminicli") |
|
|
| |
| api_request = { |
| "model": normalized_req.pop("model"), |
| "request": normalized_req |
| } |
|
|
| response = await non_stream_request(body=api_request) |
|
|
| |
| if hasattr(response, "status_code") and response.status_code != 200: |
| log.error(f"Fake streaming got error response: status={response.status_code}") |
| yield response |
| return |
|
|
| |
| if hasattr(response, "body"): |
| response_body = response.body.decode() if isinstance(response.body, bytes) else response.body |
| elif hasattr(response, "content"): |
| response_body = response.content.decode() if isinstance(response.content, bytes) else response.content |
| else: |
| response_body = str(response) |
|
|
| try: |
| response_data = json.loads(response_body) |
| log.debug(f"Gemini fake stream response data: {response_data}") |
|
|
| |
| if "error" in response_data: |
| log.error(f"Fake streaming got error in response body: {response_data['error']}") |
| yield f"data: {json.dumps(response_data)}\n\n".encode() |
| yield "data: [DONE]\n\n".encode() |
| return |
|
|
| |
| content, reasoning_content, finish_reason, images = parse_response_for_fake_stream(response_data) |
|
|
| log.debug(f"Gemini extracted content: {content}") |
| log.debug(f"Gemini extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...") |
| log.debug(f"Gemini extracted images count: {len(images)}") |
|
|
| |
| chunks = build_gemini_fake_stream_chunks(content, reasoning_content, finish_reason, images) |
| for idx, chunk in enumerate(chunks): |
| chunk_json = json.dumps(chunk) |
| log.debug(f"[FAKE_STREAM] Yielding chunk #{idx+1}: {chunk_json[:200]}") |
| yield f"data: {chunk_json}\n\n".encode() |
|
|
| except Exception as e: |
| log.error(f"Response parsing failed: {e}, directly yield original response") |
| |
| yield f"data: {response_body}\n\n".encode() |
|
|
| yield "data: [DONE]\n\n".encode() |
|
|
| |
| async def anti_truncation_generator(): |
| from src.converter.gemini_fix import normalize_gemini_request |
| from src.converter.anti_truncation import AntiTruncationStreamProcessor |
| from src.converter.anti_truncation import apply_anti_truncation |
| from src.api.geminicli import stream_request |
| from fastapi import Response |
|
|
| |
| normalized_req = await normalize_gemini_request(normalized_dict.copy(), mode="geminicli") |
|
|
| |
| api_request = { |
| "model": normalized_req.pop("model") if "model" in normalized_req else real_model, |
| "request": normalized_req |
| } |
|
|
| max_attempts = await get_anti_truncation_max_attempts() |
|
|
| |
| anti_truncation_payload = apply_anti_truncation(api_request) |
|
|
| first_attempt_stream = stream_request(body=anti_truncation_payload, native=False) |
| try: |
| first_chunk = await read_first_async_item(first_attempt_stream) |
| except StopAsyncIteration: |
| return |
|
|
| if isinstance(first_chunk, Response): |
| yield first_chunk |
| return |
|
|
| first_attempt_pending = True |
|
|
| async def stream_request_wrapper(payload): |
| nonlocal first_attempt_pending |
|
|
| if first_attempt_pending: |
| first_attempt_pending = False |
| stream_gen = prepend_async_item(first_chunk, first_attempt_stream) |
| else: |
| stream_gen = stream_request(body=payload, native=False) |
| return StreamingResponse(stream_gen, media_type="text/event-stream") |
|
|
| |
| processor = AntiTruncationStreamProcessor( |
| stream_request_wrapper, |
| anti_truncation_payload, |
| max_attempts, |
| enable_prefill_mode=True, |
| ) |
|
|
| |
| async for chunk in processor.process_stream(): |
| if isinstance(chunk, (str, bytes)): |
| chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk |
|
|
| |
| if chunk_str.startswith("data: "): |
| json_str = chunk_str[6:].strip() |
|
|
| |
| if json_str == "[DONE]": |
| yield chunk |
| continue |
|
|
| try: |
| |
| data = json.loads(json_str) |
|
|
| |
| if "response" in data and "candidates" not in data: |
| log.debug(f"[GEMINICLI-ANTI-TRUNCATION] 展开response包装") |
| unwrapped_data = data["response"] |
| |
| yield f"data: {json.dumps(unwrapped_data, ensure_ascii=False)}\n\n".encode('utf-8') |
| else: |
| |
| yield chunk |
| except json.JSONDecodeError: |
| |
| yield chunk |
| else: |
| |
| yield chunk |
| else: |
| |
| yield chunk |
|
|
| |
| async def normal_stream_generator(): |
| from src.converter.gemini_fix import normalize_gemini_request |
| from src.api.geminicli import stream_request |
| from fastapi import Response |
|
|
| normalized_req = await normalize_gemini_request(normalized_dict.copy(), mode="geminicli") |
|
|
| |
| api_request = { |
| "model": normalized_req.pop("model"), |
| "request": normalized_req |
| } |
|
|
| |
| log.debug(f"[GEMINICLI] 使用非native模式,将展开response包装") |
| stream_gen = stream_request(body=api_request, native=False) |
| try: |
| first_chunk = await read_first_async_item(stream_gen) |
| except StopAsyncIteration: |
| return |
|
|
| if isinstance(first_chunk, Response): |
| yield first_chunk |
| return |
|
|
| |
| async for chunk in prepend_async_item(first_chunk, stream_gen): |
| |
| if isinstance(chunk, Response): |
| |
| try: |
| error_content = chunk.body if isinstance(chunk.body, bytes) else (chunk.body or b'').encode('utf-8') |
| error_json = json.loads(error_content.decode('utf-8')) |
| except Exception: |
| error_json = {"error": {"code": chunk.status_code, "message": "upstream error", "status": "ERROR"}} |
| log.error(f"[GEMINICLI STREAM] 返回错误给客户端: status={chunk.status_code}, error={str(error_json)[:200]}") |
| |
| yield f"data: {json.dumps(error_json)}\n\n".encode('utf-8') |
| yield b"data: [DONE]\n\n" |
| return |
|
|
| |
| if isinstance(chunk, (str, bytes)): |
| chunk_str = chunk.decode('utf-8') if isinstance(chunk, bytes) else chunk |
|
|
| |
| if chunk_str.startswith("data: "): |
| json_str = chunk_str[6:].strip() |
|
|
| |
| if json_str == "[DONE]": |
| yield chunk |
| continue |
|
|
| try: |
| |
| data = json.loads(json_str) |
|
|
| |
| if "response" in data and "candidates" not in data: |
| log.debug(f"[GEMINICLI] 展开response包装") |
| unwrapped_data = data["response"] |
| |
| yield f"data: {json.dumps(unwrapped_data, ensure_ascii=False)}\n\n".encode('utf-8') |
| else: |
| |
| yield chunk |
| except json.JSONDecodeError: |
| |
| yield chunk |
| else: |
| |
| yield chunk |
|
|
| |
| if use_fake_streaming: |
| return await build_streaming_response_or_error(fake_stream_generator()) |
| elif use_anti_truncation: |
| log.info("启用流式抗截断功能") |
| return await build_streaming_response_or_error(anti_truncation_generator()) |
| else: |
| return await build_streaming_response_or_error(normal_stream_generator()) |
|
|
| @router.post("/v1beta/models/{model:path}:countTokens") |
| @router.post("/v1/models/{model:path}:countTokens") |
| async def count_tokens( |
| request: Request = None, |
| api_key: str = Depends(authenticate_gemini_flexible), |
| ): |
| """ |
| 模拟Gemini格式的token计数 |
| |
| 使用简单的启发式方法:大约4字符=1token |
| """ |
|
|
| try: |
| request_data = await request.json() |
| except Exception as e: |
| log.error(f"Failed to parse JSON request: {e}") |
| raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}") |
|
|
| |
| total_tokens = 0 |
|
|
| |
| if "contents" in request_data: |
| for content in request_data["contents"]: |
| if "parts" in content: |
| for part in content["parts"]: |
| if "text" in part: |
| |
| text_length = len(part["text"]) |
| total_tokens += max(1, text_length // 4) |
|
|
| |
| elif "generateContentRequest" in request_data: |
| gen_request = request_data["generateContentRequest"] |
| if "contents" in gen_request: |
| for content in gen_request["contents"]: |
| if "parts" in content: |
| for part in content["parts"]: |
| if "text" in part: |
| text_length = len(part["text"]) |
| total_tokens += max(1, text_length // 4) |
|
|
| |
| return JSONResponse(content={"totalTokens": total_tokens}) |
|
|
| |
|
|
| if __name__ == "__main__": |
| """ |
| 测试代码:演示Gemini路由的流式和非流式响应 |
| 运行方式: python src/router/geminicli/gemini.py |
| """ |
|
|
| from fastapi.testclient import TestClient |
| from fastapi import FastAPI |
|
|
| print("=" * 80) |
| print("Gemini Router 测试") |
| print("=" * 80) |
|
|
| |
| app = FastAPI() |
| app.include_router(router) |
|
|
| |
| client = TestClient(app) |
|
|
| |
| test_request_body = { |
| "contents": [ |
| { |
| "role": "user", |
| "parts": [{"text": "Hello, tell me a joke in one sentence."}] |
| } |
| ] |
| } |
|
|
| |
| test_api_key = "pwd" |
|
|
| def test_non_stream_request(): |
| """测试非流式请求""" |
| print("\n" + "=" * 80) |
| print("【测试2】非流式请求 (POST /v1/models/gemini-2.5-flash:generateContent)") |
| print("=" * 80) |
| print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") |
|
|
| response = client.post( |
| "/v1/models/gemini-2.5-flash:generateContent", |
| json=test_request_body, |
| params={"key": test_api_key} |
| ) |
|
|
| print("非流式响应数据:") |
| print("-" * 80) |
| print(f"状态码: {response.status_code}") |
| print(f"Content-Type: {response.headers.get('content-type', 'N/A')}") |
|
|
| try: |
| content = response.text |
| print(f"\n响应内容 (原始):\n{content}\n") |
|
|
| |
| try: |
| json_data = response.json() |
| print(f"响应内容 (格式化JSON):") |
| print(json.dumps(json_data, indent=2, ensure_ascii=False)) |
| except json.JSONDecodeError: |
| print("(非JSON格式)") |
| except Exception as e: |
| print(f"内容解析失败: {e}") |
|
|
| def test_stream_request(): |
| """测试流式请求""" |
| print("\n" + "=" * 80) |
| print("【测试3】流式请求 (POST /v1/models/gemini-2.5-flash:streamGenerateContent)") |
| print("=" * 80) |
| print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") |
|
|
| print("流式响应数据 (每个chunk):") |
| print("-" * 80) |
|
|
| with client.stream( |
| "POST", |
| "/v1/models/gemini-2.5-flash:streamGenerateContent", |
| json=test_request_body, |
| params={"key": test_api_key} |
| ) as response: |
| print(f"状态码: {response.status_code}") |
| print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") |
|
|
| chunk_count = 0 |
| for chunk in response.iter_bytes(): |
| if chunk: |
| chunk_count += 1 |
| print(f"\nChunk #{chunk_count}:") |
| print(f" 类型: {type(chunk).__name__}") |
| print(f" 长度: {len(chunk)}") |
|
|
| |
| try: |
| chunk_str = chunk.decode('utf-8') |
| print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") |
|
|
| |
| if chunk_str.startswith("data: "): |
| |
| for line in chunk_str.strip().split('\n'): |
| line = line.strip() |
| if not line: |
| continue |
|
|
| if line == "data: [DONE]": |
| print(f" => 流结束标记") |
| elif line.startswith("data: "): |
| try: |
| json_str = line[6:] |
| json_data = json.loads(json_str) |
| print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") |
| except Exception as e: |
| print(f" SSE解析失败: {e}") |
| except Exception as e: |
| print(f" 解码失败: {e}") |
|
|
| print(f"\n总共收到 {chunk_count} 个chunk") |
|
|
| def test_fake_stream_request(): |
| """测试假流式请求""" |
| print("\n" + "=" * 80) |
| print("【测试4】假流式请求 (POST /v1/models/假流式/gemini-2.5-flash:streamGenerateContent)") |
| print("=" * 80) |
| print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") |
|
|
| print("假流式响应数据 (每个chunk):") |
| print("-" * 80) |
|
|
| with client.stream( |
| "POST", |
| "/v1/models/假流式/gemini-2.5-flash:streamGenerateContent", |
| json=test_request_body, |
| params={"key": test_api_key} |
| ) as response: |
| print(f"状态码: {response.status_code}") |
| print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") |
|
|
| chunk_count = 0 |
| for chunk in response.iter_bytes(): |
| if chunk: |
| chunk_count += 1 |
| chunk_str = chunk.decode('utf-8') |
|
|
| print(f"\nChunk #{chunk_count}:") |
| print(f" 长度: {len(chunk_str)} 字节") |
|
|
| |
| events = [] |
| for line in chunk_str.split('\n'): |
| line = line.strip() |
| if line.startswith("data: "): |
| events.append(line) |
|
|
| print(f" 包含 {len(events)} 个SSE事件") |
|
|
| |
| for event_idx, event_line in enumerate(events, 1): |
| if event_line == "data: [DONE]": |
| print(f" 事件 #{event_idx}: [DONE]") |
| else: |
| try: |
| json_str = event_line[6:] |
| json_data = json.loads(json_str) |
| |
| text = json_data.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "") |
| finish_reason = json_data.get("candidates", [{}])[0].get("finishReason") |
| print(f" 事件 #{event_idx}: text={repr(text[:50])}{'...' if len(text) > 50 else ''}, finishReason={finish_reason}") |
| except Exception as e: |
| print(f" 事件 #{event_idx}: 解析失败 - {e}") |
|
|
| print(f"\n总共收到 {chunk_count} 个HTTP chunk") |
|
|
| def test_anti_truncation_stream_request(): |
| """测试流式抗截断请求""" |
| print("\n" + "=" * 80) |
| print("【测试5】流式抗截断请求 (POST /v1/models/流式抗截断/gemini-2.5-flash:streamGenerateContent)") |
| print("=" * 80) |
| print(f"请求体: {json.dumps(test_request_body, indent=2, ensure_ascii=False)}\n") |
|
|
| print("流式抗截断响应数据 (每个chunk):") |
| print("-" * 80) |
|
|
| with client.stream( |
| "POST", |
| "/v1/models/流式抗截断/gemini-2.5-flash:streamGenerateContent", |
| json=test_request_body, |
| params={"key": test_api_key} |
| ) as response: |
| print(f"状态码: {response.status_code}") |
| print(f"Content-Type: {response.headers.get('content-type', 'N/A')}\n") |
|
|
| chunk_count = 0 |
| for chunk in response.iter_bytes(): |
| if chunk: |
| chunk_count += 1 |
| print(f"\nChunk #{chunk_count}:") |
| print(f" 类型: {type(chunk).__name__}") |
| print(f" 长度: {len(chunk)}") |
|
|
| |
| try: |
| chunk_str = chunk.decode('utf-8') |
| print(f" 内容预览: {repr(chunk_str[:200] if len(chunk_str) > 200 else chunk_str)}") |
|
|
| |
| if chunk_str.startswith("data: "): |
| |
| for line in chunk_str.strip().split('\n'): |
| line = line.strip() |
| if not line: |
| continue |
|
|
| if line == "data: [DONE]": |
| print(f" => 流结束标记") |
| elif line.startswith("data: "): |
| try: |
| json_str = line[6:] |
| json_data = json.loads(json_str) |
| print(f" 解析后的JSON: {json.dumps(json_data, indent=4, ensure_ascii=False)}") |
| except Exception as e: |
| print(f" SSE解析失败: {e}") |
| except Exception as e: |
| print(f" 解码失败: {e}") |
|
|
| print(f"\n总共收到 {chunk_count} 个chunk") |
|
|
| |
| try: |
| |
| test_non_stream_request() |
|
|
| |
| test_stream_request() |
|
|
| |
| test_fake_stream_request() |
|
|
| |
| test_anti_truncation_stream_request() |
|
|
| print("\n" + "=" * 80) |
| print("测试完成") |
| print("=" * 80) |
|
|
| except Exception as e: |
| print(f"\n❌ 测试过程中出现异常: {e}") |
| import traceback |
| traceback.print_exc() |
|
|