File size: 25,698 Bytes
5868187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
"""
Google API Client - Handles all communication with Google's Gemini API.
This module is used by both OpenAI compatibility layer and native Gemini endpoints.
"""
import asyncio
import gc
import json

from fastapi import Response
from fastapi.responses import StreamingResponse

from config import (
    get_code_assist_endpoint,
    DEFAULT_SAFETY_SETTINGS,
    get_base_model_name,
    get_thinking_budget,
    should_include_thoughts,
    is_search_model,
    get_auto_ban_enabled,
    get_auto_ban_error_codes,
    get_retry_429_max_retries,
    get_retry_429_enabled,
    get_retry_429_interval
)
from .httpx_client import http_client, create_streaming_client_with_kwargs
from log import log
from .credential_manager import CredentialManager
from .usage_stats import record_successful_call
from .utils import get_user_agent

def _create_error_response(message: str, status_code: int = 500) -> Response:
    """Create standardized error response."""
    return Response(
        content=json.dumps({
            "error": {
                "message": message,
                "type": "api_error",
                "code": status_code
            }
        }),
        status_code=status_code,
        media_type="application/json"
    )

async def _handle_api_error(credential_manager: CredentialManager, status_code: int, response_content: str = ""):
    """Handle API errors by rotating credentials when needed. Error recording should be done before calling this function."""
    if status_code == 429 and credential_manager:
        if response_content:
            log.error(f"Google API returned status 429 - quota exhausted. Response details: {response_content[:500]}")
        else:
            log.error("Google API returned status 429 - quota exhausted, switching credentials")
        await credential_manager.force_rotate_credential()
    
    # 处理自动封禁的错误码
    elif await get_auto_ban_enabled() and status_code in await get_auto_ban_error_codes() and credential_manager:
        if response_content:
            log.error(f"Google API returned status {status_code} - auto ban triggered. Response details: {response_content[:500]}")
        else:
            log.warning(f"Google API returned status {status_code} - auto ban triggered, rotating credentials")
        await credential_manager.force_rotate_credential()

async def _prepare_request_headers_and_payload(payload: dict, credential_data: dict):
    """Prepare request headers and final payload from credential data."""
    # 尝试获取token,支持多种字段名
    token = credential_data.get('token') or credential_data.get('access_token', '')
    
    if not token:
        raise Exception("凭证中没有找到有效的访问令牌(token或access_token字段)")
    
    headers = {
        "Authorization": f"Bearer {token}",
        "Content-Type": "application/json",
        "User-Agent": get_user_agent(),
    }

    # 直接使用凭证数据中的项目ID
    project_id = credential_data.get("project_id", "")
    if not project_id:
        raise Exception("项目ID不存在于凭证数据中")

    final_payload = {
        "model": payload.get("model"),
        "project": project_id,
        "request": payload.get("request", {})
    }
    
    return headers, final_payload

