rzline commited on
Commit
813815f
·
verified ·
1 Parent(s): 8db57ed

Create gemini.py

Browse files
Files changed (1) hide show
  1. app/gemini.py +278 -0
app/gemini.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import os
4
+ import asyncio
5
+ from app.models import ChatCompletionRequest, Message # 相对导入
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Dict, Any, List
8
+ import httpx
9
+ import logging
10
+
11
+ logger = logging.getLogger('my_logger')
12
+
13
+
14
+ @dataclass
15
+ class GeneratedText:
16
+ text: str
17
+ finish_reason: Optional[str] = None
18
+
19
+
20
+ class ResponseWrapper:
21
+ def __init__(self, data: Dict[Any, Any]): # 正确的初始化方法名
22
+ self._data = data
23
+ self._text = self._extract_text()
24
+ self._finish_reason = self._extract_finish_reason()
25
+ self._prompt_token_count = self._extract_prompt_token_count()
26
+ self._candidates_token_count = self._extract_candidates_token_count()
27
+ self._total_token_count = self._extract_total_token_count()
28
+ self._thoughts = self._extract_thoughts()
29
+ self._json_dumps = json.dumps(self._data, indent=4, ensure_ascii=False)
30
+
31
+ def _extract_thoughts(self) -> Optional[str]:
32
+ try:
33
+ for part in self._data['candidates'][0]['content']['parts']:
34
+ if 'thought' in part:
35
+ return part['text']
36
+ return ""
37
+ except (KeyError, IndexError):
38
+ return ""
39
+
40
+ def _extract_text(self) -> str:
41
+ try:
42
+ for part in self._data['candidates'][0]['content']['parts']:
43
+ if 'thought' not in part:
44
+ return part['text']
45
+ return ""
46
+ except (KeyError, IndexError):
47
+ return ""
48
+
49
+ def _extract_finish_reason(self) -> Optional[str]:
50
+ try:
51
+ return self._data['candidates'][0].get('finishReason')
52
+ except (KeyError, IndexError):
53
+ return None
54
+
55
+ def _extract_prompt_token_count(self) -> Optional[int]:
56
+ try:
57
+ return self._data['usageMetadata'].get('promptTokenCount')
58
+ except (KeyError):
59
+ return None
60
+
61
+ def _extract_candidates_token_count(self) -> Optional[int]:
62
+ try:
63
+ return self._data['usageMetadata'].get('candidatesTokenCount')
64
+ except (KeyError):
65
+ return None
66
+
67
+ def _extract_total_token_count(self) -> Optional[int]:
68
+ try:
69
+ return self._data['usageMetadata'].get('totalTokenCount')
70
+ except (KeyError):
71
+ return None
72
+
73
+ @property
74
+ def text(self) -> str:
75
+ return self._text
76
+
77
+ @property
78
+ def finish_reason(self) -> Optional[str]:
79
+ return self._finish_reason
80
+
81
+ @property
82
+ def prompt_token_count(self) -> Optional[int]:
83
+ return self._prompt_token_count
84
+
85
+ @property
86
+ def candidates_token_count(self) -> Optional[int]:
87
+ return self._candidates_token_count
88
+
89
+ @property
90
+ def total_token_count(self) -> Optional[int]:
91
+ return self._total_token_count
92
+
93
+ @property
94
+ def thoughts(self) -> Optional[str]:
95
+ return self._thoughts
96
+
97
+ @property
98
+ def json_dumps(self) -> str:
99
+ return self._json_dumps
100
+
101
+
102
+ class GeminiClient:
103
+
104
+ AVAILABLE_MODELS = []
105
+ EXTRA_MODELS = os.environ.get("EXTRA_MODELS", "").split(",")
106
+
107
+ def __init__(self, api_key: str):
108
+ self.api_key = api_key
109
+
110
+ async def stream_chat(self, request: ChatCompletionRequest, contents, safety_settings, system_instruction):
111
+ logger.info("流式开始 →")
112
+ api_version = "v1alpha" if "think" in request.model else "v1beta"
113
+ url = f"https://generativelanguage.googleapis.com/{api_version}/models/{request.model}:streamGenerateContent?key={self.api_key}&alt=sse"
114
+ headers = {
115
+ "Content-Type": "application/json",
116
+ }
117
+ data = {
118
+ "contents": contents,
119
+ "generationConfig": {
120
+ "temperature": request.temperature,
121
+ "maxOutputTokens": request.max_tokens,
122
+ },
123
+ "safetySettings": safety_settings,
124
+ }
125
+ if system_instruction:
126
+ data["system_instruction"] = system_instruction
127
+
128
+ async with httpx.AsyncClient() as client:
129
+ async with client.stream("POST", url, headers=headers, json=data, timeout=600) as response:
130
+ buffer = b""
131
+ try:
132
+ async for line in response.aiter_lines():
133
+ if not line.strip():
134
+ continue
135
+ if line.startswith("data: "):
136
+ line = line[len("data: "):]
137
+ buffer += line.encode('utf-8')
138
+ try:
139
+ data = json.loads(buffer.decode('utf-8'))
140
+ buffer = b""
141
+ if 'candidates' in data and data['candidates']:
142
+ candidate = data['candidates'][0]
143
+ if 'content' in candidate:
144
+ content = candidate['content']
145
+ if 'parts' in content and content['parts']:
146
+ parts = content['parts']
147
+ text = ""
148
+ for part in parts:
149
+ if 'text' in part:
150
+ text += part['text']
151
+ if text:
152
+ yield text
153
+
154
+ if candidate.get("finishReason") and candidate.get("finishReason") != "STOP":
155
+ # logger.warning(f"模型的响应因违反内容政策而被标记: {candidate.get('finishReason')}")
156
+ raise ValueError(f"模型的响应被截断: {candidate.get('finishReason')}")
157
+
158
+ if 'safetyRatings' in candidate:
159
+ for rating in candidate['safetyRatings']:
160
+ if rating['probability'] == 'HIGH':
161
+ # logger.warning(f"模型的响应因高概率被标记为 {rating['category']}")
162
+ raise ValueError(f"模型的响应被截断: {rating['category']}")
163
+ except json.JSONDecodeError:
164
+ # logger.debug(f"JSON解析错误, 当前缓冲区内容: {buffer}")
165
+ continue
166
+ except Exception as e:
167
+ # logger.error(f"流式处理期间发生错误: {e}")
168
+ raise e
169
+ except Exception as e:
170
+ # logger.error(f"流式处理错误: {e}")
171
+ raise e
172
+ finally:
173
+ logger.info("流式结束 ←")
174
+
175
+
176
+ def complete_chat(self, request: ChatCompletionRequest, contents, safety_settings, system_instruction):
177
+ api_version = "v1alpha" if "think" in request.model else "v1beta"
178
+ url = f"https://generativelanguage.googleapis.com/{api_version}/models/{request.model}:generateContent?key={self.api_key}"
179
+ headers = {
180
+ "Content-Type": "application/json",
181
+ }
182
+ data = {
183
+ "contents": contents,
184
+ "generationConfig": {
185
+ "temperature": request.temperature,
186
+ "maxOutputTokens": request.max_tokens,
187
+ },
188
+ "safetySettings": safety_settings,
189
+ }
190
+ if system_instruction:
191
+ data["system_instruction"] = system_instruction
192
+ response = requests.post(url, headers=headers, json=data)
193
+ response.raise_for_status()
194
+ return ResponseWrapper(response.json())
195
+
196
+ def convert_messages(self, messages, use_system_prompt=False):
197
+ gemini_history = []
198
+ errors = []
199
+ system_instruction_text = ""
200
+ is_system_phase = use_system_prompt
201
+ for i, message in enumerate(messages):
202
+ role = message.role
203
+ content = message.content
204
+
205
+ if isinstance(content, str):
206
+ if is_system_phase and role == 'system':
207
+ if system_instruction_text:
208
+ system_instruction_text += "\n" + content
209
+ else:
210
+ system_instruction_text = content
211
+ else:
212
+ is_system_phase = False
213
+
214
+ if role in ['user', 'system']:
215
+ role_to_use = 'user'
216
+ elif role == 'assistant':
217
+ role_to_use = 'model'
218
+ else:
219
+ errors.append(f"Invalid role: {role}")
220
+ continue
221
+
222
+ if gemini_history and gemini_history[-1]['role'] == role_to_use:
223
+ gemini_history[-1]['parts'].append({"text": content})
224
+ else:
225
+ gemini_history.append(
226
+ {"role": role_to_use, "parts": [{"text": content}]})
227
+ elif isinstance(content, list):
228
+ parts = []
229
+ for item in content:
230
+ if item.get('type') == 'text':
231
+ parts.append({"text": item.get('text')})
232
+ elif item.get('type') == 'image_url':
233
+ image_data = item.get('image_url', {}).get('url', '')
234
+ if image_data.startswith('data:image/'):
235
+ try:
236
+ mime_type, base64_data = image_data.split(';')[0].split(':')[1], image_data.split(',')[1]
237
+ parts.append({
238
+ "inline_data": {
239
+ "mime_type": mime_type,
240
+ "data": base64_data
241
+ }
242
+ })
243
+ except (IndexError, ValueError):
244
+ errors.append(
245
+ f"Invalid data URI for image: {image_data}")
246
+ else:
247
+ errors.append(
248
+ f"Invalid image URL format for item: {item}")
249
+
250
+ if parts:
251
+ if role in ['user', 'system']:
252
+ role_to_use = 'user'
253
+ elif role == 'assistant':
254
+ role_to_use = 'model'
255
+ else:
256
+ errors.append(f"Invalid role: {role}")
257
+ continue
258
+ if gemini_history and gemini_history[-1]['role'] == role_to_use:
259
+ gemini_history[-1]['parts'].extend(parts)
260
+ else:
261
+ gemini_history.append(
262
+ {"role": role_to_use, "parts": parts})
263
+ if errors:
264
+ return errors
265
+ else:
266
+ return gemini_history, {"parts": [{"text": system_instruction_text}]}
267
+
268
+ @staticmethod
269
+ async def list_available_models(api_key) -> list:
270
+ url = "https://generativelanguage.googleapis.com/v1beta/models?key={}".format(
271
+ api_key)
272
+ async with httpx.AsyncClient() as client:
273
+ response = await client.get(url)
274
+ response.raise_for_status()
275
+ data = response.json()
276
+ models = [model["name"] for model in data.get("models", [])]
277
+ models.extend(GeminiClient.EXTRA_MODELS)
278
+ return models