Spaces:
Runtime error
Runtime error
| from io import BytesIO | |
| from discord.ext import commands | |
| import discord | |
| from google import genai | |
| from google.genai import types | |
| import random | |
| from aiohttp import ClientSession | |
| from utils.decorator import auto_delete | |
| from utils.func import async_iter, async_do_thread | |
| from utils.color_printer import cpr | |
| from utils.config import config | |
| from utils.context_prompter import ContextPrompter | |
| from utils.logger import logger | |
| from datetime import datetime | |
| import PIL.Image | |
| import json | |
| class Gemini(commands.Cog): | |
| def __init__( | |
| self, | |
| bot: commands.Bot, | |
| webhook: discord.Webhook, | |
| ): | |
| self.bot = bot | |
| self.conversations = {} | |
| self.apikeys = config.get("gemini_keys") | |
| self.current_key = config.get("current_key") | |
| self.num = len(self.apikeys) | |
| # 确保chat_channels中的键全部为字符串 | |
| self.update_chat_channels() | |
| self.config = config | |
| self.context_length = 20 | |
| self.target_language = config.get("target_language") | |
| # 获取Gemini模型配置 | |
| self.gemini_models = config.get("gemini_models", { | |
| "chat": "gemini-2.0-pro-exp-02-05", # 默认聊天模型 | |
| "translate": "gemini-2.0-pro-exp-02-05" # 默认翻译模型 | |
| }) | |
| # 如果配置中没有gemini_models,写入默认配置 | |
| if not config.get("gemini_models"): | |
| config.write("gemini_models", self.gemini_models) | |
| self.default_gemini_config = types.GenerateContentConfig( | |
| system_instruction="", | |
| top_k=55, | |
| top_p=0.95, | |
| temperature=1.3, | |
| safety_settings=[ | |
| types.SafetySetting( | |
| category=types.HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, | |
| threshold=types.HarmBlockThreshold.OFF, | |
| ), | |
| types.SafetySetting( | |
| category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, | |
| threshold=types.HarmBlockThreshold.OFF, | |
| ), | |
| types.SafetySetting( | |
| category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, | |
| threshold=types.HarmBlockThreshold.OFF, | |
| ), | |
| types.SafetySetting( | |
| category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, | |
| threshold=types.HarmBlockThreshold.OFF, | |
| ), | |
| types.SafetySetting( | |
| category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, | |
| threshold=types.HarmBlockThreshold.OFF, | |
| ), | |
| ], | |
| ) | |
| self.webhook = webhook | |
| self.context_prompter = ContextPrompter() | |
| self.non_gemini_model = None # for openai model | |
| self.openai_api_key = config.get("openai_api_key") | |
| self.openai_endpoint = config.get("openai_endpoint") | |
| if self.openai_api_key is not None and self.openai_endpoint is not None: | |
| print(cpr.info("OpenAI API available.")) | |
| def update_chat_channels(self): | |
| """更新聊天频道配置""" | |
| # 获取所有服务器配置 | |
| self.servers = config.get("servers", {}) | |
| # 更新Gemini模型配置 | |
| self.gemini_models = config.get("gemini_models", { | |
| "chat": "gemini-2.0-pro-exp-02-05", | |
| "translate": "gemini-2.0-pro-exp-02-05" | |
| }) | |
| # 获取可用的Gemini模型列表 | |
| self.available_models = config.get("gemini_available_models", []) | |
| if not self.available_models: | |
| # 设置默认模型列表 | |
| self.available_models = [ | |
| {"name": "gemini-2.0-pro-exp-02-05", "description": "默认聊天和翻译模型"}, | |
| {"name": "gemini-pro", "description": "旧版Gemini Pro模型"} | |
| ] | |
| config.write("gemini_available_models", self.available_models) | |
| print(f"Gemini cog 已更新服务器配置和模型列表") | |
| def get_channel_config(self, guild_id: str, channel_id: str): | |
| """获取频道配置""" | |
| server_name, server_config = config.get_server_config(guild_id) | |
| if not server_config: | |
| return None | |
| return server_config.get("chat_channels", {}).get(channel_id) | |
| def get_next_key(self): | |
| self.current_key = (self.current_key + 1) % self.num | |
| config.write("current_key", self.current_key) | |
| return self.apikeys[self.current_key] | |
| def get_random_key(self): | |
| return self.apikeys[random.randint(0, self.num - 1)] | |
| async def request_gemini( | |
| self, | |
| ctx: commands.Context, | |
| prompt: str, | |
| model_config: types.GenerateContentConfig = None, | |
| model="gemini-2.0-pro-exp-02-05", | |
| username=None, | |
| extra_attachment: discord.Attachment = None, | |
| ): | |
| if model_config is None: | |
| model_config = self.default_gemini_config | |
| # 获取服务器和频道配置 | |
| guild_id = str(ctx.guild.id) | |
| channel_id = str(ctx.channel.id) | |
| channel_config = self.get_channel_config(guild_id, channel_id) | |
| if not channel_config: | |
| await ctx.send("此频道未配置为聊天频道", ephemeral=True) | |
| return | |
| print(f"当前频道配置: {channel_config}") | |
| # 尝试获取预设 | |
| agent_manager = self.bot.get_cog("AgentManager") | |
| preset_data = None | |
| preset_name = "chat_preset.json" # 默认使用chat_preset.json | |
| # 根据情况选择不同的预设 | |
| if extra_attachment: | |
| preset_name = "attachment_preset.json" | |
| print(f"使用附件预设: {preset_name}") | |
| elif ctx.message.reference: | |
| preset_name = "reference_preset.json" | |
| # 获取预设内容,传递频道ID和服务器ID以获取该频道对应的预设 | |
| if agent_manager: | |
| preset_data = agent_manager.get_preset_json(preset_name, channel_id, guild_id) | |
| if model != "gemini-2.0-pro-exp-02-05": | |
| key = self.get_random_key() | |
| else: | |
| key = self.get_next_key() | |
| client = genai.Client(api_key=key) | |
| # 处理附件 - 下载附件内容 | |
| attachment_bytes = None | |
| attachment_mime_type = None | |
| if extra_attachment: | |
| msg = await ctx.send("Downloading the attachment...") | |
| bytes_data = await extra_attachment.read() | |
| attachment_bytes = bytes_data | |
| attachment_mime_type = extra_attachment.content_type.split(";")[0] | |
| await msg.edit(content="Processing the attachment...") | |
| print(f"附件已下载: {extra_attachment.filename} ({attachment_mime_type})") | |
| else: | |
| msg = await ctx.send("Typing...") if username is None else await self.webhook.send("Typing...", wait=True, username=username) | |
| # 检查预设数据是否存在 | |
| if not preset_data: | |
| await msg.edit(content="无法加载预设数据,请联系管理员") | |
| return | |
| # 首先获取变量替换所需的数据 | |
| current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| bot_name = ctx.me.name | |
| bot_display_name = ctx.me.display_name | |
| user_name = ctx.author.name | |
| user_display_name = ctx.author.display_name | |
| # 处理上下文 | |
| context = "" | |
| if hasattr(ctx, 'context') and ctx.context: | |
| context = ctx.context | |
| elif hasattr(ctx, 'history') and ctx.history: | |
| context = ctx.history | |
| else: | |
| # 获取历史消息作为上下文 | |
| context = await self.context_prompter.get_context_for_prompt(ctx, self.context_length) | |
| # 确保context是字符串 | |
| if not isinstance(context, str): | |
| context = str(context) if context is not None else "" | |
| # 替换预设中的变量 | |
| first_user_message = preset_data.get("first_user_message", "") | |
| first_user_message = first_user_message.replace("{context}", context) | |
| first_user_message = first_user_message.replace("{question}", prompt) | |
| first_user_message = first_user_message.replace("{name}", bot_display_name) | |
| first_user_message = first_user_message.replace("{bot_name}", bot_name) | |
| first_user_message = first_user_message.replace("{current_time}", current_time) | |
| first_user_message = first_user_message.replace("{user_display_name}", user_display_name) | |
| first_user_message = first_user_message.replace("{user_name}", user_name) | |
| main_content = preset_data.get("main_content", "") | |
| last_message_content = preset_data.get("last_message", "") | |
| prefill_assistant_reply = preset_data.get("prefill_assistant_reply", False) | |
| main_content = main_content.replace("{context}", context) | |
| main_content = main_content.replace("{question}", prompt) | |
| main_content = main_content.replace("{name}", bot_display_name) | |
| main_content = main_content.replace("{bot_name}", bot_name) | |
| main_content = main_content.replace("{current_time}", current_time) | |
| main_content = main_content.replace("{user_display_name}", user_display_name) | |
| main_content = main_content.replace("{user_name}", user_name) | |
| # 如果是引用回复 | |
| if ctx.message.reference and 'reference' in preset_name: | |
| reference = ctx.message.reference.resolved | |
| reference_time = self.context_prompter.get_msg_time(reference) | |
| reference_user_name = reference.author.name | |
| reference_user_display_name = reference.author.display_name | |
| reference_content = reference.content | |
| main_content = main_content.replace("{reference_time}", reference_time) | |
| main_content = main_content.replace("{reference_user_name}", reference_user_name) | |
| main_content = main_content.replace("{reference_user_display_name}", reference_user_display_name) | |
| main_content = main_content.replace("{reference_content}", reference_content) | |
| first_user_message = first_user_message.replace("{reference_time}", reference_time) | |
| first_user_message = first_user_message.replace("{reference_user_name}", reference_user_name) | |
| first_user_message = first_user_message.replace("{reference_user_display_name}", reference_user_display_name) | |
| first_user_message = first_user_message.replace("{reference_content}", reference_content) | |
| # 替换 last_message 中的变量 | |
| last_message_content = last_message_content.replace("{context}", context) | |
| last_message_content = last_message_content.replace("{question}", prompt) | |
| last_message_content = last_message_content.replace("{name}", bot_display_name) | |
| last_message_content = last_message_content.replace("{bot_name}", bot_name) | |
| last_message_content = last_message_content.replace("{current_time}", current_time) | |
| last_message_content = last_message_content.replace("{user_display_name}", user_display_name) | |
| last_message_content = last_message_content.replace("{user_name}", user_name) | |
| # 构建user-model-user的三个上下文 | |
| user_parts = [types.Part.from_text(text=first_user_message)] | |
| model_parts = [types.Part.from_text(text=main_content)] | |
| last_message_parts = [types.Part.from_text(text=last_message_content)] | |
| # 如果有附件,添加到最后一个用户消息中 | |
| if attachment_bytes: | |
| # 使用Pillow和inline_data方式添加图片 | |
| image_bytes = BytesIO(attachment_bytes) | |
| image = PIL.Image.open(image_bytes) | |
| # 转换为字节数据 | |
| mime_type = attachment_mime_type or "image/jpeg" | |
| img_byte_arr = BytesIO() | |
| image.save(img_byte_arr, format=image.format or "JPEG") | |
| img_byte_data = img_byte_arr.getvalue() | |
| # 添加到消息中 | |
| last_message_parts.append( | |
| types.Part( | |
| inline_data=types.Blob( | |
| mime_type=mime_type, | |
| data=img_byte_data | |
| ) | |
| ) | |
| ) | |
| print("附件已添加到用户消息中") | |
| # 决定最后一条消息的角色 | |
| last_message_role = "model" if prefill_assistant_reply else "user" | |
| # 构建正确的 contents 结构 | |
| contents = [ | |
| types.Content( | |
| role="user", | |
| parts=user_parts, | |
| ), | |
| types.Content( | |
| role="model", | |
| parts=model_parts, | |
| ), | |
| types.Content( | |
| role=last_message_role, # 使用动态角色 | |
| parts=last_message_parts, | |
| ), | |
| ] | |
| # 获取Gemini配置 | |
| gemini_config_data = None | |
| if agent_manager: | |
| gemini_config_data = agent_manager.get_preset_json("gemini_config.json", channel_id, guild_id) | |
| # 构建配置 | |
| generate_content_config = types.GenerateContentConfig( | |
| temperature=gemini_config_data.get("temperature", 1.0) if gemini_config_data else 1.0, | |
| top_p=gemini_config_data.get("top_p", 0.95) if gemini_config_data else 0.95, | |
| top_k=gemini_config_data.get("top_k", 64) if gemini_config_data else 64, | |
| max_output_tokens=gemini_config_data.get("max_output_tokens", 8192) if gemini_config_data else 8192, | |
| safety_settings=[ | |
| types.SafetySetting( | |
| category=types.HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, | |
| threshold=types.HarmBlockThreshold.OFF, | |
| ), | |
| types.SafetySetting( | |
| category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, | |
| threshold=types.HarmBlockThreshold.OFF, | |
| ), | |
| types.SafetySetting( | |
| category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, | |
| threshold=types.HarmBlockThreshold.OFF, | |
| ), | |
| types.SafetySetting( | |
| category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, | |
| threshold=types.HarmBlockThreshold.OFF, | |
| ), | |
| types.SafetySetting( | |
| category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, | |
| threshold=types.HarmBlockThreshold.OFF, | |
| ), | |
| ], | |
| response_mime_type="text/plain", | |
| system_instruction=[ | |
| types.Part.from_text(text=preset_data.get("system_prompt", "")), | |
| ], | |
| ) | |
| # --- 流式处理逻辑修改 --- | |
| DISCORD_LIMIT = 1995 # Discord 消息长度限制 (留一点余地) | |
| current_message_buffer = "" | |
| # 先发送初始消息 | |
| if username is None: | |
| msg = await ctx.send("Typing...") | |
| else: | |
| msg = await self.webhook.send("Typing... ", wait=True, username=username) | |
| # 记录原始请求体 (或其近似结构) 到日志 | |
| logger.info( | |
| "Gemini Raw Request: %s", | |
| json.dumps(generate_content_config.dict(), indent=2, ensure_ascii=False) | |
| ) | |
| # 使用流式响应 | |
| response = client.models.generate_content_stream( | |
| model=model, | |
| contents=contents, | |
| config=generate_content_config, | |
| ) | |
| async for chunk in async_iter(response): | |
| new_text = chunk.text | |
| if new_text: | |
| # 检查是否会超长 | |
| if len(current_message_buffer) + len(new_text) > DISCORD_LIMIT: | |
| # 编辑当前消息为最终内容 | |
| try: | |
| await msg.edit(content=current_message_buffer) | |
| except discord.errors.HTTPException as edit_error: | |
| if edit_error.code == 50035: # Still too long? Should not happen often with buffer | |
| logger.warning("Message part still too long even after split attempt.") | |
| else: | |
| raise edit_error # Re-raise other edit errors | |
| # 发送新的消息 | |
| if username is None: | |
| msg = await ctx.send("...") # Send follow-up indicator | |
| else: | |
| msg = await self.webhook.send("...", wait=True, username=username) | |
| # 重置缓冲区 | |
| current_message_buffer = new_text | |
| else: | |
| # 追加到缓冲区 | |
| current_message_buffer += new_text | |
| # 处理循环结束后缓冲区剩余的内容 | |
| if current_message_buffer: | |
| try: | |
| await msg.edit(content=current_message_buffer) | |
| except discord.errors.HTTPException as final_edit_error: | |
| if final_edit_error.code == 50035: | |
| # 如果最后一部分仍然太长,作为新消息发送 | |
| logger.warning("Final message part too long, sending as new message.") | |
| if username is None: | |
| await ctx.send(current_message_buffer) | |
| else: | |
| await self.webhook.send(current_message_buffer, username=username) | |
| else: | |
| raise final_edit_error # Re-raise other errors | |
| async def hey( | |
| self, | |
| ctx: commands.Context, | |
| *, | |
| question: str, | |
| context_length: int = None, | |
| ): | |
| # 获取服务器和频道配置 | |
| guild_id = str(ctx.guild.id) | |
| channel_id = str(ctx.channel.id) | |
| channel_config = self.get_channel_config(guild_id, channel_id) | |
| if not channel_config: | |
| await ctx.send("此频道未配置为聊天频道", ephemeral=True) | |
| return | |
| if context_length is None: | |
| context_length = self.context_length | |
| extra_attachment = None | |
| # 获取历史消息作为上下文 | |
| history = await self.context_prompter.get_context_for_prompt(ctx, context_length) | |
| ctx.history = history # 将历史消息保存到ctx对象中,供预设处理使用 | |
| # 检查附件 | |
| if ctx.message.reference: | |
| reference = ctx.message.reference.resolved | |
| # 优先查找引用消息中的附件 | |
| if reference and reference.attachments: | |
| extra_attachment = reference.attachments[-1] | |
| # 选择合适的预设 | |
| agent_manager = self.bot.get_cog("AgentManager") | |
| preset_name = "chat_preset.json" # 默认使用chat_preset.json | |
| if agent_manager: | |
| if ctx.message.reference: | |
| reference = ctx.message.reference.resolved | |
| if reference and reference.attachments: | |
| preset_name = "attachment_preset.json" | |
| else: | |
| preset_name = "reference_preset.json" | |
| # 检查附件是否存在,确保传递正确 | |
| if extra_attachment: | |
| print(f"处理附件: {extra_attachment.filename} ({extra_attachment.content_type})") | |
| # 使用聊天模型 | |
| chat_model = self.gemini_models.get("chat", "gemini-2.0-pro-exp-02-05") | |
| # 发送请求 | |
| await self.request_gemini( | |
| ctx, | |
| question, # 直接传递原始问题,预设处理在request_gemini中完成 | |
| model=chat_model, | |
| extra_attachment=extra_attachment, | |
| ) | |
| async def translate( | |
| self, | |
| ctx: commands.Context, | |
| target_language: str = None, | |
| context_length: int = None, | |
| ): | |
| # 获取服务器和频道配置 | |
| guild_id = str(ctx.guild.id) | |
| channel_id = str(ctx.channel.id) | |
| channel_config = self.get_channel_config(guild_id, channel_id) | |
| if not channel_config: | |
| await ctx.send("此频道未配置为聊天频道", ephemeral=True) | |
| return | |
| if ctx.message.reference is None: | |
| await ctx.send( | |
| "请回复要翻译的消息", ephemeral=True | |
| ) | |
| return | |
| if context_length is None: | |
| context_length = self.context_length | |
| if target_language is None: | |
| target_language = self.target_language | |
| # 使用翻译模型 | |
| translate_model = self.gemini_models.get("translate", "gemini-2.0-pro-exp-02-05") | |
| # 尝试获取翻译预设 | |
| agent_manager = self.bot.get_cog("AgentManager") | |
| preset_data = None | |
| # 获取被回复的消息 | |
| reference_message = await ctx.channel.fetch_message( | |
| ctx.message.reference.message_id | |
| ) | |
| # 检查是否有附件 | |
| extra_attachment = None | |
| if reference_message and reference_message.attachments: | |
| extra_attachment = reference_message.attachments[-1] | |
| print(f"翻译附件: {extra_attachment.filename} ({extra_attachment.content_type})") | |
| # 下载附件内容 | |
| attachment_bytes = None | |
| attachment_mime_type = None | |
| if extra_attachment: | |
| msg = await ctx.send("Downloading the attachment...") | |
| bytes_data = await extra_attachment.read() | |
| attachment_bytes = bytes_data | |
| attachment_mime_type = extra_attachment.content_type.split(";")[0] | |
| await msg.edit(content="Processing the attachment...") | |
| print(f"附件已下载: {extra_attachment.filename} ({attachment_mime_type})") | |
| else: | |
| msg = await ctx.send("Translating...") | |
| # 获取预设内容 | |
| if agent_manager: | |
| preset_data = agent_manager.get_preset_json("translate_preset.json") | |
| if preset_data: | |
| # 使用预设JSON结构和原始文本 | |
| key = self.get_next_key() | |
| client = genai.Client(api_key=key) | |
| # 获取变量替换所需的数据 | |
| current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| bot_name = ctx.me.name | |
| bot_display_name = ctx.me.display_name | |
| user_name = ctx.author.name | |
| user_display_name = ctx.author.display_name | |
| # 处理上下文和引用内容 | |
| context = "" | |
| if hasattr(ctx, 'context') and ctx.context: | |
| context = ctx.context | |
| elif hasattr(ctx, 'history') and ctx.history: | |
| context = ctx.history | |
| else: | |
| # 获取历史消息作为上下文 | |
| context = await self.context_prompter.get_context_for_prompt(ctx, context_length) | |
| # 确保context是字符串 | |
| if not isinstance(context, str): | |
| context = str(context) if context is not None else "" | |
| # 获取被引用消息的相关信息 | |
| reference_time = reference_message.created_at.strftime("%Y-%m-%d %H:%M:%S") | |
| reference_user_name = reference_message.author.name | |
| reference_user_display_name = reference_message.author.display_name | |
| reference_content = reference_message.content | |
| # 使用预设JSON结构 | |
| system_prompt = preset_data.get("system_prompt", "") | |
| first_user_message = preset_data.get("first_user_message", "") | |
| # 确保模型已设置且处于传递中 | |
| translate_model = self.gemini_models.get("translate", "gemini-2.0-pro-exp-02-05") | |
| # 处理不同内容格式情况 | |
| if attachment_bytes and attachment_mime_type: | |
| # 如果有图片附件 | |
| try: | |
| # 将字节转换为PIL图像以供Gemini处理 | |
| if attachment_mime_type.startswith("image/"): | |
| image = PIL.Image.open(BytesIO(attachment_bytes)) | |
| # 按照Gemini推荐尺寸调整图片,适应大多数模型版本 | |
| if max(image.size) > 3072: | |
| ratio = 3072 / max(image.size) | |
| new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio)) | |
| image = image.resize(new_size, PIL.Image.LANCZOS) | |
| # 构建内容 | |
| content = [ | |
| { | |
| "role": "user", | |
| "parts": [ | |
| types.Part.from_text(first_user_message), | |
| types.Part.from_image(image), | |
| ] | |
| } | |
| ] | |
| # 获取安全设置 | |
| safety_settings = [] | |
| try: | |
| gemini_config_data = None | |
| if agent_manager: | |
| gemini_config_data = agent_manager.get_preset_json("gemini_config.json") | |
| if gemini_config_data and "safety_settings" in gemini_config_data: | |
| safety_settings = self.build_safety_settings(gemini_config_data["safety_settings"]) | |
| except Exception as e: | |
| logger.error(f"Error loading safety settings: {e}") | |
| safety_settings = [] | |
| # 构建和应用生成内容配置 | |
| generate_content_config = types.GenerateContentConfig( | |
| temperature=0.2, | |
| top_p=0.95, | |
| top_k=55, | |
| max_output_tokens=8192, | |
| safety_settings=safety_settings, | |
| response_mime_type="text/plain", | |
| system_instruction=[ | |
| types.Part.from_text(text=system_prompt), | |
| ], | |
| ) | |
| # 使用流式响应 | |
| full = "" | |
| n = config.get("gemini_chunk_per_edit") | |
| every_n_chunk = 1 | |
| try: | |
| response = client.models.generate_content_stream( | |
| model=translate_model, # 使用翻译模型 | |
| contents=content, | |
| config=generate_content_config, | |
| ) | |
| async for chunk in async_iter(response): | |
| if chunk.text: | |
| full += chunk.text | |
| if every_n_chunk == n: | |
| await msg.edit(content=full) | |
| every_n_chunk = 1 | |
| else: | |
| every_n_chunk += 1 | |
| await msg.edit(content=full) | |
| except Exception as e: | |
| logger.error( | |
| "Error when translating with gemini, error: %s", | |
| e, | |
| exc_info=True, | |
| ) | |
| if full == "": | |
| await msg.edit(content="Uh oh, something went wrong...") | |
| else: | |
| full += "\nUh oh, something went wrong..." | |
| await msg.edit(content=full) | |
| except Exception as e: | |
| logger.error(f"Error processing image: {e}") | |
| await ctx.send("无法处理图片,请检查附件格式", ephemeral=True) | |
| else: | |
| # 替换预设中的变量 | |
| first_user_message = preset_data.get("first_user_message", "") | |
| first_user_message = first_user_message.replace("{context}", context) | |
| first_user_message = first_user_message.replace("{target_language}", target_language) | |
| first_user_message = first_user_message.replace("{reference_content}", reference_content) | |
| first_user_message = first_user_message.replace("{reference_time}", reference_time) | |
| first_user_message = first_user_message.replace("{reference_user_name}", reference_user_name) | |
| first_user_message = first_user_message.replace("{reference_user_display_name}", reference_user_display_name) | |
| first_user_message = first_user_message.replace("{name}", bot_display_name) | |
| first_user_message = first_user_message.replace("{bot_name}", bot_name) | |
| first_user_message = first_user_message.replace("{current_time}", current_time) | |
| first_user_message = first_user_message.replace("{user_display_name}", user_display_name) | |
| first_user_message = first_user_message.replace("{user_name}", user_name) | |
| main_content = preset_data.get("main_content", "") | |
| main_content = main_content.replace("{context}", context) | |
| main_content = main_content.replace("{target_language}", target_language) | |
| main_content = main_content.replace("{reference_content}", reference_content) | |
| main_content = main_content.replace("{reference_time}", reference_time) | |
| main_content = main_content.replace("{reference_user_name}", reference_user_name) | |
| main_content = main_content.replace("{reference_user_display_name}", reference_user_display_name) | |
| main_content = main_content.replace("{name}", bot_display_name) | |
| main_content = main_content.replace("{bot_name}", bot_name) | |
| main_content = main_content.replace("{current_time}", current_time) | |
| main_content = main_content.replace("{user_display_name}", user_display_name) | |
| main_content = main_content.replace("{user_name}", user_name) | |
| last_message_content = preset_data.get("last_message", "") | |
| last_message_content = last_message_content.replace("{context}", context) | |
| last_message_content = last_message_content.replace("{target_language}", target_language) | |
| last_message_content = last_message_content.replace("{reference_content}", reference_content) | |
| last_message_content = last_message_content.replace("{reference_time}", reference_time) | |
| last_message_content = last_message_content.replace("{reference_user_name}", reference_user_name) | |
| last_message_content = last_message_content.replace("{reference_user_display_name}", reference_user_display_name) | |
| last_message_content = last_message_content.replace("{name}", bot_display_name) | |
| last_message_content = last_message_content.replace("{bot_name}", bot_name) | |
| last_message_content = last_message_content.replace("{current_time}", current_time) | |
| last_message_content = last_message_content.replace("{user_display_name}", user_display_name) | |
| last_message_content = last_message_content.replace("{user_name}", user_name) | |
| # 构建user-model-user的三个上下文 | |
| contents = [ | |
| types.Content( | |
| role="user", | |
| parts=[ | |
| types.Part.from_text(text=first_user_message), | |
| ], | |
| ), | |
| types.Content( | |
| role="model", | |
| parts=[ | |
| types.Part.from_text(text=main_content), | |
| ], | |
| ), | |
| types.Content( | |
| role="user", | |
| parts=[ | |
| types.Part.from_text(text=last_message_content), | |
| ], | |
| ), | |
| ] | |
| # 如果有附件,添加到最后一个用户消息中 | |
| if attachment_bytes: | |
| # 使用Pillow和inline_data方式添加图片 | |
| image_bytes = BytesIO(attachment_bytes) | |
| image = PIL.Image.open(image_bytes) | |
| # 转换为字节数据 | |
| mime_type = attachment_mime_type or "image/jpeg" | |
| img_byte_arr = BytesIO() | |
| image.save(img_byte_arr, format=image.format or "JPEG") | |
| img_byte_data = img_byte_arr.getvalue() | |
| # 添加到消息中 | |
| contents[2].parts.append( | |
| types.Part( | |
| inline_data=types.Blob( | |
| mime_type=mime_type, | |
| data=img_byte_data | |
| ) | |
| ) | |
| ) | |
| print("附件已添加到翻译请求中") | |
| # 设置安全设置 | |
| safety_settings = [ | |
| types.SafetySetting( | |
| category=types.HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, | |
| threshold=types.HarmBlockThreshold.OFF, | |
| ), | |
| types.SafetySetting( | |
| category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, | |
| threshold=types.HarmBlockThreshold.OFF, | |
| ), | |
| types.SafetySetting( | |
| category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, | |
| threshold=types.HarmBlockThreshold.OFF, | |
| ), | |
| types.SafetySetting( | |
| category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, | |
| threshold=types.HarmBlockThreshold.OFF, | |
| ), | |
| types.SafetySetting( | |
| category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, | |
| threshold=types.HarmBlockThreshold.OFF, | |
| ), | |
| ] | |
| # 获取Gemini配置 | |
| gemini_config_data = None | |
| if agent_manager: | |
| gemini_config_data = agent_manager.get_preset_json("gemini_config.json", channel_id) | |
| # 构建配置 | |
| generate_content_config = types.GenerateContentConfig( | |
| temperature=gemini_config_data.get("temperature", 1.0) if gemini_config_data else 1.0, | |
| top_p=gemini_config_data.get("top_p", 0.95) if gemini_config_data else 0.95, | |
| top_k=gemini_config_data.get("top_k", 64) if gemini_config_data else 64, | |
| max_output_tokens=gemini_config_data.get("max_output_tokens", 8192) if gemini_config_data else 8192, | |
| safety_settings=safety_settings, | |
| response_mime_type="text/plain", | |
| system_instruction=[ | |
| types.Part.from_text(text=preset_data.get("system_prompt", "")), | |
| ], | |
| ) | |
| # 使用已经存在的msg变量,而不是创建新消息 | |
| # 如果之前没有创建消息(例如没有附件),现在才创建 | |
| if 'msg' not in locals() or msg is None: | |
| msg = await ctx.send("Translating...") | |
| full = "" | |
| n = config.get("gemini_chunk_per_edit") | |
| every_n_chunk = 1 | |
| try: | |
| # 记录翻译请求内容 | |
| log_contents = [] | |
| for content in contents: | |
| parts_text = [] | |
| for part in content.parts: | |
| if hasattr(part, "text") and part.text: | |
| parts_text.append(f"Text: {part.text}") | |
| else: | |
| parts_text.append(f"Unknown part type: {type(part)}") | |
| log_contents.append(f"Role: {content.role}, Parts: {parts_text}") | |
| system_instruction = "None" | |
| if hasattr(generate_content_config, "system_instruction"): | |
| if generate_content_config.system_instruction: | |
| system_instruction = generate_content_config.system_instruction[0].text if generate_content_config.system_instruction else "None" | |
| # 只记录到日志文件,不再重复打印到控制台 | |
| logger.info( | |
| "Gemini翻译请求发送: 模型=gemini-2.0-pro-exp-02-05, 内容=%s, 系统提示=%s, 配置=%s", | |
| log_contents, | |
| system_instruction, | |
| { | |
| "temperature": generate_content_config.temperature, | |
| "top_p": generate_content_config.top_p, | |
| "top_k": generate_content_config.top_k, | |
| "max_tokens": generate_content_config.max_output_tokens, | |
| } | |
| ) | |
| response = client.models.generate_content_stream( | |
| model=translate_model, # 使用翻译模型 | |
| contents=contents, | |
| config=generate_content_config, | |
| ) | |
| async for chunk in async_iter(response): | |
| if chunk.text: | |
| full += chunk.text | |
| if every_n_chunk == n: | |
| await msg.edit(content=full) | |
| every_n_chunk = 1 | |
| else: | |
| every_n_chunk += 1 | |
| await msg.edit(content=full) | |
| except Exception as e: | |
| logger.error( | |
| "Error when translating with gemini, error: %s", | |
| e, | |
| exc_info=True, | |
| ) | |
| if full == "": | |
| await msg.edit(content="Uh oh, something went wrong...") | |
| else: | |
| full += "\nUh oh, something went wrong..." | |
| await msg.edit(content=full) | |
| else: | |
| # 预设数据不存在,显示错误信息 | |
| await ctx.send("无法加载翻译预设,请联系管理员", delete_after=5, ephemeral=True) | |
| async def set_context_length(self, ctx: commands.Context, context_length: int): | |
| self.context_length = context_length | |
| await ctx.send("Context length set.", ephemeral=True, delete_after=5) | |
| async def set_target_language(self, ctx: commands.Context, target_language: str): | |
| self.target_language = target_language | |
| await ctx.send("Target language set.", ephemeral=True, delete_after=5) | |
| async def set_timezone(self, ctx: commands.Context, timezone: str): | |
| try: | |
| self.context_prompter.set_tz(timezone) | |
| await ctx.send( | |
| f"Timezone set to {timezone}.", ephemeral=True, delete_after=5 | |
| ) | |
| except Exception as e: | |
| await ctx.send(f"Invalid timezone.", ephemeral=True, delete_after=5) | |
| async def models(self, ctx: commands.Context, model_type: str = None, model_name: str = None): | |
| """列出或切换Gemini模型 | |
| 参数: | |
| model_type: 模型类型(chat或translate) | |
| model_name: 模型名称 | |
| """ | |
| if not model_type: | |
| # 如果没有指定模型类型,列出当前使用的模型和所有可用模型 | |
| embed = discord.Embed( | |
| title="Gemini模型配置", | |
| description="当前使用的Gemini模型", | |
| color=discord.Color.blue() | |
| ) | |
| chat_model = self.gemini_models.get("chat", "gemini-2.0-pro-exp-02-05") | |
| translate_model = self.gemini_models.get("translate", "gemini-2.0-pro-exp-02-05") | |
| embed.add_field(name="聊天模型", value=f"`{chat_model}`", inline=False) | |
| embed.add_field(name="翻译模型", value=f"`{translate_model}`", inline=False) | |
| # 添加可用模型列表 | |
| available_models_text = "" | |
| for model in self.available_models: | |
| model_name = model.get("name", "") | |
| model_desc = model.get("description", "") | |
| if model_desc: | |
| available_models_text += f"• `{model_name}` - {model_desc}\n" | |
| else: | |
| available_models_text += f"• `{model_name}`\n" | |
| if available_models_text: | |
| embed.add_field(name="可用模型列表", value=available_models_text, inline=False) | |
| else: | |
| embed.add_field(name="可用模型列表", value="没有可用的模型", inline=False) | |
| embed.add_field( | |
| name="使用方法", | |
| value="使用 `/models chat <模型名>` 更改聊天模型\n使用 `/models translate <模型名>` 更改翻译模型", | |
| inline=False | |
| ) | |
| await ctx.send(embed=embed) | |
| return | |
| if model_type not in ["chat", "translate"]: | |
| await ctx.send("错误:模型类型必须是 `chat` 或 `translate`", ephemeral=True) | |
| return | |
| if not model_name: | |
| await ctx.send(f"错误:请指定模型名称", ephemeral=True) | |
| return | |
| # 检查模型是否在可用列表中 | |
| model_exists = any(model.get("name") == model_name for model in self.available_models) | |
| if not model_exists: | |
| # 仍然允许用户设置不在列表中的模型,但显示警告 | |
| await ctx.send(f"警告:模型 `{model_name}` 不在可用模型列表中,但仍将设置为当前模型。", ephemeral=True) | |
| # 更新模型配置 | |
| self.gemini_models[model_type] = model_name | |
| config.write("gemini_models", self.gemini_models) | |
| await ctx.send(f"已将{model_type}模型设置为:{model_name}", ephemeral=True) | |
| async def models_error(self, ctx: commands.Context, error): | |
| if isinstance(error, commands.MissingPermissions): | |
| await ctx.send("错误:只有管理员可以更改模型配置", ephemeral=True) | |
| async def setup(bot: commands.Bot): | |
| apikeys = config.get("gemini_keys") | |
| print(cpr.info(f"{len(apikeys)} keys loaded.")) | |
| webhook = discord.Webhook.from_url( | |
| config.get("webhook_url"), session=ClientSession() | |
| ) | |
| cog = Gemini(bot, webhook) | |
| # 设置AgentManager | |
| try: | |
| agent_manager = bot.get_cog("AgentManager") | |
| if agent_manager: | |
| cog.context_prompter.set_agent_manager(agent_manager) | |
| # 加载Gemini配置 | |
| gemini_config = agent_manager.get_preset_json("gemini_config.json") | |
| if gemini_config: | |
| # 将JSON配置转换为Gemini配置对象 | |
| safety_settings = [] | |
| for setting in gemini_config.get("safety_settings", []): | |
| category = getattr(types.HarmCategory, setting["category"]) | |
| threshold = getattr(types.HarmBlockThreshold, setting["threshold"]) | |
| safety_settings.append(types.SafetySetting( | |
| category=category, | |
| threshold=threshold | |
| )) | |
| cog.default_gemini_config = types.GenerateContentConfig( | |
| system_instruction=gemini_config.get("system_instruction", ""), | |
| top_k=gemini_config.get("top_k", 55), | |
| top_p=gemini_config.get("top_p", 0.95), | |
| temperature=gemini_config.get("temperature", 1.3), | |
| safety_settings=safety_settings | |
| ) | |
| except Exception as e: | |
| print(f"Error loading Gemini config: {e}") | |
| await bot.add_cog(cog) | |
| print(cpr.success("Cog loaded: Gemini")) | |