async def send_gemini_request(payload: dict, is_streaming: bool = False, credential_manager: CredentialManager = None) -> Response:
    """
    Send a request to Google's Gemini API.
    
    Args:
        payload: The request payload in Gemini format
        is_streaming: Whether this is a streaming request
        credential_manager: CredentialManager instance
        
    Returns:
        FastAPI Response object
    """
    # 获取429重试配置
    max_retries = await get_retry_429_max_retries()
    retry_429_enabled = await get_retry_429_enabled()
    retry_interval = await get_retry_429_interval()
    
    # 确定API端点
    action = "streamGenerateContent" if is_streaming else "generateContent"
    target_url = f"{await get_code_assist_endpoint()}/v1internal:{action}"
    if is_streaming:
        target_url += "?alt=sse"

    # 确保有credential_manager
    if not credential_manager:
        return _create_error_response("Credential manager not provided", 500)
    
    # 获取当前凭证
    try:
        credential_result = await credential_manager.get_valid_credential()
        if not credential_result:
            return _create_error_response("No valid credentials available", 500)
        
        current_file, credential_data = credential_result
        headers, final_payload = await _prepare_request_headers_and_payload(payload, credential_data)
    except Exception as e:
        return _create_error_response(str(e), 500)

    # 预序列化payload,避免重试时重复序列化
    final_post_data = json.dumps(final_payload)
    
    # Debug日志:打印请求体结构
    log.debug(f"Final request payload structure: {json.dumps(final_payload, ensure_ascii=False, indent=2)}")

    for attempt in range(max_retries + 1):
        try:
            if is_streaming:
                # 流式请求处理 - 使用httpx_client模块的统一配置
                client = await create_streaming_client_with_kwargs()
                
                try:
                    # 使用stream方法但不在async with块中消费数据
                    stream_ctx = client.stream("POST", target_url, content=final_post_data, headers=headers)
                    resp = await stream_ctx.__aenter__()
                    
                    if resp.status_code == 429:
                        # 记录429错误并获取响应内容
                        response_content = ""
                        try:
                            content_bytes = await resp.aread()
                            if isinstance(content_bytes, bytes):
                                response_content = content_bytes.decode('utf-8', errors='ignore')
                        except Exception as e:
                            log.debug(f"[STREAMING] Failed to read 429 response content: {e}")
                        
                        # 显示详细的429错误信息
                        if response_content:
                            log.error(f"Google API returned status 429 (STREAMING). Response details: {response_content[:500]}")
                        else:
                            log.error("Google API returned status 429 (STREAMING) - quota exhausted, no response details available")
                        
                        if credential_manager and current_file:
                            await credential_manager.record_api_call_result(current_file, False, 429)
                        
                        # 清理资源
                        try:
                            await stream_ctx.__aexit__(None, None, None)
                        except:
                            pass
                        await client.aclose()
                        
                        # 如果重试可用且未达到最大次数,进行重试
                        if retry_429_enabled and attempt < max_retries:
                            log.warning(f"[RETRY] 429 error encountered, retrying ({attempt + 1}/{max_retries})")
                            if credential_manager:
                                # 429错误时强制轮换凭证,不增加调用计数
                                await credential_manager.force_rotate_credential()
                                # 重新获取凭证和headers(凭证可能已轮换)
                                new_credential_result = await credential_manager.get_valid_credential()
                                if new_credential_result:
                                    current_file, credential_data = new_credential_result
                                    headers, updated_payload = await _prepare_request_headers_and_payload(payload, credential_data)
                                    final_post_data = json.dumps(updated_payload)
                            await asyncio.sleep(retry_interval)
                            continue  # 跳出内层处理,继续外层循环重试
                        else:
                            # 返回429错误流
                            async def error_stream():
                                error_response = {
                                    "error": {
                                        "message": "429 rate limit exceeded, max retries reached",
                                        "type": "api_error",
                                        "code": 429
                                    }
                                }
                                yield f"data: {json.dumps(error_response)}\n\n"
                            return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=429)
                    elif resp.status_code != 200:
                        # 处理其他非200状态码的错误
                        response_content = ""
                        try:
                            content_bytes = await resp.aread()
                            if isinstance(content_bytes, bytes):
                                response_content = content_bytes.decode('utf-8', errors='ignore')
                        except Exception as e:
                            log.debug(f"[STREAMING] Failed to read error response content: {e}")
                        
                        # 显示详细的错误信息
                        if response_content:
                            log.error(f"Google API returned status {resp.status_code} (STREAMING). Response details: {response_content[:500]}")
                        else:
                            log.error(f"Google API returned status {resp.status_code} (STREAMING) - no response details available")
                        
                        # 记录API调用错误
                        if credential_manager and current_file:
                            await credential_manager.record_api_call_result(current_file, False, resp.status_code)
                        
                        # 清理资源
                        try:
                            await stream_ctx.__aexit__(None, None, None)
                        except:
                            pass
                        await client.aclose()
                        
                        # 处理凭证轮换
                        await _handle_api_error(credential_manager, resp.status_code, response_content)
                        
                        # 返回错误流
                        async def error_stream():
                            error_response = {
                                "error": {
                                    "message": f"API error: {resp.status_code}",
                                    "type": "api_error", 
                                    "code": resp.status_code
                                }
                            }
                            yield f"data: {json.dumps(error_response)}\n\n"
                        return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=resp.status_code)
                    else:
                        # 成功响应,传递所有资源给流式处理函数管理
                        return _handle_streaming_response_managed(resp, stream_ctx, client, credential_manager, payload.get("model", ""), current_file)
                        
                except Exception as e:
                    # 清理资源
                    try:
                        await client.aclose()
                    except:
                        pass
                    raise e

            else:
                # 非流式请求处理 - 使用httpx_client模块
                async with http_client.get_client(timeout=None) as client:
                    resp = await client.post(
                        target_url, content=final_post_data, headers=headers
                    )
                    
                    if resp.status_code == 429:
                        # 记录429错误
                        if credential_manager and current_file:
                            await credential_manager.record_api_call_result(current_file, False, 429)
                        
                        # 如果重试可用且未达到最大次数,继续重试
                        if retry_429_enabled and attempt < max_retries:
                            log.warning(f"[RETRY] 429 error encountered, retrying ({attempt + 1}/{max_retries})")
                            if credential_manager:
                                # 429错误时强制轮换凭证,不增加调用计数
                                await credential_manager.force_rotate_credential()
                                # 重新获取凭证和headers(凭证可能已轮换)
                                new_credential_result = await credential_manager.get_valid_credential()
                                if new_credential_result:
                                    current_file, credential_data = new_credential_result
                                    headers, updated_payload = await _prepare_request_headers_and_payload(payload, credential_data)
                                    final_post_data = json.dumps(updated_payload)
                            await asyncio.sleep(retry_interval)
                            continue
                        else:
                            log.error(f"[RETRY] Max retries exceeded for 429 error")
                            return _create_error_response("429 rate limit exceeded, max retries reached", 429)
                    else:
                        # 非429错误或成功响应,正常处理
                        return await _handle_non_streaming_response(resp, credential_manager, payload.get("model", ""), current_file)
                    
        except Exception as e:
            if attempt < max_retries:
                log.warning(f"[RETRY] Request failed with exception, retrying ({attempt + 1}/{max_retries}): {str(e)}")
                await asyncio.sleep(retry_interval)
                continue
            else:
                log.error(f"Request to Google API failed: {str(e)}")
                return _create_error_response(f"Request failed: {str(e)}")
    
    # 如果循环结束仍未成功,返回错误
    return _create_error_response("Max retries exceeded", 429)


