gemini / app /services /gemini.py
drriver's picture
Upload 28 files
60016b2 verified
import requests
import json
import os
import asyncio
import time
from app.models import ChatCompletionRequest, Message
from dataclasses import dataclass
from typing import Optional, Dict, Any, List
import httpx
import logging
from app.utils import format_log_message
logger = logging.getLogger('my_logger')
# 是否启用假流式请求 默认启用
FAKE_STREAMING = os.environ.get("FAKE_STREAMING", "true").lower() in ["true", "1", "yes"]
# 假流式请求的空内容返回间隔(秒)
FAKE_STREAMING_INTERVAL = float(os.environ.get("FAKE_STREAMING_INTERVAL", "1"))
@dataclass
class GeneratedText:
text: str
finish_reason: Optional[str] = None
class ResponseWrapper:
def __init__(self, data: Dict[Any, Any]): # 正确的初始化方法名
self._data = data
self._text = self._extract_text()
self._finish_reason = self._extract_finish_reason()
self._prompt_token_count = self._extract_prompt_token_count()
self._candidates_token_count = self._extract_candidates_token_count()
self._total_token_count = self._extract_total_token_count()
self._thoughts = self._extract_thoughts()
self._json_dumps = json.dumps(self._data, indent=4, ensure_ascii=False)
def _extract_thoughts(self) -> Optional[str]:
try:
for part in self._data['candidates'][0]['content']['parts']:
if 'thought' in part:
return part['text']
return ""
except (KeyError, IndexError):
return ""
def _extract_text(self) -> str:
try:
for part in self._data['candidates'][0]['content']['parts']:
if 'thought' not in part:
return part['text']
return ""
except (KeyError, IndexError):
return ""
def _extract_finish_reason(self) -> Optional[str]:
try:
return self._data['candidates'][0].get('finishReason')
except (KeyError, IndexError):
return None
def _extract_prompt_token_count(self) -> Optional[int]:
try:
return self._data['usageMetadata'].get('promptTokenCount')
except (KeyError):
return None
def _extract_candidates_token_count(self) -> Optional[int]:
try:
return self._data['usageMetadata'].get('candidatesTokenCount')
except (KeyError):
return None
def _extract_total_token_count(self) -> Optional[int]:
try:
return self._data['usageMetadata'].get('totalTokenCount')
except (KeyError):
return None
@property
def text(self) -> str:
return self._text
@property
def finish_reason(self) -> Optional[str]:
return self._finish_reason
@property
def prompt_token_count(self) -> Optional[int]:
return self._prompt_token_count
@property
def candidates_token_count(self) -> Optional[int]:
return self._candidates_token_count
@property
def total_token_count(self) -> Optional[int]:
return self._total_token_count
@property
def thoughts(self) -> Optional[str]:
return self._thoughts
@property
def json_dumps(self) -> str:
return self._json_dumps
class GeminiClient:
AVAILABLE_MODELS = []
EXTRA_MODELS = os.environ.get("EXTRA_MODELS", "").split(",")
def __init__(self, api_key: str):
self.api_key = api_key
async def stream_chat(self, request: ChatCompletionRequest, contents, safety_settings, system_instruction):
extra_log = {'key': self.api_key[:8], 'request_type': 'stream', 'model': request.model, 'status_code': 'N/A'}
log_msg = format_log_message('INFO', "流式请求开始", extra=extra_log)
logger.info(log_msg)
# 检查是否启用假流式请求
if FAKE_STREAMING:
log_msg = format_log_message('INFO', "使用假流式请求模式(发送换行符保持连接)", extra=extra_log)
logger.info(log_msg)
try:
# 这个方法不再直接使用self.api_key,而是由main.py提供API密钥列表和管理
# 在这里,我们只负责持续发送换行符,直到main.py那边获取到响应
# 持续发送换行符,直到外部取消此生成器
start_time = time.time()
while True:
# 发送换行符作为保活消息
yield "\n"
# 等待一段时间
await asyncio.sleep(FAKE_STREAMING_INTERVAL)
# 如果等待时间过长(超过300秒),防止无限等待
if time.time() - start_time > 300:
log_msg = format_log_message('WARNING', "假流式请求等待时间过长,强制结束", extra=extra_log)
logger.warning(log_msg)
# 抛出超时异常,让外部处理
error_msg = "假流式请求等待时间过长,所有API密钥均已尝试"
extra_log_timeout = {'key': self.api_key[:8], 'request_type': 'fake-stream', 'model': request.model, 'status_code': 'TIMEOUT', 'error_message': error_msg}
log_msg = format_log_message('ERROR', error_msg, extra=extra_log_timeout)
logger.error(log_msg)
raise TimeoutError(error_msg)
except Exception as e:
if not isinstance(e, asyncio.CancelledError): # 忽略取消异常的日志记录
error_msg = f"假流式处理期间发生错误: {str(e)}"
extra_log_error = {'key': self.api_key[:8], 'request_type': 'fake-stream', 'model': request.model, 'status_code': 'ERROR', 'error_message': error_msg}
log_msg = format_log_message('ERROR', error_msg, extra=extra_log_error)
logger.error(log_msg)
raise e
finally:
log_msg = format_log_message('INFO', "假流式请求结束", extra=extra_log)
logger.info(log_msg)
else:
# 原始流式请求处理逻辑
api_version = "v1alpha" if "think" in request.model else "v1beta"
url = f"https://generativelanguage.googleapis.com/{api_version}/models/{request.model}:streamGenerateContent?key={self.api_key}&alt=sse"
headers = {
"Content-Type": "application/json",
}
data = {
"contents": contents,
"generationConfig": {
"temperature": request.temperature,
"maxOutputTokens": request.max_tokens,
},
"safetySettings": safety_settings,
}
if system_instruction:
data["system_instruction"] = system_instruction
async with httpx.AsyncClient() as client:
async with client.stream("POST", url, headers=headers, json=data, timeout=600) as response:
buffer = b""
try:
async for line in response.aiter_lines():
if not line.strip():
continue
if line.startswith("data: "):
line = line[len("data: "):]
buffer += line.encode('utf-8')
try:
data = json.loads(buffer.decode('utf-8'))
buffer = b""
if 'candidates' in data and data['candidates']:
candidate = data['candidates'][0]
if 'content' in candidate:
content = candidate['content']
if 'parts' in content and content['parts']:
parts = content['parts']
text = ""
for part in parts:
if 'text' in part:
text += part['text']
if text:
yield text
if candidate.get("finishReason") and candidate.get("finishReason") != "STOP":
error_msg = f"模型的响应被截断: {candidate.get('finishReason')}"
extra_log_error = {'key': self.api_key[:8], 'request_type': 'stream', 'model': request.model, 'status_code': 'ERROR', 'error_message': error_msg}
log_msg = format_log_message('WARNING', error_msg, extra=extra_log_error)
logger.warning(log_msg)
raise ValueError(error_msg)
if 'safetyRatings' in candidate:
for rating in candidate['safetyRatings']:
if rating['probability'] == 'HIGH':
error_msg = f"模型的响应被截断: {rating['category']}"
extra_log_safety = {'key': self.api_key[:8], 'request_type': 'stream', 'model': request.model, 'status_code': 'ERROR', 'error_message': error_msg}
log_msg = format_log_message('WARNING', error_msg, extra=extra_log_safety)
logger.warning(log_msg)
raise ValueError(error_msg)
except json.JSONDecodeError:
continue
except Exception as e:
error_msg = f"流式处理期间发生错误: {str(e)}"
extra_log_stream_error = {'key': self.api_key[:8], 'request_type': 'stream', 'model': request.model, 'status_code': 'ERROR', 'error_message': error_msg}
log_msg = format_log_message('ERROR', error_msg, extra=extra_log_stream_error)
logger.error(log_msg)
raise e
except Exception as e:
raise e
finally:
log_msg = format_log_message('INFO', "流式请求结束", extra=extra_log)
logger.info(log_msg)
def complete_chat(self, request: ChatCompletionRequest, contents, safety_settings, system_instruction):
extra_log = {'key': self.api_key[:8], 'request_type': 'non-stream', 'model': request.model, 'status_code': 'N/A'}
log_msg = format_log_message('INFO', "非流式请求开始", extra=extra_log)
logger.info(log_msg)
api_version = "v1alpha" if "think" in request.model else "v1beta"
url = f"https://generativelanguage.googleapis.com/{api_version}/models/{request.model}:generateContent?key={self.api_key}"
headers = {
"Content-Type": "application/json",
}
data = {
"contents": contents,
"generationConfig": {
"temperature": request.temperature,
"maxOutputTokens": request.max_tokens,
},
"safetySettings": safety_settings,
}
if system_instruction:
data["system_instruction"] = system_instruction
try:
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
log_msg = format_log_message('INFO', "非流式请求成功完成", extra=extra_log)
logger.info(log_msg)
return ResponseWrapper(response.json())
except Exception as e:
raise
def convert_messages(self, messages, use_system_prompt=False):
gemini_history = []
errors = []
system_instruction_text = ""
is_system_phase = use_system_prompt
for i, message in enumerate(messages):
role = message.role
content = message.content
if isinstance(content, str):
if is_system_phase and role == 'system':
if system_instruction_text:
system_instruction_text += "\n" + content
else:
system_instruction_text = content
else:
is_system_phase = False
if role in ['user', 'system']:
role_to_use = 'user'
elif role == 'assistant':
role_to_use = 'model'
else:
errors.append(f"Invalid role: {role}")
continue
if gemini_history and gemini_history[-1]['role'] == role_to_use:
gemini_history[-1]['parts'].append({"text": content})
else:
gemini_history.append(
{"role": role_to_use, "parts": [{"text": content}]})
elif isinstance(content, list):
parts = []
for item in content:
if item.get('type') == 'text':
parts.append({"text": item.get('text')})
elif item.get('type') == 'image_url':
image_data = item.get('image_url', {}).get('url', '')
if image_data.startswith('data:image/'):
try:
mime_type, base64_data = image_data.split(';')[0].split(':')[1], image_data.split(',')[1]
parts.append({
"inline_data": {
"mime_type": mime_type,
"data": base64_data
}
})
except (IndexError, ValueError):
errors.append(
f"Invalid data URI for image: {image_data}")
else:
errors.append(
f"Invalid image URL format for item: {item}")
if parts:
if role in ['user', 'system']:
role_to_use = 'user'
elif role == 'assistant':
role_to_use = 'model'
else:
errors.append(f"Invalid role: {role}")
continue
if gemini_history and gemini_history[-1]['role'] == role_to_use:
gemini_history[-1]['parts'].extend(parts)
else:
gemini_history.append(
{"role": role_to_use, "parts": parts})
if errors:
return errors
else:
return gemini_history, {"parts": [{"text": system_instruction_text}]}
@staticmethod
async def list_available_models(api_key) -> list:
url = "https://generativelanguage.googleapis.com/v1beta/models?key={}".format(
api_key)
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
data = response.json()
models = [model["name"] for model in data.get("models", [])]
models.extend(GeminiClient.EXTRA_MODELS)
return models