Spaces:
Paused
Paused
| import requests | |
| import json | |
| import os | |
| import asyncio | |
| from app.models import ChatCompletionRequest, Message # 相对导入 | |
| from dataclasses import dataclass | |
| from typing import Optional, Dict, Any, List | |
| import httpx | |
| import logging | |
| logger = logging.getLogger('my_logger') | |
| class GeneratedText: | |
| text: str | |
| finish_reason: Optional[str] = None | |
| class ResponseWrapper: | |
| def __init__(self, data: Dict[Any, Any]): # 正确的初始化方法名 | |
| self._data = data | |
| self._text = self._extract_text() | |
| self._finish_reason = self._extract_finish_reason() | |
| self._prompt_token_count = self._extract_prompt_token_count() | |
| self._candidates_token_count = self._extract_candidates_token_count() | |
| self._total_token_count = self._extract_total_token_count() | |
| self._thoughts = self._extract_thoughts() | |
| self._json_dumps = json.dumps(self._data, indent=4, ensure_ascii=False) | |
| def _extract_thoughts(self) -> Optional[str]: | |
| try: | |
| for part in self._data['candidates'][0]['content']['parts']: | |
| if 'thought' in part: | |
| return part['text'] | |
| return "" | |
| except (KeyError, IndexError): | |
| return "" | |
| def _extract_text(self) -> str: | |
| try: | |
| for part in self._data['candidates'][0]['content']['parts']: | |
| if 'thought' not in part: | |
| return part['text'] | |
| return "" | |
| except (KeyError, IndexError): | |
| return "" | |
| def _extract_finish_reason(self) -> Optional[str]: | |
| try: | |
| return self._data['candidates'][0].get('finishReason') | |
| except (KeyError, IndexError): | |
| return None | |
| def _extract_prompt_token_count(self) -> Optional[int]: | |
| try: | |
| return self._data['usageMetadata'].get('promptTokenCount') | |
| except (KeyError): | |
| return None | |
| def _extract_candidates_token_count(self) -> Optional[int]: | |
| try: | |
| return self._data['usageMetadata'].get('candidatesTokenCount') | |
| except (KeyError): | |
| return None | |
| def _extract_total_token_count(self) -> Optional[int]: | |
| try: | |
| return self._data['usageMetadata'].get('totalTokenCount') | |
| except (KeyError): | |
| return None | |
| def text(self) -> str: | |
| return self._text | |
| def finish_reason(self) -> Optional[str]: | |
| return self._finish_reason | |
| def prompt_token_count(self) -> Optional[int]: | |
| return self._prompt_token_count | |
| def candidates_token_count(self) -> Optional[int]: | |
| return self._candidates_token_count | |
| def total_token_count(self) -> Optional[int]: | |
| return self._total_token_count | |
| def thoughts(self) -> Optional[str]: | |
| return self._thoughts | |
| def json_dumps(self) -> str: | |
| return self._json_dumps | |
| class GeminiClient: | |
| AVAILABLE_MODELS = [] | |
| EXTRA_MODELS = os.environ.get("EXTRA_MODELS", "").split(",") | |
| def __init__(self, api_key: str): | |
| self.api_key = api_key | |
| async def stream_chat(self, request: ChatCompletionRequest, contents, safety_settings, system_instruction): | |
| logger.info("流式开始 →") | |
| api_version = "v1alpha" if "think" in request.model else "v1beta" | |
| url = f"https://generativelanguage.googleapis.com/{api_version}/models/{request.model}:streamGenerateContent?key={self.api_key}&alt=sse" | |
| headers = { | |
| "Content-Type": "application/json", | |
| } | |
| data = { | |
| "contents": contents, | |
| "generationConfig": { | |
| "temperature": request.temperature, | |
| "maxOutputTokens": request.max_tokens, | |
| }, | |
| "safetySettings": safety_settings, | |
| } | |
| if system_instruction: | |
| data["system_instruction"] = system_instruction | |
| async with httpx.AsyncClient() as client: | |
| async with client.stream("POST", url, headers=headers, json=data, timeout=600) as response: | |
| buffer = b"" | |
| try: | |
| async for line in response.aiter_lines(): | |
| if not line.strip(): | |
| continue | |
| if line.startswith("data: "): | |
| line = line[len("data: "):] | |
| buffer += line.encode('utf-8') | |
| try: | |
| data = json.loads(buffer.decode('utf-8')) | |
| buffer = b"" | |
| if 'candidates' in data and data['candidates']: | |
| candidate = data['candidates'][0] | |
| if 'content' in candidate: | |
| content = candidate['content'] | |
| if 'parts' in content and content['parts']: | |
| parts = content['parts'] | |
| text = "" | |
| for part in parts: | |
| if 'text' in part: | |
| text += part['text'] | |
| if text: | |
| yield text | |
| if candidate.get("finishReason") and candidate.get("finishReason") != "STOP": | |
| # logger.warning(f"模型的响应因违反内容政策而被标记: {candidate.get('finishReason')}") | |
| raise ValueError(f"模型的响应被截断: {candidate.get('finishReason')}") | |
| if 'safetyRatings' in candidate: | |
| for rating in candidate['safetyRatings']: | |
| if rating['probability'] == 'HIGH': | |
| # logger.warning(f"模型的响应因高概率被标记为 {rating['category']}") | |
| raise ValueError(f"模型的响应被截断: {rating['category']}") | |
| except json.JSONDecodeError: | |
| # logger.debug(f"JSON解析错误, 当前缓冲区内容: {buffer}") | |
| continue | |
| except Exception as e: | |
| # logger.error(f"流式处理期间发生错误: {e}") | |
| raise e | |
| except Exception as e: | |
| # logger.error(f"流式处理错误: {e}") | |
| raise e | |
| finally: | |
| logger.info("流式结束 ←") | |
| def complete_chat(self, request: ChatCompletionRequest, contents, safety_settings, system_instruction): | |
| api_version = "v1alpha" if "think" in request.model else "v1beta" | |
| url = f"https://generativelanguage.googleapis.com/{api_version}/models/{request.model}:generateContent?key={self.api_key}" | |
| headers = { | |
| "Content-Type": "application/json", | |
| } | |
| data = { | |
| "contents": contents, | |
| "generationConfig": { | |
| "temperature": request.temperature, | |
| "maxOutputTokens": request.max_tokens, | |
| }, | |
| "safetySettings": safety_settings, | |
| } | |
| if system_instruction: | |
| data["system_instruction"] = system_instruction | |
| response = requests.post(url, headers=headers, json=data) | |
| response.raise_for_status() | |
| return ResponseWrapper(response.json()) | |
| def convert_messages(self, messages, use_system_prompt=False): | |
| gemini_history = [] | |
| errors = [] | |
| system_instruction_text = "" | |
| is_system_phase = use_system_prompt | |
| for i, message in enumerate(messages): | |
| role = message.role | |
| content = message.content | |
| if isinstance(content, str): | |
| if is_system_phase and role == 'system': | |
| if system_instruction_text: | |
| system_instruction_text += "\n" + content | |
| else: | |
| system_instruction_text = content | |
| else: | |
| is_system_phase = False | |
| if role in ['user', 'system']: | |
| role_to_use = 'user' | |
| elif role == 'assistant': | |
| role_to_use = 'model' | |
| else: | |
| errors.append(f"Invalid role: {role}") | |
| continue | |
| if gemini_history and gemini_history[-1]['role'] == role_to_use: | |
| gemini_history[-1]['parts'].append({"text": content}) | |
| else: | |
| gemini_history.append( | |
| {"role": role_to_use, "parts": [{"text": content}]}) | |
| elif isinstance(content, list): | |
| parts = [] | |
| for item in content: | |
| if item.get('type') == 'text': | |
| parts.append({"text": item.get('text')}) | |
| elif item.get('type') == 'image_url': | |
| image_data = item.get('image_url', {}).get('url', '') | |
| if image_data.startswith('data:image/'): | |
| try: | |
| mime_type, base64_data = image_data.split(';')[0].split(':')[1], image_data.split(',')[1] | |
| parts.append({ | |
| "inline_data": { | |
| "mime_type": mime_type, | |
| "data": base64_data | |
| } | |
| }) | |
| except (IndexError, ValueError): | |
| errors.append( | |
| f"Invalid data URI for image: {image_data}") | |
| else: | |
| errors.append( | |
| f"Invalid image URL format for item: {item}") | |
| if parts: | |
| if role in ['user', 'system']: | |
| role_to_use = 'user' | |
| elif role == 'assistant': | |
| role_to_use = 'model' | |
| else: | |
| errors.append(f"Invalid role: {role}") | |
| continue | |
| if gemini_history and gemini_history[-1]['role'] == role_to_use: | |
| gemini_history[-1]['parts'].extend(parts) | |
| else: | |
| gemini_history.append( | |
| {"role": role_to_use, "parts": parts}) | |
| if errors: | |
| return errors | |
| else: | |
| return gemini_history, {"parts": [{"text": system_instruction_text}]} | |
| async def list_available_models(api_key) -> list: | |
| url = "https://generativelanguage.googleapis.com/v1beta/models?key={}".format( | |
| api_key) | |
| async with httpx.AsyncClient() as client: | |
| response = await client.get(url) | |
| response.raise_for_status() | |
| data = response.json() | |
| models = [model["name"] for model in data.get("models", [])] | |
| models.extend(GeminiClient.EXTRA_MODELS) | |
| return models | |