def _handle_streaming_response_managed(resp, stream_ctx, client, credential_manager: CredentialManager = None, model_name: str = "", current_file: str = None) -> StreamingResponse:
    """Handle streaming response with complete resource lifecycle management."""
    
    # 检查HTTP错误
    if resp.status_code != 200:
        # 立即清理资源并返回错误
        async def cleanup_and_error():
            try:
                await stream_ctx.__aexit__(None, None, None)
            except:
                pass
            try:
                await client.aclose()
            except:
                pass
            
            # 获取响应内容用于详细错误显示
            response_content = ""
            try:
                content_bytes = await resp.aread()
                if isinstance(content_bytes, bytes):
                    response_content = content_bytes.decode('utf-8', errors='ignore')
            except Exception as e:
                log.debug(f"[STREAMING] Failed to read response content for error analysis: {e}")
                response_content = ""
            
            # 显示详细错误信息
            if resp.status_code == 429:
                if response_content:
                    log.error(f"Google API returned status 429 (STREAMING). Response details: {response_content[:500]}")
                else:
                    log.error(f"Google API returned status 429 (STREAMING)")
            else:
                if response_content:
                    log.error(f"Google API returned status {resp.status_code} (STREAMING). Response details: {response_content[:500]}")
                else:
                    log.error(f"Google API returned status {resp.status_code} (STREAMING)")
            
            # 记录API调用错误
            if credential_manager and current_file:
                await credential_manager.record_api_call_result(current_file, False, resp.status_code)
            
            await _handle_api_error(credential_manager, resp.status_code, response_content)
            
            error_response = {
                "error": {
                    "message": f"API error: {resp.status_code}",
                    "type": "api_error",
                    "code": resp.status_code
                }
            }
            yield f'data: {json.dumps(error_response)}\n\n'.encode('utf-8')
        
        return StreamingResponse(
            cleanup_and_error(),
            media_type="text/event-stream",
            status_code=resp.status_code
        )
    
    # 正常流式响应处理,确保资源在流结束时被清理
    async def managed_stream_generator():
        success_recorded = False
        managed_stream_generator._chunk_count = 0  # 初始化chunk计数器
        try:
            async for chunk in resp.aiter_lines():
                if not chunk or not chunk.startswith('data: '):
                    continue
                    
                # 记录第一次成功响应
                if not success_recorded:
                    if current_file and credential_manager:
                        await credential_manager.record_api_call_result(current_file, True)
                        # 记录到使用统计
                        try:
                            await record_successful_call(current_file, model_name)
                        except Exception as e:
                            log.debug(f"Failed to record usage statistics: {e}")
                    success_recorded = True
                
                payload = chunk[len('data: '):]
                try:
                    obj = json.loads(payload)
                    if "response" in obj:
                        data = obj["response"]
                        yield f"data: {json.dumps(data, separators=(',',':'))}\n\n".encode()
                        await asyncio.sleep(0)  # 让其他协程有机会运行
                        
                        # 定期释放内存(每100个chunk)
                        if hasattr(managed_stream_generator, '_chunk_count'):
                            managed_stream_generator._chunk_count += 1
                            if managed_stream_generator._chunk_count % 100 == 0:
                                gc.collect()
                    else:
                        yield f"data: {json.dumps(obj, separators=(',',':'))}\n\n".encode()
                except json.JSONDecodeError:
                    continue
                    
        except Exception as e:
            log.error(f"Streaming error: {e}")
            err = {"error": {"message": str(e), "type": "api_error", "code": 500}}
            yield f"data: {json.dumps(err)}\n\n".encode()
        finally:
            # 确保清理所有资源
            try:
                await stream_ctx.__aexit__(None, None, None)
            except Exception as e:
                log.debug(f"Error closing stream context: {e}")
            try:
                await client.aclose()
            except Exception as e:
                log.debug(f"Error closing client: {e}")

    return StreamingResponse(
        managed_stream_generator(),
        media_type="text/event-stream"
    )

