| |
|
|
| import time |
|
|
| import openai |
| import openai.error |
|
|
| from bot.bot import Bot |
| from bot.openai.open_ai_image import OpenAIImage |
| from bot.openai.open_ai_session import OpenAISession |
| from bot.session_manager import SessionManager |
| from bridge.context import ContextType |
| from bridge.reply import Reply, ReplyType |
| from common.log import logger |
| from config import conf |
|
|
| user_session = dict() |
|
|
|
|
| |
| class OpenAIBot(Bot, OpenAIImage): |
| def __init__(self): |
| super().__init__() |
| openai.api_key = conf().get("open_ai_api_key") |
| if conf().get("open_ai_api_base"): |
| openai.api_base = conf().get("open_ai_api_base") |
| proxy = conf().get("proxy") |
| if proxy: |
| openai.proxy = proxy |
|
|
| self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003") |
| self.args = { |
| "model": conf().get("model") or "text-davinci-003", |
| "temperature": conf().get("temperature", 0.9), |
| "max_tokens": 1200, |
| "top_p": 1, |
| "frequency_penalty": conf().get("frequency_penalty", 0.0), |
| "presence_penalty": conf().get("presence_penalty", 0.0), |
| "request_timeout": conf().get("request_timeout", None), |
| "timeout": conf().get("request_timeout", None), |
| "stop": ["\n\n\n"], |
| } |
|
|
| def reply(self, query, context=None): |
| |
| if context and context.type: |
| if context.type == ContextType.TEXT: |
| logger.info("[OPEN_AI] query={}".format(query)) |
| session_id = context["session_id"] |
| reply = None |
| if query == "#清除记忆": |
| self.sessions.clear_session(session_id) |
| reply = Reply(ReplyType.INFO, "记忆已清除") |
| elif query == "#清除所有": |
| self.sessions.clear_all_session() |
| reply = Reply(ReplyType.INFO, "所有人记忆已清除") |
| else: |
| session = self.sessions.session_query(query, session_id) |
| result = self.reply_text(session) |
| total_tokens, completion_tokens, reply_content = ( |
| result["total_tokens"], |
| result["completion_tokens"], |
| result["content"], |
| ) |
| logger.debug( |
| "[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens) |
| ) |
|
|
| if total_tokens == 0: |
| reply = Reply(ReplyType.ERROR, reply_content) |
| else: |
| self.sessions.session_reply(reply_content, session_id, total_tokens) |
| reply = Reply(ReplyType.TEXT, reply_content) |
| return reply |
| elif context.type == ContextType.IMAGE_CREATE: |
| ok, retstring = self.create_img(query, 0) |
| reply = None |
| if ok: |
| reply = Reply(ReplyType.IMAGE_URL, retstring) |
| else: |
| reply = Reply(ReplyType.ERROR, retstring) |
| return reply |
|
|
| def reply_text(self, session: OpenAISession, retry_count=0): |
| try: |
| response = openai.Completion.create(prompt=str(session), **self.args) |
| res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "") |
| total_tokens = response["usage"]["total_tokens"] |
| completion_tokens = response["usage"]["completion_tokens"] |
| logger.info("[OPEN_AI] reply={}".format(res_content)) |
| return { |
| "total_tokens": total_tokens, |
| "completion_tokens": completion_tokens, |
| "content": res_content, |
| } |
| except Exception as e: |
| need_retry = retry_count < 2 |
| result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} |
| if isinstance(e, openai.error.RateLimitError): |
| logger.warn("[OPEN_AI] RateLimitError: {}".format(e)) |
| result["content"] = "提问太快啦,请休息一下再问我吧" |
| if need_retry: |
| time.sleep(20) |
| elif isinstance(e, openai.error.Timeout): |
| logger.warn("[OPEN_AI] Timeout: {}".format(e)) |
| result["content"] = "我没有收到你的消息" |
| if need_retry: |
| time.sleep(5) |
| elif isinstance(e, openai.error.APIConnectionError): |
| logger.warn("[OPEN_AI] APIConnectionError: {}".format(e)) |
| need_retry = False |
| result["content"] = "我连接不到你的网络" |
| else: |
| logger.warn("[OPEN_AI] Exception: {}".format(e)) |
| need_retry = False |
| self.sessions.clear_session(session.session_id) |
|
|
| if need_retry: |
| logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1)) |
| return self.reply_text(session, retry_count + 1) |
| else: |
| return result |
|
|