| |
|
|
| from bot.bot import Bot |
| from config import conf |
| from common.log import logger |
| import openai |
| import time |
|
|
| user_session = dict() |
|
|
| |
| class ChatGPTBot(Bot): |
| def __init__(self): |
| openai.api_key = conf().get('open_ai_api_key') |
| openai.api_base="https://api.gueai.com/v1" |
|
|
| def reply(self, query, context=None): |
| |
| if not context or not context.get('type') or context.get('type') == 'TEXT': |
| logger.info("[OPEN_AI] query={}".format(query)) |
| from_user_id = context['from_user_id'] |
| if query == '#清除记忆': |
| Session.clear_session(from_user_id) |
| return '记忆已清除' |
|
|
| new_query = Session.build_session_query(query, from_user_id) |
| logger.debug("[OPEN_AI] session query={}".format(new_query)) |
|
|
| |
| |
| |
|
|
| reply_content = self.reply_text(new_query, from_user_id, 0) |
| logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content)) |
| if reply_content: |
| Session.save_session(query, reply_content, from_user_id) |
| return reply_content |
|
|
| elif context.get('type', None) == 'IMAGE_CREATE': |
| return self.create_img(query, 0) |
|
|
| def reply_text(self, query, user_id, retry_count=0): |
| try: |
| response = openai.ChatCompletion.create( |
| model="gpt-4o-2024-08-06", |
| messages=query, |
| temperature=0.5, |
| max_tokens=1500, |
| top_p=1, |
| frequency_penalty=0.5, |
| presence_penalty=0.5, |
| ) |
| |
| logger.info(response.choices[0]['message']['content']) |
| |
| return response.choices[0]['message']['content'] |
| except openai.error.RateLimitError as e: |
| |
| logger.warn(e) |
| if retry_count < 1: |
| time.sleep(5) |
| logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1)) |
| return self.reply_text(query, user_id, retry_count+1) |
| else: |
| return "慢点,先让时间停止一下,稍等" |
| except Exception as e: |
| |
| logger.exception(e) |
| Session.clear_session(user_id) |
| return "看到我说明链接出现问题。" |
|
|
| def create_img(self, query, retry_count=0): |
| try: |
| logger.info("[OPEN_AI] image_query={}".format(query)) |
| response = openai.Image.create( |
| prompt=query, |
| n=1, |
| size="1024x1024" |
| ) |
| image_url = response['data'][0]['url'] |
| logger.info("[OPEN_AI] image_url={}".format(image_url)) |
| return image_url |
| except openai.error.RateLimitError as e: |
| logger.warn(e) |
| if retry_count < 1: |
| time.sleep(5) |
| logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) |
| return self.reply_text(query, retry_count+1) |
| else: |
| return "慢点,先让时间停止一下,稍等。" |
| except Exception as e: |
| logger.exception(e) |
| return None |
|
|
| class Session(object): |
| @staticmethod |
| def build_session_query(query, user_id): |
| ''' |
| build query with conversation history |
| e.g. [ |
| {"role": "system", "content": "You are a helpful assistant,let's think step by step in multiple different ways."}, |
| {"role": "user", "content": "Who won the world series in 2020?"}, |
| {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, |
| {"role": "user", "content": "Where was it played?"} |
| ] |
| :param query: query content |
| :param user_id: from user id |
| :return: query content with conversaction |
| ''' |
| session = user_session.get(user_id, []) |
| if len(session) == 0: |
| system_prompt = conf().get("character_desc", "") |
| system_item = {'role': 'system', 'content': system_prompt} |
| session.append(system_item) |
| user_session[user_id] = session |
| user_item = {'role': 'user', 'content': query} |
| session.append(user_item) |
| return session |
|
|
| @staticmethod |
| def save_session(query, answer, user_id): |
| session = user_session.get(user_id) |
| if session: |
| |
| gpt_item = {'role': 'assistant', 'content': answer} |
| session.append(gpt_item) |
|
|
| @staticmethod |
| def clear_session(user_id): |
| user_session[user_id] = [] |
|
|
|
|