async def _handle_non_streaming_response(resp, credential_manager: CredentialManager = None, model_name: str = "", current_file: str = None) -> Response:
    """Handle non-streaming response from Google API."""
    if resp.status_code == 200:
        try:
            # 记录成功响应
            if current_file and credential_manager:
                await credential_manager.record_api_call_result(current_file, True)
                # 记录到使用统计
                try:
                    await record_successful_call(current_file, model_name)
                except Exception as e:
                    log.debug(f"Failed to record usage statistics: {e}")
            
            raw = await resp.aread()
            google_api_response = raw.decode('utf-8')
            if google_api_response.startswith('data: '):
                google_api_response = google_api_response[len('data: '):]
            google_api_response = json.loads(google_api_response)
            log.debug(f"Google API原始响应: {json.dumps(google_api_response, ensure_ascii=False)[:500]}...")
            standard_gemini_response = google_api_response.get("response")
            log.debug(f"提取的response字段: {json.dumps(standard_gemini_response, ensure_ascii=False)[:500]}...")
            return Response(
                content=json.dumps(standard_gemini_response),
                status_code=200,
                media_type="application/json; charset=utf-8"
            )
        except Exception as e:
            log.error(f"Failed to parse Google API response: {str(e)}")
            return Response(
                content=resp.content,
                status_code=resp.status_code,
                media_type=resp.headers.get("Content-Type")
            )
    else:
        # 获取响应内容用于详细错误显示
        response_content = ""
        try:
            if hasattr(resp, 'content'):
                content = resp.content
                if isinstance(content, bytes):
                    response_content = content.decode('utf-8', errors='ignore')
            else:
                content_bytes = await resp.aread()
                if isinstance(content_bytes, bytes):
                    response_content = content_bytes.decode('utf-8', errors='ignore')
        except Exception as e:
            log.debug(f"[NON-STREAMING] Failed to read response content for error analysis: {e}")
            response_content = ""
        
        # 显示详细错误信息
        if resp.status_code == 429:
            if response_content:
                log.error(f"Google API returned status 429 (NON-STREAMING). Response details: {response_content[:500]}")
            else:
                log.error(f"Google API returned status 429 (NON-STREAMING)")
        else:
            if response_content:
                log.error(f"Google API returned status {resp.status_code} (NON-STREAMING). Response details: {response_content[:500]}")
            else:
                log.error(f"Google API returned status {resp.status_code} (NON-STREAMING)")
        
        # 记录API调用错误
        if credential_manager and current_file:
            await credential_manager.record_api_call_result(current_file, False, resp.status_code)
        
        await _handle_api_error(credential_manager, resp.status_code, response_content)
        
        return _create_error_response(f"API error: {resp.status_code}", resp.status_code)

def build_gemini_payload_from_native(native_request: dict, model_from_path: str) -> dict:
    """
    Build a Gemini API payload from a native Gemini request with full pass-through support.
    """
    # 创建请求副本以避免修改原始数据
    request_data = native_request.copy()
    
    # 应用默认安全设置(如果未指定)
    if "safetySettings" not in request_data:
        request_data["safetySettings"] = DEFAULT_SAFETY_SETTINGS
    
    # 确保generationConfig存在
    if "generationConfig" not in request_data:
        request_data["generationConfig"] = {}
    
    generation_config = request_data["generationConfig"]
    
    # 配置thinking(如果未指定thinkingConfig)
    if "thinkingConfig" not in generation_config:
        generation_config["thinkingConfig"] = {}
    
    thinking_config = generation_config["thinkingConfig"]
    
    # 只有在未明确设置时才应用默认thinking配置
    if "includeThoughts" not in thinking_config:
        thinking_config["includeThoughts"] = should_include_thoughts(model_from_path)
    if "thinkingBudget" not in thinking_config:
        thinking_config["thinkingBudget"] = get_thinking_budget(model_from_path)
    
    # 为搜索模型添加Google Search工具(如果未指定且没有functionDeclarations)
    if is_search_model(model_from_path):
        if "tools" not in request_data:
            request_data["tools"] = []
        # 检查是否已有functionDeclarations或googleSearch工具
        has_function_declarations = any(tool.get("functionDeclarations") for tool in request_data["tools"])
        has_google_search = any(tool.get("googleSearch") for tool in request_data["tools"])
        
        # 只有在没有任何工具时才添加googleSearch,或者只有googleSearch工具时可以添加更多googleSearch
        if not has_function_declarations and not has_google_search:
            request_data["tools"].append({"googleSearch": {}})
    
    # 透传所有其他Gemini原生字段:
    # - contents (必需)
    # - systemInstruction (可选)
    # - generationConfig (已处理)
    # - safetySettings (已处理)  
    # - tools (已处理)
    # - toolConfig (透传)
    # - cachedContent (透传)
    # - 以及任何其他未知字段都会被透传
    
    return {
        "model": get_base_model_name(model_from_path),
        "request": request_data
    }