Upload 153 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +3 -0
- LICENSE +19 -0
- app.py +57 -0
- bot/baidu/baidu_unit_bot.py +36 -0
- bot/baidu/baidu_wenxin.py +104 -0
- bot/baidu/baidu_wenxin_session.py +53 -0
- bot/bot.py +17 -0
- bot/bot_factory.py +42 -0
- bot/chatgpt/chat_gpt_bot.py +193 -0
- bot/chatgpt/chat_gpt_session.py +100 -0
- bot/linkai/link_ai_bot.py +116 -0
- bot/openai/open_ai_bot.py +122 -0
- bot/openai/open_ai_image.py +42 -0
- bot/openai/open_ai_session.py +73 -0
- bot/session_manager.py +91 -0
- bot/xunfei/xunfei_spark_bot.py +250 -0
- bridge/bridge.py +66 -0
- bridge/context.py +63 -0
- bridge/reply.py +25 -0
- channel/channel.py +43 -0
- channel/channel_factory.py +36 -0
- channel/chat_channel.py +367 -0
- channel/chat_message.py +85 -0
- channel/terminal/terminal_channel.py +92 -0
- channel/wechat/wechat_channel.py +210 -0
- channel/wechat/wechat_message.py +84 -0
- channel/wechat/wechaty_channel.py +129 -0
- channel/wechat/wechaty_message.py +89 -0
- channel/wechatcom/README.md +85 -0
- channel/wechatcom/wechatcomapp_channel.py +178 -0
- channel/wechatcom/wechatcomapp_client.py +21 -0
- channel/wechatcom/wechatcomapp_message.py +52 -0
- channel/wechatmp/README.md +100 -0
- channel/wechatmp/active_reply.py +75 -0
- channel/wechatmp/common.py +27 -0
- channel/wechatmp/passive_reply.py +209 -0
- channel/wechatmp/wechatmp_channel.py +216 -0
- channel/wechatmp/wechatmp_client.py +49 -0
- channel/wechatmp/wechatmp_message.py +56 -0
- common/const.py +11 -0
- common/dequeue.py +33 -0
- common/expired_dict.py +42 -0
- common/log.py +38 -0
- common/package_manager.py +36 -0
- common/singleton.py +9 -0
- common/sorted_dict.py +65 -0
- common/time_check.py +42 -0
- common/tmp_dir.py +18 -0
- common/token_bucket.py +45 -0
- common/utils.py +51 -0
Dockerfile
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM ghcr.io/zhayujie/chatgpt-on-wechat:latest
|
| 2 |
+
|
| 3 |
+
ENTRYPOINT ["/entrypoint.sh"]
|
LICENSE
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright (c) 2022 zhayujie
|
| 2 |
+
|
| 3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 4 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 5 |
+
in the Software without restriction, including without limitation the rights
|
| 6 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 7 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 8 |
+
furnished to do so, subject to the following conditions:
|
| 9 |
+
|
| 10 |
+
The above copyright notice and this permission notice shall be included in all
|
| 11 |
+
copies or substantial portions of the Software.
|
| 12 |
+
|
| 13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 14 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 15 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 16 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 17 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 18 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 19 |
+
SOFTWARE.
|
app.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# encoding:utf-8
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import signal
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
from channel import channel_factory
|
| 8 |
+
from common.log import logger
|
| 9 |
+
from config import conf, load_config
|
| 10 |
+
from plugins import *
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def sigterm_handler_wrap(_signo):
|
| 14 |
+
old_handler = signal.getsignal(_signo)
|
| 15 |
+
|
| 16 |
+
def func(_signo, _stack_frame):
|
| 17 |
+
logger.info("signal {} received, exiting...".format(_signo))
|
| 18 |
+
conf().save_user_datas()
|
| 19 |
+
if callable(old_handler): # check old_handler
|
| 20 |
+
return old_handler(_signo, _stack_frame)
|
| 21 |
+
sys.exit(0)
|
| 22 |
+
|
| 23 |
+
signal.signal(_signo, func)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def run():
|
| 27 |
+
try:
|
| 28 |
+
# load config
|
| 29 |
+
load_config()
|
| 30 |
+
# ctrl + c
|
| 31 |
+
sigterm_handler_wrap(signal.SIGINT)
|
| 32 |
+
# kill signal
|
| 33 |
+
sigterm_handler_wrap(signal.SIGTERM)
|
| 34 |
+
|
| 35 |
+
# create channel
|
| 36 |
+
channel_name = conf().get("channel_type", "wx")
|
| 37 |
+
|
| 38 |
+
if "--cmd" in sys.argv:
|
| 39 |
+
channel_name = "terminal"
|
| 40 |
+
|
| 41 |
+
if channel_name == "wxy":
|
| 42 |
+
os.environ["WECHATY_LOG"] = "warn"
|
| 43 |
+
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
|
| 44 |
+
|
| 45 |
+
channel = channel_factory.create_channel(channel_name)
|
| 46 |
+
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app"]:
|
| 47 |
+
PluginManager().load_plugins()
|
| 48 |
+
|
| 49 |
+
# startup channel
|
| 50 |
+
channel.startup()
|
| 51 |
+
except Exception as e:
|
| 52 |
+
logger.error("App startup failed!")
|
| 53 |
+
logger.exception(e)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
run()
|
bot/baidu/baidu_unit_bot.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# encoding:utf-8
|
| 2 |
+
|
| 3 |
+
import requests
|
| 4 |
+
|
| 5 |
+
from bot.bot import Bot
|
| 6 |
+
from bridge.reply import Reply, ReplyType
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Baidu Unit对话接口 (可用, 但能力较弱)
|
| 10 |
+
class BaiduUnitBot(Bot):
|
| 11 |
+
def reply(self, query, context=None):
|
| 12 |
+
token = self.get_token()
|
| 13 |
+
url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + token
|
| 14 |
+
post_data = (
|
| 15 |
+
'{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"'
|
| 16 |
+
+ query
|
| 17 |
+
+ '", "hyper_params": {"chat_custom_bot_profile": 1}}}'
|
| 18 |
+
)
|
| 19 |
+
print(post_data)
|
| 20 |
+
headers = {"content-type": "application/x-www-form-urlencoded"}
|
| 21 |
+
response = requests.post(url, data=post_data.encode(), headers=headers)
|
| 22 |
+
if response:
|
| 23 |
+
reply = Reply(
|
| 24 |
+
ReplyType.TEXT,
|
| 25 |
+
response.json()["result"]["context"]["SYS_PRESUMED_HIST"][1],
|
| 26 |
+
)
|
| 27 |
+
return reply
|
| 28 |
+
|
| 29 |
+
def get_token(self):
|
| 30 |
+
access_key = "YOUR_ACCESS_KEY"
|
| 31 |
+
secret_key = "YOUR_SECRET_KEY"
|
| 32 |
+
host = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + access_key + "&client_secret=" + secret_key
|
| 33 |
+
response = requests.get(host)
|
| 34 |
+
if response:
|
| 35 |
+
print(response.json())
|
| 36 |
+
return response.json()["access_token"]
|
bot/baidu/baidu_wenxin.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# encoding:utf-8
|
| 2 |
+
|
| 3 |
+
import requests, json
|
| 4 |
+
from bot.bot import Bot
|
| 5 |
+
from bot.session_manager import SessionManager
|
| 6 |
+
from bridge.context import ContextType
|
| 7 |
+
from bridge.reply import Reply, ReplyType
|
| 8 |
+
from common.log import logger
|
| 9 |
+
from config import conf
|
| 10 |
+
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
| 11 |
+
|
| 12 |
+
BAIDU_API_KEY = conf().get("baidu_wenxin_api_key")
|
| 13 |
+
BAIDU_SECRET_KEY = conf().get("baidu_wenxin_secret_key")
|
| 14 |
+
|
| 15 |
+
class BaiduWenxinBot(Bot):
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("baidu_wenxin_model") or "eb-instant")
|
| 20 |
+
|
| 21 |
+
def reply(self, query, context=None):
|
| 22 |
+
# acquire reply content
|
| 23 |
+
if context and context.type:
|
| 24 |
+
if context.type == ContextType.TEXT:
|
| 25 |
+
logger.info("[BAIDU] query={}".format(query))
|
| 26 |
+
session_id = context["session_id"]
|
| 27 |
+
reply = None
|
| 28 |
+
if query == "#清除记忆":
|
| 29 |
+
self.sessions.clear_session(session_id)
|
| 30 |
+
reply = Reply(ReplyType.INFO, "记忆已清除")
|
| 31 |
+
elif query == "#清除所有":
|
| 32 |
+
self.sessions.clear_all_session()
|
| 33 |
+
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
| 34 |
+
else:
|
| 35 |
+
session = self.sessions.session_query(query, session_id)
|
| 36 |
+
result = self.reply_text(session)
|
| 37 |
+
total_tokens, completion_tokens, reply_content = (
|
| 38 |
+
result["total_tokens"],
|
| 39 |
+
result["completion_tokens"],
|
| 40 |
+
result["content"],
|
| 41 |
+
)
|
| 42 |
+
logger.debug(
|
| 43 |
+
"[BAIDU] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content, completion_tokens)
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
if total_tokens == 0:
|
| 47 |
+
reply = Reply(ReplyType.ERROR, reply_content)
|
| 48 |
+
else:
|
| 49 |
+
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
| 50 |
+
reply = Reply(ReplyType.TEXT, reply_content)
|
| 51 |
+
return reply
|
| 52 |
+
elif context.type == ContextType.IMAGE_CREATE:
|
| 53 |
+
ok, retstring = self.create_img(query, 0)
|
| 54 |
+
reply = None
|
| 55 |
+
if ok:
|
| 56 |
+
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
| 57 |
+
else:
|
| 58 |
+
reply = Reply(ReplyType.ERROR, retstring)
|
| 59 |
+
return reply
|
| 60 |
+
|
| 61 |
+
def reply_text(self, session: BaiduWenxinSession, retry_count=0):
|
| 62 |
+
try:
|
| 63 |
+
logger.info("[BAIDU] model={}".format(session.model))
|
| 64 |
+
access_token = self.get_access_token()
|
| 65 |
+
if access_token == 'None':
|
| 66 |
+
logger.warn("[BAIDU] access token 获取失败")
|
| 67 |
+
return {
|
| 68 |
+
"total_tokens": 0,
|
| 69 |
+
"completion_tokens": 0,
|
| 70 |
+
"content": 0,
|
| 71 |
+
}
|
| 72 |
+
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + session.model + "?access_token=" + access_token
|
| 73 |
+
headers = {
|
| 74 |
+
'Content-Type': 'application/json'
|
| 75 |
+
}
|
| 76 |
+
payload = {'messages': session.messages}
|
| 77 |
+
response = requests.request("POST", url, headers=headers, data=json.dumps(payload))
|
| 78 |
+
response_text = json.loads(response.text)
|
| 79 |
+
logger.info(f"[BAIDU] response text={response_text}")
|
| 80 |
+
res_content = response_text["result"]
|
| 81 |
+
total_tokens = response_text["usage"]["total_tokens"]
|
| 82 |
+
completion_tokens = response_text["usage"]["completion_tokens"]
|
| 83 |
+
logger.info("[BAIDU] reply={}".format(res_content))
|
| 84 |
+
return {
|
| 85 |
+
"total_tokens": total_tokens,
|
| 86 |
+
"completion_tokens": completion_tokens,
|
| 87 |
+
"content": res_content,
|
| 88 |
+
}
|
| 89 |
+
except Exception as e:
|
| 90 |
+
need_retry = retry_count < 2
|
| 91 |
+
logger.warn("[BAIDU] Exception: {}".format(e))
|
| 92 |
+
need_retry = False
|
| 93 |
+
self.sessions.clear_session(session.session_id)
|
| 94 |
+
result = {"completion_tokens": 0, "content": "出错了: {}".format(e)}
|
| 95 |
+
return result
|
| 96 |
+
|
| 97 |
+
def get_access_token(self):
|
| 98 |
+
"""
|
| 99 |
+
使用 AK,SK 生成鉴权签名(Access Token)
|
| 100 |
+
:return: access_token,或是None(如果错误)
|
| 101 |
+
"""
|
| 102 |
+
url = "https://aip.baidubce.com/oauth/2.0/token"
|
| 103 |
+
params = {"grant_type": "client_credentials", "client_id": BAIDU_API_KEY, "client_secret": BAIDU_SECRET_KEY}
|
| 104 |
+
return str(requests.post(url, params=params).json().get("access_token"))
|
bot/baidu/baidu_wenxin_session.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from bot.session_manager import Session
|
| 2 |
+
from common.log import logger
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
e.g. [
|
| 6 |
+
{"role": "user", "content": "Who won the world series in 2020?"},
|
| 7 |
+
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
| 8 |
+
{"role": "user", "content": "Where was it played?"}
|
| 9 |
+
]
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BaiduWenxinSession(Session):
|
| 14 |
+
def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
|
| 15 |
+
super().__init__(session_id, system_prompt)
|
| 16 |
+
self.model = model
|
| 17 |
+
# 百度文心不支持system prompt
|
| 18 |
+
# self.reset()
|
| 19 |
+
|
| 20 |
+
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
| 21 |
+
precise = True
|
| 22 |
+
try:
|
| 23 |
+
cur_tokens = self.calc_tokens()
|
| 24 |
+
except Exception as e:
|
| 25 |
+
precise = False
|
| 26 |
+
if cur_tokens is None:
|
| 27 |
+
raise e
|
| 28 |
+
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
| 29 |
+
while cur_tokens > max_tokens:
|
| 30 |
+
if len(self.messages) >= 2:
|
| 31 |
+
self.messages.pop(0)
|
| 32 |
+
self.messages.pop(0)
|
| 33 |
+
else:
|
| 34 |
+
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
| 35 |
+
break
|
| 36 |
+
if precise:
|
| 37 |
+
cur_tokens = self.calc_tokens()
|
| 38 |
+
else:
|
| 39 |
+
cur_tokens = cur_tokens - max_tokens
|
| 40 |
+
return cur_tokens
|
| 41 |
+
|
| 42 |
+
def calc_tokens(self):
|
| 43 |
+
return num_tokens_from_messages(self.messages, self.model)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def num_tokens_from_messages(messages, model):
|
| 47 |
+
"""Returns the number of tokens used by a list of messages."""
|
| 48 |
+
tokens = 0
|
| 49 |
+
for msg in messages:
|
| 50 |
+
# 官方token计算规则暂不明确: "大约为 token数为 "中文字 + 其他语种单词数 x 1.3"
|
| 51 |
+
# 这里先直接根据字数粗略估算吧,暂不影响正常使用,仅在判断是否丢弃历史会话的时候会有偏差
|
| 52 |
+
tokens += len(msg["content"])
|
| 53 |
+
return tokens
|
bot/bot.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Auto-replay chat robot abstract class
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from bridge.context import Context
|
| 7 |
+
from bridge.reply import Reply
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Bot(object):
|
| 11 |
+
def reply(self, query, context: Context = None) -> Reply:
|
| 12 |
+
"""
|
| 13 |
+
bot auto-reply content
|
| 14 |
+
:param req: received message
|
| 15 |
+
:return: reply content
|
| 16 |
+
"""
|
| 17 |
+
raise NotImplementedError
|
bot/bot_factory.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
channel factory
|
| 3 |
+
"""
|
| 4 |
+
from common import const
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def create_bot(bot_type):
|
| 8 |
+
"""
|
| 9 |
+
create a bot_type instance
|
| 10 |
+
:param bot_type: bot type code
|
| 11 |
+
:return: bot instance
|
| 12 |
+
"""
|
| 13 |
+
if bot_type == const.BAIDU:
|
| 14 |
+
# 替换Baidu Unit为Baidu文心千帆对话接口
|
| 15 |
+
# from bot.baidu.baidu_unit_bot import BaiduUnitBot
|
| 16 |
+
# return BaiduUnitBot()
|
| 17 |
+
from bot.baidu.baidu_wenxin import BaiduWenxinBot
|
| 18 |
+
return BaiduWenxinBot()
|
| 19 |
+
|
| 20 |
+
elif bot_type == const.CHATGPT:
|
| 21 |
+
# ChatGPT 网页端web接口
|
| 22 |
+
from bot.chatgpt.chat_gpt_bot import ChatGPTBot
|
| 23 |
+
return ChatGPTBot()
|
| 24 |
+
|
| 25 |
+
elif bot_type == const.OPEN_AI:
|
| 26 |
+
# OpenAI 官方对话模型API
|
| 27 |
+
from bot.openai.open_ai_bot import OpenAIBot
|
| 28 |
+
return OpenAIBot()
|
| 29 |
+
|
| 30 |
+
elif bot_type == const.CHATGPTONAZURE:
|
| 31 |
+
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
|
| 32 |
+
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
|
| 33 |
+
return AzureChatGPTBot()
|
| 34 |
+
|
| 35 |
+
elif bot_type == const.XUNFEI:
|
| 36 |
+
from bot.xunfei.xunfei_spark_bot import XunFeiBot
|
| 37 |
+
return XunFeiBot()
|
| 38 |
+
|
| 39 |
+
elif bot_type == const.LINKAI:
|
| 40 |
+
from bot.linkai.link_ai_bot import LinkAIBot
|
| 41 |
+
return LinkAIBot()
|
| 42 |
+
raise RuntimeError
|
bot/chatgpt/chat_gpt_bot.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# encoding:utf-8
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import openai
|
| 6 |
+
import openai.error
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
from bot.bot import Bot
|
| 10 |
+
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
| 11 |
+
from bot.openai.open_ai_image import OpenAIImage
|
| 12 |
+
from bot.session_manager import SessionManager
|
| 13 |
+
from bridge.context import ContextType
|
| 14 |
+
from bridge.reply import Reply, ReplyType
|
| 15 |
+
from common.log import logger
|
| 16 |
+
from common.token_bucket import TokenBucket
|
| 17 |
+
from config import conf, load_config
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# OpenAI对话模型API (可用)
|
| 21 |
+
class ChatGPTBot(Bot, OpenAIImage):
|
| 22 |
+
def __init__(self):
|
| 23 |
+
super().__init__()
|
| 24 |
+
# set the default api_key
|
| 25 |
+
openai.api_key = conf().get("open_ai_api_key")
|
| 26 |
+
if conf().get("open_ai_api_base"):
|
| 27 |
+
openai.api_base = conf().get("open_ai_api_base")
|
| 28 |
+
proxy = conf().get("proxy")
|
| 29 |
+
if proxy:
|
| 30 |
+
openai.proxy = proxy
|
| 31 |
+
if conf().get("rate_limit_chatgpt"):
|
| 32 |
+
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
|
| 33 |
+
|
| 34 |
+
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
|
| 35 |
+
self.args = {
|
| 36 |
+
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
| 37 |
+
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
| 38 |
+
# "max_tokens":4096, # 回复最大的字符数
|
| 39 |
+
"top_p": conf().get("top_p", 1),
|
| 40 |
+
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
| 41 |
+
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
| 42 |
+
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
| 43 |
+
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
def reply(self, query, context=None):
|
| 47 |
+
# acquire reply content
|
| 48 |
+
if context.type == ContextType.TEXT:
|
| 49 |
+
logger.info("[CHATGPT] query={}".format(query))
|
| 50 |
+
|
| 51 |
+
session_id = context["session_id"]
|
| 52 |
+
reply = None
|
| 53 |
+
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
| 54 |
+
if query in clear_memory_commands:
|
| 55 |
+
self.sessions.clear_session(session_id)
|
| 56 |
+
reply = Reply(ReplyType.INFO, "记忆已清除")
|
| 57 |
+
elif query == "#清除所有":
|
| 58 |
+
self.sessions.clear_all_session()
|
| 59 |
+
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
| 60 |
+
elif query == "#更新配置":
|
| 61 |
+
load_config()
|
| 62 |
+
reply = Reply(ReplyType.INFO, "配置已更新")
|
| 63 |
+
if reply:
|
| 64 |
+
return reply
|
| 65 |
+
session = self.sessions.session_query(query, session_id)
|
| 66 |
+
logger.debug("[CHATGPT] session query={}".format(session.messages))
|
| 67 |
+
|
| 68 |
+
api_key = context.get("openai_api_key")
|
| 69 |
+
model = context.get("gpt_model")
|
| 70 |
+
new_args = None
|
| 71 |
+
if model:
|
| 72 |
+
new_args = self.args.copy()
|
| 73 |
+
new_args["model"] = model
|
| 74 |
+
# if context.get('stream'):
|
| 75 |
+
# # reply in stream
|
| 76 |
+
# return self.reply_text_stream(query, new_query, session_id)
|
| 77 |
+
|
| 78 |
+
reply_content = self.reply_text(session, api_key, args=new_args)
|
| 79 |
+
logger.debug(
|
| 80 |
+
"[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
| 81 |
+
session.messages,
|
| 82 |
+
session_id,
|
| 83 |
+
reply_content["content"],
|
| 84 |
+
reply_content["completion_tokens"],
|
| 85 |
+
)
|
| 86 |
+
)
|
| 87 |
+
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
| 88 |
+
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
| 89 |
+
elif reply_content["completion_tokens"] > 0:
|
| 90 |
+
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
| 91 |
+
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
| 92 |
+
else:
|
| 93 |
+
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
| 94 |
+
logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content))
|
| 95 |
+
return reply
|
| 96 |
+
|
| 97 |
+
elif context.type == ContextType.IMAGE_CREATE:
|
| 98 |
+
ok, retstring = self.create_img(query, 0)
|
| 99 |
+
reply = None
|
| 100 |
+
if ok:
|
| 101 |
+
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
| 102 |
+
else:
|
| 103 |
+
reply = Reply(ReplyType.ERROR, retstring)
|
| 104 |
+
return reply
|
| 105 |
+
else:
|
| 106 |
+
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
| 107 |
+
return reply
|
| 108 |
+
|
| 109 |
+
def reply_text(self, session: ChatGPTSession, api_key=None, args=None, retry_count=0) -> dict:
|
| 110 |
+
"""
|
| 111 |
+
call openai's ChatCompletion to get the answer
|
| 112 |
+
:param session: a conversation session
|
| 113 |
+
:param session_id: session id
|
| 114 |
+
:param retry_count: retry count
|
| 115 |
+
:return: {}
|
| 116 |
+
"""
|
| 117 |
+
try:
|
| 118 |
+
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
|
| 119 |
+
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
| 120 |
+
# if api_key == None, the default openai.api_key will be used
|
| 121 |
+
if args is None:
|
| 122 |
+
args = self.args
|
| 123 |
+
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
|
| 124 |
+
# logger.debug("[CHATGPT] response={}".format(response))
|
| 125 |
+
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
| 126 |
+
return {
|
| 127 |
+
"total_tokens": response["usage"]["total_tokens"],
|
| 128 |
+
"completion_tokens": response["usage"]["completion_tokens"],
|
| 129 |
+
"content": response.choices[0]["message"]["content"],
|
| 130 |
+
}
|
| 131 |
+
except Exception as e:
|
| 132 |
+
need_retry = retry_count < 2
|
| 133 |
+
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
| 134 |
+
if isinstance(e, openai.error.RateLimitError):
|
| 135 |
+
logger.warn("[CHATGPT] RateLimitError: {}".format(e))
|
| 136 |
+
result["content"] = "提问太快啦,请休息一下再问我吧"
|
| 137 |
+
if need_retry:
|
| 138 |
+
time.sleep(20)
|
| 139 |
+
elif isinstance(e, openai.error.Timeout):
|
| 140 |
+
logger.warn("[CHATGPT] Timeout: {}".format(e))
|
| 141 |
+
result["content"] = "我没有收到你的消息"
|
| 142 |
+
if need_retry:
|
| 143 |
+
time.sleep(5)
|
| 144 |
+
elif isinstance(e, openai.error.APIError):
|
| 145 |
+
logger.warn("[CHATGPT] Bad Gateway: {}".format(e))
|
| 146 |
+
result["content"] = "请再问我一次"
|
| 147 |
+
if need_retry:
|
| 148 |
+
time.sleep(10)
|
| 149 |
+
elif isinstance(e, openai.error.APIConnectionError):
|
| 150 |
+
logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
|
| 151 |
+
need_retry = False
|
| 152 |
+
result["content"] = "我连接不到你的网络"
|
| 153 |
+
else:
|
| 154 |
+
logger.exception("[CHATGPT] Exception: {}".format(e))
|
| 155 |
+
need_retry = False
|
| 156 |
+
self.sessions.clear_session(session.session_id)
|
| 157 |
+
|
| 158 |
+
if need_retry:
|
| 159 |
+
logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1))
|
| 160 |
+
return self.reply_text(session, api_key, args, retry_count + 1)
|
| 161 |
+
else:
|
| 162 |
+
return result
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class AzureChatGPTBot(ChatGPTBot):
|
| 166 |
+
def __init__(self):
|
| 167 |
+
super().__init__()
|
| 168 |
+
openai.api_type = "azure"
|
| 169 |
+
openai.api_version = conf().get("azure_api_version", "2023-06-01-preview")
|
| 170 |
+
self.args["deployment_id"] = conf().get("azure_deployment_id")
|
| 171 |
+
|
| 172 |
+
def create_img(self, query, retry_count=0, api_key=None):
|
| 173 |
+
api_version = "2022-08-03-preview"
|
| 174 |
+
url = "{}dalle/text-to-image?api-version={}".format(openai.api_base, api_version)
|
| 175 |
+
api_key = api_key or openai.api_key
|
| 176 |
+
headers = {"api-key": api_key, "Content-Type": "application/json"}
|
| 177 |
+
try:
|
| 178 |
+
body = {"caption": query, "resolution": conf().get("image_create_size", "256x256")}
|
| 179 |
+
submission = requests.post(url, headers=headers, json=body)
|
| 180 |
+
operation_location = submission.headers["Operation-Location"]
|
| 181 |
+
retry_after = submission.headers["Retry-after"]
|
| 182 |
+
status = ""
|
| 183 |
+
image_url = ""
|
| 184 |
+
while status != "Succeeded":
|
| 185 |
+
logger.info("waiting for image create..., " + status + ",retry after " + retry_after + " seconds")
|
| 186 |
+
time.sleep(int(retry_after))
|
| 187 |
+
response = requests.get(operation_location, headers=headers)
|
| 188 |
+
status = response.json()["status"]
|
| 189 |
+
image_url = response.json()["result"]["contentUrl"]
|
| 190 |
+
return True, image_url
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.error("create image error: {}".format(e))
|
| 193 |
+
return False, "图片生成失败"
|
bot/chatgpt/chat_gpt_session.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from bot.session_manager import Session
|
| 2 |
+
from common.log import logger
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
e.g. [
|
| 6 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 7 |
+
{"role": "user", "content": "Who won the world series in 2020?"},
|
| 8 |
+
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
| 9 |
+
{"role": "user", "content": "Where was it played?"}
|
| 10 |
+
]
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ChatGPTSession(Session):
|
| 15 |
+
def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
|
| 16 |
+
super().__init__(session_id, system_prompt)
|
| 17 |
+
self.model = model
|
| 18 |
+
self.reset()
|
| 19 |
+
|
| 20 |
+
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
| 21 |
+
precise = True
|
| 22 |
+
try:
|
| 23 |
+
cur_tokens = self.calc_tokens()
|
| 24 |
+
except Exception as e:
|
| 25 |
+
precise = False
|
| 26 |
+
if cur_tokens is None:
|
| 27 |
+
raise e
|
| 28 |
+
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
| 29 |
+
while cur_tokens > max_tokens:
|
| 30 |
+
if len(self.messages) > 2:
|
| 31 |
+
self.messages.pop(1)
|
| 32 |
+
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
| 33 |
+
self.messages.pop(1)
|
| 34 |
+
if precise:
|
| 35 |
+
cur_tokens = self.calc_tokens()
|
| 36 |
+
else:
|
| 37 |
+
cur_tokens = cur_tokens - max_tokens
|
| 38 |
+
break
|
| 39 |
+
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
| 40 |
+
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
| 41 |
+
break
|
| 42 |
+
else:
|
| 43 |
+
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
| 44 |
+
break
|
| 45 |
+
if precise:
|
| 46 |
+
cur_tokens = self.calc_tokens()
|
| 47 |
+
else:
|
| 48 |
+
cur_tokens = cur_tokens - max_tokens
|
| 49 |
+
return cur_tokens
|
| 50 |
+
|
| 51 |
+
def calc_tokens(self):
|
| 52 |
+
return num_tokens_from_messages(self.messages, self.model)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
| 56 |
+
def num_tokens_from_messages(messages, model):
|
| 57 |
+
"""Returns the number of tokens used by a list of messages."""
|
| 58 |
+
|
| 59 |
+
if model in ["wenxin", "xunfei"]:
|
| 60 |
+
return num_tokens_by_character(messages)
|
| 61 |
+
|
| 62 |
+
import tiktoken
|
| 63 |
+
|
| 64 |
+
if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo"]:
|
| 65 |
+
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
|
| 66 |
+
elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613",
|
| 67 |
+
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k"]:
|
| 68 |
+
return num_tokens_from_messages(messages, model="gpt-4")
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
encoding = tiktoken.encoding_for_model(model)
|
| 72 |
+
except KeyError:
|
| 73 |
+
logger.debug("Warning: model not found. Using cl100k_base encoding.")
|
| 74 |
+
encoding = tiktoken.get_encoding("cl100k_base")
|
| 75 |
+
if model == "gpt-3.5-turbo":
|
| 76 |
+
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
| 77 |
+
tokens_per_name = -1 # if there's a name, the role is omitted
|
| 78 |
+
elif model == "gpt-4":
|
| 79 |
+
tokens_per_message = 3
|
| 80 |
+
tokens_per_name = 1
|
| 81 |
+
else:
|
| 82 |
+
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo.")
|
| 83 |
+
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
|
| 84 |
+
num_tokens = 0
|
| 85 |
+
for message in messages:
|
| 86 |
+
num_tokens += tokens_per_message
|
| 87 |
+
for key, value in message.items():
|
| 88 |
+
num_tokens += len(encoding.encode(value))
|
| 89 |
+
if key == "name":
|
| 90 |
+
num_tokens += tokens_per_name
|
| 91 |
+
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
| 92 |
+
return num_tokens
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def num_tokens_by_character(messages):
|
| 96 |
+
"""Returns the number of tokens used by a list of messages."""
|
| 97 |
+
tokens = 0
|
| 98 |
+
for msg in messages:
|
| 99 |
+
tokens += len(msg["content"])
|
| 100 |
+
return tokens
|
bot/linkai/link_ai_bot.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# access LinkAI knowledge base platform
|
| 2 |
+
# docs: https://link-ai.tech/platform/link-app/wechat
|
| 3 |
+
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import requests
|
| 7 |
+
|
| 8 |
+
from bot.bot import Bot
|
| 9 |
+
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
| 10 |
+
from bot.openai.open_ai_image import OpenAIImage
|
| 11 |
+
from bot.session_manager import SessionManager
|
| 12 |
+
from bridge.context import Context, ContextType
|
| 13 |
+
from bridge.reply import Reply, ReplyType
|
| 14 |
+
from common.log import logger
|
| 15 |
+
from config import conf
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class LinkAIBot(Bot, OpenAIImage):
|
| 19 |
+
# authentication failed
|
| 20 |
+
AUTH_FAILED_CODE = 401
|
| 21 |
+
NO_QUOTA_CODE = 406
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
|
| 26 |
+
|
| 27 |
+
def reply(self, query, context: Context = None) -> Reply:
|
| 28 |
+
if context.type == ContextType.TEXT:
|
| 29 |
+
return self._chat(query, context)
|
| 30 |
+
elif context.type == ContextType.IMAGE_CREATE:
|
| 31 |
+
ok, res = self.create_img(query, 0)
|
| 32 |
+
if ok:
|
| 33 |
+
reply = Reply(ReplyType.IMAGE_URL, res)
|
| 34 |
+
else:
|
| 35 |
+
reply = Reply(ReplyType.ERROR, res)
|
| 36 |
+
return reply
|
| 37 |
+
else:
|
| 38 |
+
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
| 39 |
+
return reply
|
| 40 |
+
|
| 41 |
+
def _chat(self, query, context, retry_count=0) -> Reply:
|
| 42 |
+
"""
|
| 43 |
+
发起对话请求
|
| 44 |
+
:param query: 请求提示词
|
| 45 |
+
:param context: 对话上下文
|
| 46 |
+
:param retry_count: 当前递归重试次数
|
| 47 |
+
:return: 回复
|
| 48 |
+
"""
|
| 49 |
+
if retry_count >= 2:
|
| 50 |
+
# exit from retry 2 times
|
| 51 |
+
logger.warn("[LINKAI] failed after maximum number of retry times")
|
| 52 |
+
return Reply(ReplyType.ERROR, "请再问我一次吧")
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
# load config
|
| 56 |
+
if context.get("generate_breaked_by"):
|
| 57 |
+
logger.info(f"[LINKAI] won't set appcode because a plugin ({context['generate_breaked_by']}) affected the context")
|
| 58 |
+
app_code = None
|
| 59 |
+
else:
|
| 60 |
+
app_code = context.kwargs.get("app_code") or conf().get("linkai_app_code")
|
| 61 |
+
linkai_api_key = conf().get("linkai_api_key")
|
| 62 |
+
|
| 63 |
+
session_id = context["session_id"]
|
| 64 |
+
|
| 65 |
+
session = self.sessions.session_query(query, session_id)
|
| 66 |
+
model = conf().get("model") or "gpt-3.5-turbo"
|
| 67 |
+
# remove system message
|
| 68 |
+
if session.messages[0].get("role") == "system":
|
| 69 |
+
if app_code or model == "wenxin":
|
| 70 |
+
session.messages.pop(0)
|
| 71 |
+
|
| 72 |
+
body = {
|
| 73 |
+
"app_code": app_code,
|
| 74 |
+
"messages": session.messages,
|
| 75 |
+
"model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin
|
| 76 |
+
"temperature": conf().get("temperature"),
|
| 77 |
+
"top_p": conf().get("top_p", 1),
|
| 78 |
+
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
| 79 |
+
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
| 80 |
+
}
|
| 81 |
+
logger.info(f"[LINKAI] query={query}, app_code={app_code}, mode={body.get('model')}")
|
| 82 |
+
headers = {"Authorization": "Bearer " + linkai_api_key}
|
| 83 |
+
|
| 84 |
+
# do http request
|
| 85 |
+
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
|
| 86 |
+
res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
|
| 87 |
+
timeout=conf().get("request_timeout", 180))
|
| 88 |
+
if res.status_code == 200:
|
| 89 |
+
# execute success
|
| 90 |
+
response = res.json()
|
| 91 |
+
reply_content = response["choices"][0]["message"]["content"]
|
| 92 |
+
total_tokens = response["usage"]["total_tokens"]
|
| 93 |
+
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
|
| 94 |
+
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
| 95 |
+
return Reply(ReplyType.TEXT, reply_content)
|
| 96 |
+
|
| 97 |
+
else:
|
| 98 |
+
response = res.json()
|
| 99 |
+
error = response.get("error")
|
| 100 |
+
logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
|
| 101 |
+
f"msg={error.get('message')}, type={error.get('type')}")
|
| 102 |
+
|
| 103 |
+
if res.status_code >= 500:
|
| 104 |
+
# server error, need retry
|
| 105 |
+
time.sleep(2)
|
| 106 |
+
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
| 107 |
+
return self._chat(query, context, retry_count + 1)
|
| 108 |
+
|
| 109 |
+
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
logger.exception(e)
|
| 113 |
+
# retry
|
| 114 |
+
time.sleep(2)
|
| 115 |
+
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
| 116 |
+
return self._chat(query, context, retry_count + 1)
|
bot/openai/open_ai_bot.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# encoding:utf-8
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import openai
|
| 6 |
+
import openai.error
|
| 7 |
+
|
| 8 |
+
from bot.bot import Bot
|
| 9 |
+
from bot.openai.open_ai_image import OpenAIImage
|
| 10 |
+
from bot.openai.open_ai_session import OpenAISession
|
| 11 |
+
from bot.session_manager import SessionManager
|
| 12 |
+
from bridge.context import ContextType
|
| 13 |
+
from bridge.reply import Reply, ReplyType
|
| 14 |
+
from common.log import logger
|
| 15 |
+
from config import conf
|
| 16 |
+
|
| 17 |
+
user_session = dict()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# OpenAI对话模型API (可用)
|
| 21 |
+
class OpenAIBot(Bot, OpenAIImage):
|
| 22 |
+
def __init__(self):
|
| 23 |
+
super().__init__()
|
| 24 |
+
openai.api_key = conf().get("open_ai_api_key")
|
| 25 |
+
if conf().get("open_ai_api_base"):
|
| 26 |
+
openai.api_base = conf().get("open_ai_api_base")
|
| 27 |
+
proxy = conf().get("proxy")
|
| 28 |
+
if proxy:
|
| 29 |
+
openai.proxy = proxy
|
| 30 |
+
|
| 31 |
+
self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003")
|
| 32 |
+
self.args = {
|
| 33 |
+
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称
|
| 34 |
+
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
| 35 |
+
"max_tokens": 1200, # 回复最大的字符数
|
| 36 |
+
"top_p": 1,
|
| 37 |
+
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
| 38 |
+
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
| 39 |
+
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
| 40 |
+
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
|
| 41 |
+
"stop": ["\n\n\n"],
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
def reply(self, query, context=None):
|
| 45 |
+
# acquire reply content
|
| 46 |
+
if context and context.type:
|
| 47 |
+
if context.type == ContextType.TEXT:
|
| 48 |
+
logger.info("[OPEN_AI] query={}".format(query))
|
| 49 |
+
session_id = context["session_id"]
|
| 50 |
+
reply = None
|
| 51 |
+
if query == "#清除记忆":
|
| 52 |
+
self.sessions.clear_session(session_id)
|
| 53 |
+
reply = Reply(ReplyType.INFO, "记忆已清除")
|
| 54 |
+
elif query == "#清除所有":
|
| 55 |
+
self.sessions.clear_all_session()
|
| 56 |
+
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
| 57 |
+
else:
|
| 58 |
+
session = self.sessions.session_query(query, session_id)
|
| 59 |
+
result = self.reply_text(session)
|
| 60 |
+
total_tokens, completion_tokens, reply_content = (
|
| 61 |
+
result["total_tokens"],
|
| 62 |
+
result["completion_tokens"],
|
| 63 |
+
result["content"],
|
| 64 |
+
)
|
| 65 |
+
logger.debug(
|
| 66 |
+
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
if total_tokens == 0:
|
| 70 |
+
reply = Reply(ReplyType.ERROR, reply_content)
|
| 71 |
+
else:
|
| 72 |
+
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
| 73 |
+
reply = Reply(ReplyType.TEXT, reply_content)
|
| 74 |
+
return reply
|
| 75 |
+
elif context.type == ContextType.IMAGE_CREATE:
|
| 76 |
+
ok, retstring = self.create_img(query, 0)
|
| 77 |
+
reply = None
|
| 78 |
+
if ok:
|
| 79 |
+
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
| 80 |
+
else:
|
| 81 |
+
reply = Reply(ReplyType.ERROR, retstring)
|
| 82 |
+
return reply
|
| 83 |
+
|
| 84 |
+
def reply_text(self, session: OpenAISession, retry_count=0):
|
| 85 |
+
try:
|
| 86 |
+
response = openai.Completion.create(prompt=str(session), **self.args)
|
| 87 |
+
res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "")
|
| 88 |
+
total_tokens = response["usage"]["total_tokens"]
|
| 89 |
+
completion_tokens = response["usage"]["completion_tokens"]
|
| 90 |
+
logger.info("[OPEN_AI] reply={}".format(res_content))
|
| 91 |
+
return {
|
| 92 |
+
"total_tokens": total_tokens,
|
| 93 |
+
"completion_tokens": completion_tokens,
|
| 94 |
+
"content": res_content,
|
| 95 |
+
}
|
| 96 |
+
except Exception as e:
|
| 97 |
+
need_retry = retry_count < 2
|
| 98 |
+
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
| 99 |
+
if isinstance(e, openai.error.RateLimitError):
|
| 100 |
+
logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
|
| 101 |
+
result["content"] = "提问太快啦,请休息一下再问我吧"
|
| 102 |
+
if need_retry:
|
| 103 |
+
time.sleep(20)
|
| 104 |
+
elif isinstance(e, openai.error.Timeout):
|
| 105 |
+
logger.warn("[OPEN_AI] Timeout: {}".format(e))
|
| 106 |
+
result["content"] = "我没有收到你的消息"
|
| 107 |
+
if need_retry:
|
| 108 |
+
time.sleep(5)
|
| 109 |
+
elif isinstance(e, openai.error.APIConnectionError):
|
| 110 |
+
logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
|
| 111 |
+
need_retry = False
|
| 112 |
+
result["content"] = "我连接不到你的网络"
|
| 113 |
+
else:
|
| 114 |
+
logger.warn("[OPEN_AI] Exception: {}".format(e))
|
| 115 |
+
need_retry = False
|
| 116 |
+
self.sessions.clear_session(session.session_id)
|
| 117 |
+
|
| 118 |
+
if need_retry:
|
| 119 |
+
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1))
|
| 120 |
+
return self.reply_text(session, retry_count + 1)
|
| 121 |
+
else:
|
| 122 |
+
return result
|
bot/openai/open_ai_image.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
import openai
|
| 4 |
+
import openai.error
|
| 5 |
+
|
| 6 |
+
from common.log import logger
|
| 7 |
+
from common.token_bucket import TokenBucket
|
| 8 |
+
from config import conf
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# OPENAI提供的画图接口
|
| 12 |
+
class OpenAIImage(object):
|
| 13 |
+
def __init__(self):
|
| 14 |
+
openai.api_key = conf().get("open_ai_api_key")
|
| 15 |
+
if conf().get("rate_limit_dalle"):
|
| 16 |
+
self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50))
|
| 17 |
+
|
| 18 |
+
def create_img(self, query, retry_count=0, api_key=None):
|
| 19 |
+
try:
|
| 20 |
+
if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token():
|
| 21 |
+
return False, "请求太快了,请休息一下再问我吧"
|
| 22 |
+
logger.info("[OPEN_AI] image_query={}".format(query))
|
| 23 |
+
response = openai.Image.create(
|
| 24 |
+
api_key=api_key,
|
| 25 |
+
prompt=query, # 图片描述
|
| 26 |
+
n=1, # 每次生成图片的数量
|
| 27 |
+
size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024
|
| 28 |
+
)
|
| 29 |
+
image_url = response["data"][0]["url"]
|
| 30 |
+
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
| 31 |
+
return True, image_url
|
| 32 |
+
except openai.error.RateLimitError as e:
|
| 33 |
+
logger.warn(e)
|
| 34 |
+
if retry_count < 1:
|
| 35 |
+
time.sleep(5)
|
| 36 |
+
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
|
| 37 |
+
return self.create_img(query, retry_count + 1)
|
| 38 |
+
else:
|
| 39 |
+
return False, "提问太快啦,请休息一下再问我吧"
|
| 40 |
+
except Exception as e:
|
| 41 |
+
logger.exception(e)
|
| 42 |
+
return False, str(e)
|
bot/openai/open_ai_session.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from bot.session_manager import Session
|
| 2 |
+
from common.log import logger
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class OpenAISession(Session):
|
| 6 |
+
def __init__(self, session_id, system_prompt=None, model="text-davinci-003"):
|
| 7 |
+
super().__init__(session_id, system_prompt)
|
| 8 |
+
self.model = model
|
| 9 |
+
self.reset()
|
| 10 |
+
|
| 11 |
+
def __str__(self):
|
| 12 |
+
# 构造对话模型的输入
|
| 13 |
+
"""
|
| 14 |
+
e.g. Q: xxx
|
| 15 |
+
A: xxx
|
| 16 |
+
Q: xxx
|
| 17 |
+
"""
|
| 18 |
+
prompt = ""
|
| 19 |
+
for item in self.messages:
|
| 20 |
+
if item["role"] == "system":
|
| 21 |
+
prompt += item["content"] + "<|endoftext|>\n\n\n"
|
| 22 |
+
elif item["role"] == "user":
|
| 23 |
+
prompt += "Q: " + item["content"] + "\n"
|
| 24 |
+
elif item["role"] == "assistant":
|
| 25 |
+
prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n"
|
| 26 |
+
|
| 27 |
+
if len(self.messages) > 0 and self.messages[-1]["role"] == "user":
|
| 28 |
+
prompt += "A: "
|
| 29 |
+
return prompt
|
| 30 |
+
|
| 31 |
+
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
| 32 |
+
precise = True
|
| 33 |
+
try:
|
| 34 |
+
cur_tokens = self.calc_tokens()
|
| 35 |
+
except Exception as e:
|
| 36 |
+
precise = False
|
| 37 |
+
if cur_tokens is None:
|
| 38 |
+
raise e
|
| 39 |
+
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
| 40 |
+
while cur_tokens > max_tokens:
|
| 41 |
+
if len(self.messages) > 1:
|
| 42 |
+
self.messages.pop(0)
|
| 43 |
+
elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant":
|
| 44 |
+
self.messages.pop(0)
|
| 45 |
+
if precise:
|
| 46 |
+
cur_tokens = self.calc_tokens()
|
| 47 |
+
else:
|
| 48 |
+
cur_tokens = len(str(self))
|
| 49 |
+
break
|
| 50 |
+
elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
|
| 51 |
+
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
|
| 52 |
+
break
|
| 53 |
+
else:
|
| 54 |
+
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
| 55 |
+
break
|
| 56 |
+
if precise:
|
| 57 |
+
cur_tokens = self.calc_tokens()
|
| 58 |
+
else:
|
| 59 |
+
cur_tokens = len(str(self))
|
| 60 |
+
return cur_tokens
|
| 61 |
+
|
| 62 |
+
def calc_tokens(self):
|
| 63 |
+
return num_tokens_from_string(str(self), self.model)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
| 67 |
+
def num_tokens_from_string(string: str, model: str) -> int:
|
| 68 |
+
"""Returns the number of tokens in a text string."""
|
| 69 |
+
import tiktoken
|
| 70 |
+
|
| 71 |
+
encoding = tiktoken.encoding_for_model(model)
|
| 72 |
+
num_tokens = len(encoding.encode(string, disallowed_special=()))
|
| 73 |
+
return num_tokens
|
bot/session_manager.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from common.expired_dict import ExpiredDict
|
| 2 |
+
from common.log import logger
|
| 3 |
+
from config import conf
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Session(object):
|
| 7 |
+
def __init__(self, session_id, system_prompt=None):
|
| 8 |
+
self.session_id = session_id
|
| 9 |
+
self.messages = []
|
| 10 |
+
if system_prompt is None:
|
| 11 |
+
self.system_prompt = conf().get("character_desc", "")
|
| 12 |
+
else:
|
| 13 |
+
self.system_prompt = system_prompt
|
| 14 |
+
|
| 15 |
+
# 重置会话
|
| 16 |
+
def reset(self):
|
| 17 |
+
system_item = {"role": "system", "content": self.system_prompt}
|
| 18 |
+
self.messages = [system_item]
|
| 19 |
+
|
| 20 |
+
def set_system_prompt(self, system_prompt):
|
| 21 |
+
self.system_prompt = system_prompt
|
| 22 |
+
self.reset()
|
| 23 |
+
|
| 24 |
+
def add_query(self, query):
|
| 25 |
+
user_item = {"role": "user", "content": query}
|
| 26 |
+
self.messages.append(user_item)
|
| 27 |
+
|
| 28 |
+
def add_reply(self, reply):
|
| 29 |
+
assistant_item = {"role": "assistant", "content": reply}
|
| 30 |
+
self.messages.append(assistant_item)
|
| 31 |
+
|
| 32 |
+
def discard_exceeding(self, max_tokens=None, cur_tokens=None):
|
| 33 |
+
raise NotImplementedError
|
| 34 |
+
|
| 35 |
+
def calc_tokens(self):
|
| 36 |
+
raise NotImplementedError
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class SessionManager(object):
|
| 40 |
+
def __init__(self, sessioncls, **session_args):
|
| 41 |
+
if conf().get("expires_in_seconds"):
|
| 42 |
+
sessions = ExpiredDict(conf().get("expires_in_seconds"))
|
| 43 |
+
else:
|
| 44 |
+
sessions = dict()
|
| 45 |
+
self.sessions = sessions
|
| 46 |
+
self.sessioncls = sessioncls
|
| 47 |
+
self.session_args = session_args
|
| 48 |
+
|
| 49 |
+
def build_session(self, session_id, system_prompt=None):
|
| 50 |
+
"""
|
| 51 |
+
如果session_id不在sessions中,创建一个新的session并添加到sessions中
|
| 52 |
+
如果system_prompt不会空,会更新session的system_prompt并重置session
|
| 53 |
+
"""
|
| 54 |
+
if session_id is None:
|
| 55 |
+
return self.sessioncls(session_id, system_prompt, **self.session_args)
|
| 56 |
+
|
| 57 |
+
if session_id not in self.sessions:
|
| 58 |
+
self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
|
| 59 |
+
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session
|
| 60 |
+
self.sessions[session_id].set_system_prompt(system_prompt)
|
| 61 |
+
session = self.sessions[session_id]
|
| 62 |
+
return session
|
| 63 |
+
|
| 64 |
+
def session_query(self, query, session_id):
|
| 65 |
+
session = self.build_session(session_id)
|
| 66 |
+
session.add_query(query)
|
| 67 |
+
try:
|
| 68 |
+
max_tokens = conf().get("conversation_max_tokens", 1000)
|
| 69 |
+
total_tokens = session.discard_exceeding(max_tokens, None)
|
| 70 |
+
logger.debug("prompt tokens used={}".format(total_tokens))
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
|
| 73 |
+
return session
|
| 74 |
+
|
| 75 |
+
def session_reply(self, reply, session_id, total_tokens=None):
|
| 76 |
+
session = self.build_session(session_id)
|
| 77 |
+
session.add_reply(reply)
|
| 78 |
+
try:
|
| 79 |
+
max_tokens = conf().get("conversation_max_tokens", 1000)
|
| 80 |
+
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
|
| 81 |
+
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
|
| 84 |
+
return session
|
| 85 |
+
|
| 86 |
+
def clear_session(self, session_id):
|
| 87 |
+
if session_id in self.sessions:
|
| 88 |
+
del self.sessions[session_id]
|
| 89 |
+
|
| 90 |
+
def clear_all_session(self):
|
| 91 |
+
self.sessions.clear()
|
bot/xunfei/xunfei_spark_bot.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# encoding:utf-8
|
| 2 |
+
|
| 3 |
+
import requests, json
|
| 4 |
+
from bot.bot import Bot
|
| 5 |
+
from bot.session_manager import SessionManager
|
| 6 |
+
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
| 7 |
+
from bridge.context import ContextType, Context
|
| 8 |
+
from bridge.reply import Reply, ReplyType
|
| 9 |
+
from common.log import logger
|
| 10 |
+
from config import conf
|
| 11 |
+
from common import const
|
| 12 |
+
import time
|
| 13 |
+
import _thread as thread
|
| 14 |
+
import datetime
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
from wsgiref.handlers import format_date_time
|
| 17 |
+
from urllib.parse import urlencode
|
| 18 |
+
import base64
|
| 19 |
+
import ssl
|
| 20 |
+
import hashlib
|
| 21 |
+
import hmac
|
| 22 |
+
import json
|
| 23 |
+
from time import mktime
|
| 24 |
+
from urllib.parse import urlparse
|
| 25 |
+
import websocket
|
| 26 |
+
import queue
|
| 27 |
+
import threading
|
| 28 |
+
import random
|
| 29 |
+
|
| 30 |
+
# 消息队列 map
|
| 31 |
+
queue_map = dict()
|
| 32 |
+
|
| 33 |
+
# 响应队列 map
|
| 34 |
+
reply_map = dict()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class XunFeiBot(Bot):
|
| 38 |
+
def __init__(self):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.app_id = conf().get("xunfei_app_id")
|
| 41 |
+
self.api_key = conf().get("xunfei_api_key")
|
| 42 |
+
self.api_secret = conf().get("xunfei_api_secret")
|
| 43 |
+
# 默认使用v2.0版本,1.5版本可设置为 general
|
| 44 |
+
self.domain = "generalv2"
|
| 45 |
+
# 默认使用v2.0版本,1.5版本可设置为 "ws://spark-api.xf-yun.com/v1.1/chat"
|
| 46 |
+
self.spark_url = "ws://spark-api.xf-yun.com/v2.1/chat"
|
| 47 |
+
self.host = urlparse(self.spark_url).netloc
|
| 48 |
+
self.path = urlparse(self.spark_url).path
|
| 49 |
+
# 和wenxin使用相同的session机制
|
| 50 |
+
self.sessions = SessionManager(BaiduWenxinSession, model=const.XUNFEI)
|
| 51 |
+
|
| 52 |
+
def reply(self, query, context: Context = None) -> Reply:
|
| 53 |
+
if context.type == ContextType.TEXT:
|
| 54 |
+
logger.info("[XunFei] query={}".format(query))
|
| 55 |
+
session_id = context["session_id"]
|
| 56 |
+
request_id = self.gen_request_id(session_id)
|
| 57 |
+
reply_map[request_id] = ""
|
| 58 |
+
session = self.sessions.session_query(query, session_id)
|
| 59 |
+
threading.Thread(target=self.create_web_socket, args=(session.messages, request_id)).start()
|
| 60 |
+
depth = 0
|
| 61 |
+
time.sleep(0.1)
|
| 62 |
+
t1 = time.time()
|
| 63 |
+
usage = {}
|
| 64 |
+
while depth <= 300:
|
| 65 |
+
try:
|
| 66 |
+
data_queue = queue_map.get(request_id)
|
| 67 |
+
if not data_queue:
|
| 68 |
+
depth += 1
|
| 69 |
+
time.sleep(0.1)
|
| 70 |
+
continue
|
| 71 |
+
data_item = data_queue.get(block=True, timeout=0.1)
|
| 72 |
+
if data_item.is_end:
|
| 73 |
+
# 请求结束
|
| 74 |
+
del queue_map[request_id]
|
| 75 |
+
if data_item.reply:
|
| 76 |
+
reply_map[request_id] += data_item.reply
|
| 77 |
+
usage = data_item.usage
|
| 78 |
+
break
|
| 79 |
+
|
| 80 |
+
reply_map[request_id] += data_item.reply
|
| 81 |
+
depth += 1
|
| 82 |
+
except Exception as e:
|
| 83 |
+
depth += 1
|
| 84 |
+
continue
|
| 85 |
+
t2 = time.time()
|
| 86 |
+
logger.info(f"[XunFei-API] response={reply_map[request_id]}, time={t2 - t1}s, usage={usage}")
|
| 87 |
+
self.sessions.session_reply(reply_map[request_id], session_id, usage.get("total_tokens"))
|
| 88 |
+
reply = Reply(ReplyType.TEXT, reply_map[request_id])
|
| 89 |
+
del reply_map[request_id]
|
| 90 |
+
return reply
|
| 91 |
+
else:
|
| 92 |
+
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
| 93 |
+
return reply
|
| 94 |
+
|
| 95 |
+
def create_web_socket(self, prompt, session_id, temperature=0.5):
|
| 96 |
+
logger.info(f"[XunFei] start connect, prompt={prompt}")
|
| 97 |
+
websocket.enableTrace(False)
|
| 98 |
+
wsUrl = self.create_url()
|
| 99 |
+
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close,
|
| 100 |
+
on_open=on_open)
|
| 101 |
+
data_queue = queue.Queue(1000)
|
| 102 |
+
queue_map[session_id] = data_queue
|
| 103 |
+
ws.appid = self.app_id
|
| 104 |
+
ws.question = prompt
|
| 105 |
+
ws.domain = self.domain
|
| 106 |
+
ws.session_id = session_id
|
| 107 |
+
ws.temperature = temperature
|
| 108 |
+
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
| 109 |
+
|
| 110 |
+
def gen_request_id(self, session_id: str):
|
| 111 |
+
return session_id + "_" + str(int(time.time())) + "" + str(random.randint(0, 100))
|
| 112 |
+
|
| 113 |
+
# 生成url
|
| 114 |
+
def create_url(self):
|
| 115 |
+
# 生成RFC1123格式的时间戳
|
| 116 |
+
now = datetime.now()
|
| 117 |
+
date = format_date_time(mktime(now.timetuple()))
|
| 118 |
+
|
| 119 |
+
# 拼接字符串
|
| 120 |
+
signature_origin = "host: " + self.host + "\n"
|
| 121 |
+
signature_origin += "date: " + date + "\n"
|
| 122 |
+
signature_origin += "GET " + self.path + " HTTP/1.1"
|
| 123 |
+
|
| 124 |
+
# 进行hmac-sha256进行加密
|
| 125 |
+
signature_sha = hmac.new(self.api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
| 126 |
+
digestmod=hashlib.sha256).digest()
|
| 127 |
+
|
| 128 |
+
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
| 129 |
+
|
| 130 |
+
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", ' \
|
| 131 |
+
f'signature="{signature_sha_base64}"'
|
| 132 |
+
|
| 133 |
+
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
| 134 |
+
|
| 135 |
+
# 将请求的鉴权参数组合为字典
|
| 136 |
+
v = {
|
| 137 |
+
"authorization": authorization,
|
| 138 |
+
"date": date,
|
| 139 |
+
"host": self.host
|
| 140 |
+
}
|
| 141 |
+
# 拼接鉴权参数,生成url
|
| 142 |
+
url = self.spark_url + '?' + urlencode(v)
|
| 143 |
+
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
| 144 |
+
return url
|
| 145 |
+
|
| 146 |
+
def gen_params(self, appid, domain, question):
|
| 147 |
+
"""
|
| 148 |
+
通过appid和用户的提问来生成请参数
|
| 149 |
+
"""
|
| 150 |
+
data = {
|
| 151 |
+
"header": {
|
| 152 |
+
"app_id": appid,
|
| 153 |
+
"uid": "1234"
|
| 154 |
+
},
|
| 155 |
+
"parameter": {
|
| 156 |
+
"chat": {
|
| 157 |
+
"domain": domain,
|
| 158 |
+
"random_threshold": 0.5,
|
| 159 |
+
"max_tokens": 2048,
|
| 160 |
+
"auditing": "default"
|
| 161 |
+
}
|
| 162 |
+
},
|
| 163 |
+
"payload": {
|
| 164 |
+
"message": {
|
| 165 |
+
"text": question
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
return data
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class ReplyItem:
|
| 173 |
+
def __init__(self, reply, usage=None, is_end=False):
|
| 174 |
+
self.is_end = is_end
|
| 175 |
+
self.reply = reply
|
| 176 |
+
self.usage = usage
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# 收到websocket错误的处理
|
| 180 |
+
def on_error(ws, error):
|
| 181 |
+
logger.error(f"[XunFei] error: {str(error)}")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# 收到websocket关闭的处理
|
| 185 |
+
def on_close(ws, one, two):
|
| 186 |
+
data_queue = queue_map.get(ws.session_id)
|
| 187 |
+
data_queue.put("END")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# 收到websocket连接建立的处理
|
| 191 |
+
def on_open(ws):
|
| 192 |
+
logger.info(f"[XunFei] Start websocket, session_id={ws.session_id}")
|
| 193 |
+
thread.start_new_thread(run, (ws,))
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def run(ws, *args):
|
| 197 |
+
data = json.dumps(gen_params(appid=ws.appid, domain=ws.domain, question=ws.question, temperature=ws.temperature))
|
| 198 |
+
ws.send(data)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# Websocket 操作
|
| 202 |
+
# 收到websocket消息的处理
|
| 203 |
+
def on_message(ws, message):
|
| 204 |
+
data = json.loads(message)
|
| 205 |
+
code = data['header']['code']
|
| 206 |
+
if code != 0:
|
| 207 |
+
logger.error(f'请求错误: {code}, {data}')
|
| 208 |
+
ws.close()
|
| 209 |
+
else:
|
| 210 |
+
choices = data["payload"]["choices"]
|
| 211 |
+
status = choices["status"]
|
| 212 |
+
content = choices["text"][0]["content"]
|
| 213 |
+
data_queue = queue_map.get(ws.session_id)
|
| 214 |
+
if not data_queue:
|
| 215 |
+
logger.error(f"[XunFei] can't find data queue, session_id={ws.session_id}")
|
| 216 |
+
return
|
| 217 |
+
reply_item = ReplyItem(content)
|
| 218 |
+
if status == 2:
|
| 219 |
+
usage = data["payload"].get("usage")
|
| 220 |
+
reply_item = ReplyItem(content, usage)
|
| 221 |
+
reply_item.is_end = True
|
| 222 |
+
ws.close()
|
| 223 |
+
data_queue.put(reply_item)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def gen_params(appid, domain, question, temperature=0.5):
|
| 227 |
+
"""
|
| 228 |
+
通过appid和用户的提问来生成请参数
|
| 229 |
+
"""
|
| 230 |
+
data = {
|
| 231 |
+
"header": {
|
| 232 |
+
"app_id": appid,
|
| 233 |
+
"uid": "1234"
|
| 234 |
+
},
|
| 235 |
+
"parameter": {
|
| 236 |
+
"chat": {
|
| 237 |
+
"domain": domain,
|
| 238 |
+
"temperature": temperature,
|
| 239 |
+
"random_threshold": 0.5,
|
| 240 |
+
"max_tokens": 2048,
|
| 241 |
+
"auditing": "default"
|
| 242 |
+
}
|
| 243 |
+
},
|
| 244 |
+
"payload": {
|
| 245 |
+
"message": {
|
| 246 |
+
"text": question
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
}
|
| 250 |
+
return data
|
bridge/bridge.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from bot.bot_factory import create_bot
|
| 2 |
+
from bridge.context import Context
|
| 3 |
+
from bridge.reply import Reply
|
| 4 |
+
from common import const
|
| 5 |
+
from common.log import logger
|
| 6 |
+
from common.singleton import singleton
|
| 7 |
+
from config import conf
|
| 8 |
+
from translate.factory import create_translator
|
| 9 |
+
from voice.factory import create_voice
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@singleton
|
| 13 |
+
class Bridge(object):
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.btype = {
|
| 16 |
+
"chat": const.CHATGPT,
|
| 17 |
+
"voice_to_text": conf().get("voice_to_text", "openai"),
|
| 18 |
+
"text_to_voice": conf().get("text_to_voice", "google"),
|
| 19 |
+
"translate": conf().get("translate", "baidu"),
|
| 20 |
+
}
|
| 21 |
+
model_type = conf().get("model")
|
| 22 |
+
if model_type in ["text-davinci-003"]:
|
| 23 |
+
self.btype["chat"] = const.OPEN_AI
|
| 24 |
+
if conf().get("use_azure_chatgpt", False):
|
| 25 |
+
self.btype["chat"] = const.CHATGPTONAZURE
|
| 26 |
+
if model_type in ["wenxin"]:
|
| 27 |
+
self.btype["chat"] = const.BAIDU
|
| 28 |
+
if model_type in ["xunfei"]:
|
| 29 |
+
self.btype["chat"] = const.XUNFEI
|
| 30 |
+
if conf().get("use_linkai") and conf().get("linkai_api_key"):
|
| 31 |
+
self.btype["chat"] = const.LINKAI
|
| 32 |
+
self.bots = {}
|
| 33 |
+
|
| 34 |
+
def get_bot(self, typename):
|
| 35 |
+
if self.bots.get(typename) is None:
|
| 36 |
+
logger.info("create bot {} for {}".format(self.btype[typename], typename))
|
| 37 |
+
if typename == "text_to_voice":
|
| 38 |
+
self.bots[typename] = create_voice(self.btype[typename])
|
| 39 |
+
elif typename == "voice_to_text":
|
| 40 |
+
self.bots[typename] = create_voice(self.btype[typename])
|
| 41 |
+
elif typename == "chat":
|
| 42 |
+
self.bots[typename] = create_bot(self.btype[typename])
|
| 43 |
+
elif typename == "translate":
|
| 44 |
+
self.bots[typename] = create_translator(self.btype[typename])
|
| 45 |
+
return self.bots[typename]
|
| 46 |
+
|
| 47 |
+
def get_bot_type(self, typename):
|
| 48 |
+
return self.btype[typename]
|
| 49 |
+
|
| 50 |
+
def fetch_reply_content(self, query, context: Context) -> Reply:
|
| 51 |
+
return self.get_bot("chat").reply(query, context)
|
| 52 |
+
|
| 53 |
+
def fetch_voice_to_text(self, voiceFile) -> Reply:
|
| 54 |
+
return self.get_bot("voice_to_text").voiceToText(voiceFile)
|
| 55 |
+
|
| 56 |
+
def fetch_text_to_voice(self, text) -> Reply:
|
| 57 |
+
return self.get_bot("text_to_voice").textToVoice(text)
|
| 58 |
+
|
| 59 |
+
def fetch_translate(self, text, from_lang="", to_lang="en") -> Reply:
|
| 60 |
+
return self.get_bot("translate").translate(text, from_lang, to_lang)
|
| 61 |
+
|
| 62 |
+
def reset_bot(self):
|
| 63 |
+
"""
|
| 64 |
+
重置bot路由
|
| 65 |
+
"""
|
| 66 |
+
self.__init__()
|
bridge/context.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# encoding:utf-8
|
| 2 |
+
|
| 3 |
+
from enum import Enum
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ContextType(Enum):
|
| 7 |
+
TEXT = 1 # 文本消息
|
| 8 |
+
VOICE = 2 # 音频消息
|
| 9 |
+
IMAGE = 3 # 图片消息
|
| 10 |
+
IMAGE_CREATE = 10 # 创建图片命令
|
| 11 |
+
JOIN_GROUP = 20 # 加入群聊
|
| 12 |
+
PATPAT = 21 # 拍了拍
|
| 13 |
+
|
| 14 |
+
def __str__(self):
|
| 15 |
+
return self.name
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Context:
|
| 19 |
+
def __init__(self, type: ContextType = None, content=None, kwargs=dict()):
|
| 20 |
+
self.type = type
|
| 21 |
+
self.content = content
|
| 22 |
+
self.kwargs = kwargs
|
| 23 |
+
|
| 24 |
+
def __contains__(self, key):
|
| 25 |
+
if key == "type":
|
| 26 |
+
return self.type is not None
|
| 27 |
+
elif key == "content":
|
| 28 |
+
return self.content is not None
|
| 29 |
+
else:
|
| 30 |
+
return key in self.kwargs
|
| 31 |
+
|
| 32 |
+
def __getitem__(self, key):
|
| 33 |
+
if key == "type":
|
| 34 |
+
return self.type
|
| 35 |
+
elif key == "content":
|
| 36 |
+
return self.content
|
| 37 |
+
else:
|
| 38 |
+
return self.kwargs[key]
|
| 39 |
+
|
| 40 |
+
def get(self, key, default=None):
|
| 41 |
+
try:
|
| 42 |
+
return self[key]
|
| 43 |
+
except KeyError:
|
| 44 |
+
return default
|
| 45 |
+
|
| 46 |
+
def __setitem__(self, key, value):
|
| 47 |
+
if key == "type":
|
| 48 |
+
self.type = value
|
| 49 |
+
elif key == "content":
|
| 50 |
+
self.content = value
|
| 51 |
+
else:
|
| 52 |
+
self.kwargs[key] = value
|
| 53 |
+
|
| 54 |
+
def __delitem__(self, key):
|
| 55 |
+
if key == "type":
|
| 56 |
+
self.type = None
|
| 57 |
+
elif key == "content":
|
| 58 |
+
self.content = None
|
| 59 |
+
else:
|
| 60 |
+
del self.kwargs[key]
|
| 61 |
+
|
| 62 |
+
def __str__(self):
|
| 63 |
+
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
|
bridge/reply.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# encoding:utf-8
|
| 2 |
+
|
| 3 |
+
from enum import Enum
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ReplyType(Enum):
|
| 7 |
+
TEXT = 1 # 文本
|
| 8 |
+
VOICE = 2 # 音频文件
|
| 9 |
+
IMAGE = 3 # 图片文件
|
| 10 |
+
IMAGE_URL = 4 # 图片URL
|
| 11 |
+
|
| 12 |
+
INFO = 9
|
| 13 |
+
ERROR = 10
|
| 14 |
+
|
| 15 |
+
def __str__(self):
|
| 16 |
+
return self.name
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Reply:
|
| 20 |
+
def __init__(self, type: ReplyType = None, content=None):
|
| 21 |
+
self.type = type
|
| 22 |
+
self.content = content
|
| 23 |
+
|
| 24 |
+
def __str__(self):
|
| 25 |
+
return "Reply(type={}, content={})".format(self.type, self.content)
|
channel/channel.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Message sending channel abstract class
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from bridge.bridge import Bridge
|
| 6 |
+
from bridge.context import Context
|
| 7 |
+
from bridge.reply import *
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Channel(object):
|
| 11 |
+
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
|
| 12 |
+
|
| 13 |
+
def startup(self):
|
| 14 |
+
"""
|
| 15 |
+
init channel
|
| 16 |
+
"""
|
| 17 |
+
raise NotImplementedError
|
| 18 |
+
|
| 19 |
+
def handle_text(self, msg):
|
| 20 |
+
"""
|
| 21 |
+
process received msg
|
| 22 |
+
:param msg: message object
|
| 23 |
+
"""
|
| 24 |
+
raise NotImplementedError
|
| 25 |
+
|
| 26 |
+
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
| 27 |
+
def send(self, reply: Reply, context: Context):
|
| 28 |
+
"""
|
| 29 |
+
send message to user
|
| 30 |
+
:param msg: message content
|
| 31 |
+
:param receiver: receiver channel account
|
| 32 |
+
:return:
|
| 33 |
+
"""
|
| 34 |
+
raise NotImplementedError
|
| 35 |
+
|
| 36 |
+
def build_reply_content(self, query, context: Context = None) -> Reply:
|
| 37 |
+
return Bridge().fetch_reply_content(query, context)
|
| 38 |
+
|
| 39 |
+
def build_voice_to_text(self, voice_file) -> Reply:
|
| 40 |
+
return Bridge().fetch_voice_to_text(voice_file)
|
| 41 |
+
|
| 42 |
+
def build_text_to_voice(self, text) -> Reply:
|
| 43 |
+
return Bridge().fetch_text_to_voice(text)
|
channel/channel_factory.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
channel factory
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def create_channel(channel_type):
|
| 7 |
+
"""
|
| 8 |
+
create a channel instance
|
| 9 |
+
:param channel_type: channel type code
|
| 10 |
+
:return: channel instance
|
| 11 |
+
"""
|
| 12 |
+
if channel_type == "wx":
|
| 13 |
+
from channel.wechat.wechat_channel import WechatChannel
|
| 14 |
+
|
| 15 |
+
return WechatChannel()
|
| 16 |
+
elif channel_type == "wxy":
|
| 17 |
+
from channel.wechat.wechaty_channel import WechatyChannel
|
| 18 |
+
|
| 19 |
+
return WechatyChannel()
|
| 20 |
+
elif channel_type == "terminal":
|
| 21 |
+
from channel.terminal.terminal_channel import TerminalChannel
|
| 22 |
+
|
| 23 |
+
return TerminalChannel()
|
| 24 |
+
elif channel_type == "wechatmp":
|
| 25 |
+
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
| 26 |
+
|
| 27 |
+
return WechatMPChannel(passive_reply=True)
|
| 28 |
+
elif channel_type == "wechatmp_service":
|
| 29 |
+
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
| 30 |
+
|
| 31 |
+
return WechatMPChannel(passive_reply=False)
|
| 32 |
+
elif channel_type == "wechatcom_app":
|
| 33 |
+
from channel.wechatcom.wechatcomapp_channel import WechatComAppChannel
|
| 34 |
+
|
| 35 |
+
return WechatComAppChannel()
|
| 36 |
+
raise RuntimeError
|
channel/chat_channel.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import threading
|
| 4 |
+
import time
|
| 5 |
+
from asyncio import CancelledError
|
| 6 |
+
from concurrent.futures import Future, ThreadPoolExecutor
|
| 7 |
+
|
| 8 |
+
from bridge.context import *
|
| 9 |
+
from bridge.reply import *
|
| 10 |
+
from channel.channel import Channel
|
| 11 |
+
from common.dequeue import Dequeue
|
| 12 |
+
from common.log import logger
|
| 13 |
+
from config import conf
|
| 14 |
+
from plugins import *
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from voice.audio_convert import any_to_wav
|
| 18 |
+
except Exception as e:
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# 抽象类, 它包含了与消息通道无关的通用处理逻辑
|
| 23 |
+
class ChatChannel(Channel):
|
| 24 |
+
name = None # 登录的用户名
|
| 25 |
+
user_id = None # 登录的用户id
|
| 26 |
+
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
|
| 27 |
+
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
|
| 28 |
+
lock = threading.Lock() # 用于控制对sessions的访问
|
| 29 |
+
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
| 30 |
+
|
| 31 |
+
def __init__(self):
|
| 32 |
+
_thread = threading.Thread(target=self.consume)
|
| 33 |
+
_thread.setDaemon(True)
|
| 34 |
+
_thread.start()
|
| 35 |
+
|
| 36 |
+
# 根据消息构造context,消息内容相关的触发项写在这里
|
| 37 |
+
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
| 38 |
+
context = Context(ctype, content)
|
| 39 |
+
context.kwargs = kwargs
|
| 40 |
+
# context首次传入时,origin_ctype是None,
|
| 41 |
+
# 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
|
| 42 |
+
# origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
|
| 43 |
+
if "origin_ctype" not in context:
|
| 44 |
+
context["origin_ctype"] = ctype
|
| 45 |
+
# context首次传入时,receiver是None,根据类型设置receiver
|
| 46 |
+
first_in = "receiver" not in context
|
| 47 |
+
# 群名匹配过程,设置session_id和receiver
|
| 48 |
+
if first_in: # context首次传入时,receiver是None,根据类型设置receiver
|
| 49 |
+
config = conf()
|
| 50 |
+
cmsg = context["msg"]
|
| 51 |
+
user_data = conf().get_user_data(cmsg.from_user_id)
|
| 52 |
+
context["openai_api_key"] = user_data.get("openai_api_key")
|
| 53 |
+
context["gpt_model"] = user_data.get("gpt_model")
|
| 54 |
+
if context.get("isgroup", False):
|
| 55 |
+
group_name = cmsg.other_user_nickname
|
| 56 |
+
group_id = cmsg.other_user_id
|
| 57 |
+
|
| 58 |
+
group_name_white_list = config.get("group_name_white_list", [])
|
| 59 |
+
group_name_keyword_white_list = config.get("group_name_keyword_white_list", [])
|
| 60 |
+
if any(
|
| 61 |
+
[
|
| 62 |
+
group_name in group_name_white_list,
|
| 63 |
+
"ALL_GROUP" in group_name_white_list,
|
| 64 |
+
check_contain(group_name, group_name_keyword_white_list),
|
| 65 |
+
]
|
| 66 |
+
):
|
| 67 |
+
group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
|
| 68 |
+
session_id = cmsg.actual_user_id
|
| 69 |
+
if any(
|
| 70 |
+
[
|
| 71 |
+
group_name in group_chat_in_one_session,
|
| 72 |
+
"ALL_GROUP" in group_chat_in_one_session,
|
| 73 |
+
]
|
| 74 |
+
):
|
| 75 |
+
session_id = group_id
|
| 76 |
+
else:
|
| 77 |
+
return None
|
| 78 |
+
context["session_id"] = session_id
|
| 79 |
+
context["receiver"] = group_id
|
| 80 |
+
else:
|
| 81 |
+
context["session_id"] = cmsg.other_user_id
|
| 82 |
+
context["receiver"] = cmsg.other_user_id
|
| 83 |
+
e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}))
|
| 84 |
+
context = e_context["context"]
|
| 85 |
+
if e_context.is_pass() or context is None:
|
| 86 |
+
return context
|
| 87 |
+
if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True):
|
| 88 |
+
logger.debug("[WX]self message skipped")
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
# 消息内容匹配过程,并处理content
|
| 92 |
+
if ctype == ContextType.TEXT:
|
| 93 |
+
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
|
| 94 |
+
logger.debug("[WX]reference query skipped")
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
if context.get("isgroup", False): # 群聊
|
| 98 |
+
# 校验关键字
|
| 99 |
+
match_prefix = check_prefix(content, conf().get("group_chat_prefix"))
|
| 100 |
+
match_contain = check_contain(content, conf().get("group_chat_keyword"))
|
| 101 |
+
flag = False
|
| 102 |
+
if match_prefix is not None or match_contain is not None:
|
| 103 |
+
flag = True
|
| 104 |
+
if match_prefix:
|
| 105 |
+
content = content.replace(match_prefix, "", 1).strip()
|
| 106 |
+
if context["msg"].is_at:
|
| 107 |
+
logger.info("[WX]receive group at")
|
| 108 |
+
if not conf().get("group_at_off", False):
|
| 109 |
+
flag = True
|
| 110 |
+
pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
|
| 111 |
+
subtract_res = re.sub(pattern, r"", content)
|
| 112 |
+
if subtract_res == content and context["msg"].self_display_name:
|
| 113 |
+
# 前缀移除后没有变化,使用群昵称再次移除
|
| 114 |
+
pattern = f"@{re.escape(context['msg'].self_display_name)}(\u2005|\u0020)"
|
| 115 |
+
subtract_res = re.sub(pattern, r"", content)
|
| 116 |
+
content = subtract_res
|
| 117 |
+
if not flag:
|
| 118 |
+
if context["origin_ctype"] == ContextType.VOICE:
|
| 119 |
+
logger.info("[WX]receive group voice, but checkprefix didn't match")
|
| 120 |
+
return None
|
| 121 |
+
else: # 单聊
|
| 122 |
+
match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
|
| 123 |
+
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
|
| 124 |
+
content = content.replace(match_prefix, "", 1).strip()
|
| 125 |
+
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
|
| 126 |
+
pass
|
| 127 |
+
else:
|
| 128 |
+
return None
|
| 129 |
+
content = content.strip()
|
| 130 |
+
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
| 131 |
+
if img_match_prefix:
|
| 132 |
+
content = content.replace(img_match_prefix, "", 1)
|
| 133 |
+
context.type = ContextType.IMAGE_CREATE
|
| 134 |
+
else:
|
| 135 |
+
context.type = ContextType.TEXT
|
| 136 |
+
context.content = content.strip()
|
| 137 |
+
if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
| 138 |
+
context["desire_rtype"] = ReplyType.VOICE
|
| 139 |
+
elif context.type == ContextType.VOICE:
|
| 140 |
+
if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
| 141 |
+
context["desire_rtype"] = ReplyType.VOICE
|
| 142 |
+
|
| 143 |
+
return context
|
| 144 |
+
|
| 145 |
+
def _handle(self, context: Context):
|
| 146 |
+
if context is None or not context.content:
|
| 147 |
+
return
|
| 148 |
+
logger.debug("[WX] ready to handle context: {}".format(context))
|
| 149 |
+
# reply的构建步骤
|
| 150 |
+
reply = self._generate_reply(context)
|
| 151 |
+
|
| 152 |
+
logger.debug("[WX] ready to decorate reply: {}".format(reply))
|
| 153 |
+
# reply的包装步骤
|
| 154 |
+
reply = self._decorate_reply(context, reply)
|
| 155 |
+
|
| 156 |
+
# reply的发送步骤
|
| 157 |
+
self._send_reply(context, reply)
|
| 158 |
+
|
| 159 |
+
def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
|
| 160 |
+
e_context = PluginManager().emit_event(
|
| 161 |
+
EventContext(
|
| 162 |
+
Event.ON_HANDLE_CONTEXT,
|
| 163 |
+
{"channel": self, "context": context, "reply": reply},
|
| 164 |
+
)
|
| 165 |
+
)
|
| 166 |
+
reply = e_context["reply"]
|
| 167 |
+
if not e_context.is_pass():
|
| 168 |
+
logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content))
|
| 169 |
+
if e_context.is_break():
|
| 170 |
+
context["generate_breaked_by"] = e_context["breaked_by"]
|
| 171 |
+
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
|
| 172 |
+
reply = super().build_reply_content(context.content, context)
|
| 173 |
+
elif context.type == ContextType.VOICE: # 语音消息
|
| 174 |
+
cmsg = context["msg"]
|
| 175 |
+
cmsg.prepare()
|
| 176 |
+
file_path = context.content
|
| 177 |
+
wav_path = os.path.splitext(file_path)[0] + ".wav"
|
| 178 |
+
try:
|
| 179 |
+
any_to_wav(file_path, wav_path)
|
| 180 |
+
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
|
| 181 |
+
logger.warning("[WX]any to wav error, use raw path. " + str(e))
|
| 182 |
+
wav_path = file_path
|
| 183 |
+
# 语音识别
|
| 184 |
+
reply = super().build_voice_to_text(wav_path)
|
| 185 |
+
# 删除临时文件
|
| 186 |
+
try:
|
| 187 |
+
os.remove(file_path)
|
| 188 |
+
if wav_path != file_path:
|
| 189 |
+
os.remove(wav_path)
|
| 190 |
+
except Exception as e:
|
| 191 |
+
pass
|
| 192 |
+
# logger.warning("[WX]delete temp file error: " + str(e))
|
| 193 |
+
|
| 194 |
+
if reply.type == ReplyType.TEXT:
|
| 195 |
+
new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs)
|
| 196 |
+
if new_context:
|
| 197 |
+
reply = self._generate_reply(new_context)
|
| 198 |
+
else:
|
| 199 |
+
return
|
| 200 |
+
elif context.type == ContextType.IMAGE: # 图片消息,当前无默认逻辑
|
| 201 |
+
pass
|
| 202 |
+
else:
|
| 203 |
+
logger.error("[WX] unknown context type: {}".format(context.type))
|
| 204 |
+
return
|
| 205 |
+
return reply
|
| 206 |
+
|
| 207 |
+
def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
|
| 208 |
+
if reply and reply.type:
|
| 209 |
+
e_context = PluginManager().emit_event(
|
| 210 |
+
EventContext(
|
| 211 |
+
Event.ON_DECORATE_REPLY,
|
| 212 |
+
{"channel": self, "context": context, "reply": reply},
|
| 213 |
+
)
|
| 214 |
+
)
|
| 215 |
+
reply = e_context["reply"]
|
| 216 |
+
desire_rtype = context.get("desire_rtype")
|
| 217 |
+
if not e_context.is_pass() and reply and reply.type:
|
| 218 |
+
if reply.type in self.NOT_SUPPORT_REPLYTYPE:
|
| 219 |
+
logger.error("[WX]reply type not support: " + str(reply.type))
|
| 220 |
+
reply.type = ReplyType.ERROR
|
| 221 |
+
reply.content = "不支持发送的消息类型: " + str(reply.type)
|
| 222 |
+
|
| 223 |
+
if reply.type == ReplyType.TEXT:
|
| 224 |
+
reply_text = reply.content
|
| 225 |
+
if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
| 226 |
+
reply = super().build_text_to_voice(reply.content)
|
| 227 |
+
return self._decorate_reply(context, reply)
|
| 228 |
+
if context.get("isgroup", False):
|
| 229 |
+
reply_text = "@" + context["msg"].actual_user_nickname + "\n" + reply_text.strip()
|
| 230 |
+
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text + conf().get("group_chat_reply_suffix", "")
|
| 231 |
+
else:
|
| 232 |
+
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text + conf().get("single_chat_reply_suffix", "")
|
| 233 |
+
reply.content = reply_text
|
| 234 |
+
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
| 235 |
+
reply.content = "[" + str(reply.type) + "]\n" + reply.content
|
| 236 |
+
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
|
| 237 |
+
pass
|
| 238 |
+
else:
|
| 239 |
+
logger.error("[WX] unknown reply type: {}".format(reply.type))
|
| 240 |
+
return
|
| 241 |
+
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
|
| 242 |
+
logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
|
| 243 |
+
return reply
|
| 244 |
+
|
| 245 |
+
def _send_reply(self, context: Context, reply: Reply):
|
| 246 |
+
if reply and reply.type:
|
| 247 |
+
e_context = PluginManager().emit_event(
|
| 248 |
+
EventContext(
|
| 249 |
+
Event.ON_SEND_REPLY,
|
| 250 |
+
{"channel": self, "context": context, "reply": reply},
|
| 251 |
+
)
|
| 252 |
+
)
|
| 253 |
+
reply = e_context["reply"]
|
| 254 |
+
if not e_context.is_pass() and reply and reply.type:
|
| 255 |
+
logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context))
|
| 256 |
+
self._send(reply, context)
|
| 257 |
+
|
| 258 |
+
def _send(self, reply: Reply, context: Context, retry_cnt=0):
|
| 259 |
+
try:
|
| 260 |
+
self.send(reply, context)
|
| 261 |
+
except Exception as e:
|
| 262 |
+
logger.error("[WX] sendMsg error: {}".format(str(e)))
|
| 263 |
+
if isinstance(e, NotImplementedError):
|
| 264 |
+
return
|
| 265 |
+
logger.exception(e)
|
| 266 |
+
if retry_cnt < 2:
|
| 267 |
+
time.sleep(3 + 3 * retry_cnt)
|
| 268 |
+
self._send(reply, context, retry_cnt + 1)
|
| 269 |
+
|
| 270 |
+
def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数
|
| 271 |
+
logger.debug("Worker return success, session_id = {}".format(session_id))
|
| 272 |
+
|
| 273 |
+
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
|
| 274 |
+
logger.exception("Worker return exception: {}".format(exception))
|
| 275 |
+
|
| 276 |
+
def _thread_pool_callback(self, session_id, **kwargs):
|
| 277 |
+
def func(worker: Future):
|
| 278 |
+
try:
|
| 279 |
+
worker_exception = worker.exception()
|
| 280 |
+
if worker_exception:
|
| 281 |
+
self._fail_callback(session_id, exception=worker_exception, **kwargs)
|
| 282 |
+
else:
|
| 283 |
+
self._success_callback(session_id, **kwargs)
|
| 284 |
+
except CancelledError as e:
|
| 285 |
+
logger.info("Worker cancelled, session_id = {}".format(session_id))
|
| 286 |
+
except Exception as e:
|
| 287 |
+
logger.exception("Worker raise exception: {}".format(e))
|
| 288 |
+
with self.lock:
|
| 289 |
+
self.sessions[session_id][1].release()
|
| 290 |
+
|
| 291 |
+
return func
|
| 292 |
+
|
| 293 |
+
def produce(self, context: Context):
|
| 294 |
+
session_id = context["session_id"]
|
| 295 |
+
with self.lock:
|
| 296 |
+
if session_id not in self.sessions:
|
| 297 |
+
self.sessions[session_id] = [
|
| 298 |
+
Dequeue(),
|
| 299 |
+
threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)),
|
| 300 |
+
]
|
| 301 |
+
if context.type == ContextType.TEXT and context.content.startswith("#"):
|
| 302 |
+
self.sessions[session_id][0].putleft(context) # 优先处理管理��令
|
| 303 |
+
else:
|
| 304 |
+
self.sessions[session_id][0].put(context)
|
| 305 |
+
|
| 306 |
+
# 消费者函数,单独线程,用于从消息队列中取出消息并处理
|
| 307 |
+
def consume(self):
|
| 308 |
+
while True:
|
| 309 |
+
with self.lock:
|
| 310 |
+
session_ids = list(self.sessions.keys())
|
| 311 |
+
for session_id in session_ids:
|
| 312 |
+
context_queue, semaphore = self.sessions[session_id]
|
| 313 |
+
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
|
| 314 |
+
if not context_queue.empty():
|
| 315 |
+
context = context_queue.get()
|
| 316 |
+
logger.debug("[WX] consume context: {}".format(context))
|
| 317 |
+
future: Future = self.handler_pool.submit(self._handle, context)
|
| 318 |
+
future.add_done_callback(self._thread_pool_callback(session_id, context=context))
|
| 319 |
+
if session_id not in self.futures:
|
| 320 |
+
self.futures[session_id] = []
|
| 321 |
+
self.futures[session_id].append(future)
|
| 322 |
+
elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
| 323 |
+
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
|
| 324 |
+
assert len(self.futures[session_id]) == 0, "thread pool error"
|
| 325 |
+
del self.sessions[session_id]
|
| 326 |
+
else:
|
| 327 |
+
semaphore.release()
|
| 328 |
+
time.sleep(0.1)
|
| 329 |
+
|
| 330 |
+
# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
|
| 331 |
+
def cancel_session(self, session_id):
|
| 332 |
+
with self.lock:
|
| 333 |
+
if session_id in self.sessions:
|
| 334 |
+
for future in self.futures[session_id]:
|
| 335 |
+
future.cancel()
|
| 336 |
+
cnt = self.sessions[session_id][0].qsize()
|
| 337 |
+
if cnt > 0:
|
| 338 |
+
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
| 339 |
+
self.sessions[session_id][0] = Dequeue()
|
| 340 |
+
|
| 341 |
+
def cancel_all_session(self):
|
| 342 |
+
with self.lock:
|
| 343 |
+
for session_id in self.sessions:
|
| 344 |
+
for future in self.futures[session_id]:
|
| 345 |
+
future.cancel()
|
| 346 |
+
cnt = self.sessions[session_id][0].qsize()
|
| 347 |
+
if cnt > 0:
|
| 348 |
+
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
| 349 |
+
self.sessions[session_id][0] = Dequeue()
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def check_prefix(content, prefix_list):
|
| 353 |
+
if not prefix_list:
|
| 354 |
+
return None
|
| 355 |
+
for prefix in prefix_list:
|
| 356 |
+
if content.startswith(prefix):
|
| 357 |
+
return prefix
|
| 358 |
+
return None
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def check_contain(content, keyword_list):
|
| 362 |
+
if not keyword_list:
|
| 363 |
+
return None
|
| 364 |
+
for ky in keyword_list:
|
| 365 |
+
if content.find(ky) != -1:
|
| 366 |
+
return True
|
| 367 |
+
return None
|
channel/chat_message.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。
|
| 3 |
+
|
| 4 |
+
填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel
|
| 5 |
+
|
| 6 |
+
ChatMessage
|
| 7 |
+
msg_id: 消息id (必填)
|
| 8 |
+
create_time: 消息创建时间
|
| 9 |
+
|
| 10 |
+
ctype: 消息类型 : ContextType (必填)
|
| 11 |
+
content: 消息内容, 如果是声音/图片,这里是文件路径 (必填)
|
| 12 |
+
|
| 13 |
+
from_user_id: 发送者id (必填)
|
| 14 |
+
from_user_nickname: 发送者昵称
|
| 15 |
+
to_user_id: 接收者id (必填)
|
| 16 |
+
to_user_nickname: 接收者昵称
|
| 17 |
+
|
| 18 |
+
other_user_id: 对方的id,如果你是发送者,那这个就是接收者id,如果你是接收者,那这个就是发送者id,如果是群消息,那这一直是群id (必填)
|
| 19 |
+
other_user_nickname: 同上
|
| 20 |
+
|
| 21 |
+
is_group: 是否是群消息 (群聊必填)
|
| 22 |
+
is_at: 是否被at
|
| 23 |
+
|
| 24 |
+
- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
|
| 25 |
+
actual_user_id: 实际发送者id (群聊必填)
|
| 26 |
+
actual_user_nickname:实际发送者昵称
|
| 27 |
+
self_display_name: 自身的展示名,设置群昵称时,该字段表示群昵称
|
| 28 |
+
|
| 29 |
+
_prepare_fn: 准备函数,用于准备消息的内容,比如下载图片等,
|
| 30 |
+
_prepared: 是否已经调用过准备函数
|
| 31 |
+
_rawmsg: 原始消息对象
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ChatMessage(object):
|
| 37 |
+
msg_id = None
|
| 38 |
+
create_time = None
|
| 39 |
+
|
| 40 |
+
ctype = None
|
| 41 |
+
content = None
|
| 42 |
+
|
| 43 |
+
from_user_id = None
|
| 44 |
+
from_user_nickname = None
|
| 45 |
+
to_user_id = None
|
| 46 |
+
to_user_nickname = None
|
| 47 |
+
other_user_id = None
|
| 48 |
+
other_user_nickname = None
|
| 49 |
+
my_msg = False
|
| 50 |
+
self_display_name = None
|
| 51 |
+
|
| 52 |
+
is_group = False
|
| 53 |
+
is_at = False
|
| 54 |
+
actual_user_id = None
|
| 55 |
+
actual_user_nickname = None
|
| 56 |
+
|
| 57 |
+
_prepare_fn = None
|
| 58 |
+
_prepared = False
|
| 59 |
+
_rawmsg = None
|
| 60 |
+
|
| 61 |
+
def __init__(self, _rawmsg):
|
| 62 |
+
self._rawmsg = _rawmsg
|
| 63 |
+
|
| 64 |
+
def prepare(self):
|
| 65 |
+
if self._prepare_fn and not self._prepared:
|
| 66 |
+
self._prepared = True
|
| 67 |
+
self._prepare_fn()
|
| 68 |
+
|
| 69 |
+
def __str__(self):
|
| 70 |
+
return "ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}".format(
|
| 71 |
+
self.msg_id,
|
| 72 |
+
self.create_time,
|
| 73 |
+
self.ctype,
|
| 74 |
+
self.content,
|
| 75 |
+
self.from_user_id,
|
| 76 |
+
self.from_user_nickname,
|
| 77 |
+
self.to_user_id,
|
| 78 |
+
self.to_user_nickname,
|
| 79 |
+
self.other_user_id,
|
| 80 |
+
self.other_user_nickname,
|
| 81 |
+
self.is_group,
|
| 82 |
+
self.is_at,
|
| 83 |
+
self.actual_user_id,
|
| 84 |
+
self.actual_user_nickname,
|
| 85 |
+
)
|
channel/terminal/terminal_channel.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
from bridge.context import *
|
| 4 |
+
from bridge.reply import Reply, ReplyType
|
| 5 |
+
from channel.chat_channel import ChatChannel, check_prefix
|
| 6 |
+
from channel.chat_message import ChatMessage
|
| 7 |
+
from common.log import logger
|
| 8 |
+
from config import conf
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TerminalMessage(ChatMessage):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
msg_id,
|
| 15 |
+
content,
|
| 16 |
+
ctype=ContextType.TEXT,
|
| 17 |
+
from_user_id="User",
|
| 18 |
+
to_user_id="Chatgpt",
|
| 19 |
+
other_user_id="Chatgpt",
|
| 20 |
+
):
|
| 21 |
+
self.msg_id = msg_id
|
| 22 |
+
self.ctype = ctype
|
| 23 |
+
self.content = content
|
| 24 |
+
self.from_user_id = from_user_id
|
| 25 |
+
self.to_user_id = to_user_id
|
| 26 |
+
self.other_user_id = other_user_id
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TerminalChannel(ChatChannel):
|
| 30 |
+
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
|
| 31 |
+
|
| 32 |
+
def send(self, reply: Reply, context: Context):
|
| 33 |
+
print("\nBot:")
|
| 34 |
+
if reply.type == ReplyType.IMAGE:
|
| 35 |
+
from PIL import Image
|
| 36 |
+
|
| 37 |
+
image_storage = reply.content
|
| 38 |
+
image_storage.seek(0)
|
| 39 |
+
img = Image.open(image_storage)
|
| 40 |
+
print("<IMAGE>")
|
| 41 |
+
img.show()
|
| 42 |
+
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
| 43 |
+
import io
|
| 44 |
+
|
| 45 |
+
import requests
|
| 46 |
+
from PIL import Image
|
| 47 |
+
|
| 48 |
+
img_url = reply.content
|
| 49 |
+
pic_res = requests.get(img_url, stream=True)
|
| 50 |
+
image_storage = io.BytesIO()
|
| 51 |
+
for block in pic_res.iter_content(1024):
|
| 52 |
+
image_storage.write(block)
|
| 53 |
+
image_storage.seek(0)
|
| 54 |
+
img = Image.open(image_storage)
|
| 55 |
+
print(img_url)
|
| 56 |
+
img.show()
|
| 57 |
+
else:
|
| 58 |
+
print(reply.content)
|
| 59 |
+
print("\nUser:", end="")
|
| 60 |
+
sys.stdout.flush()
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
def startup(self):
|
| 64 |
+
context = Context()
|
| 65 |
+
logger.setLevel("WARN")
|
| 66 |
+
print("\nPlease input your question:\nUser:", end="")
|
| 67 |
+
sys.stdout.flush()
|
| 68 |
+
msg_id = 0
|
| 69 |
+
while True:
|
| 70 |
+
try:
|
| 71 |
+
prompt = self.get_input()
|
| 72 |
+
except KeyboardInterrupt:
|
| 73 |
+
print("\nExiting...")
|
| 74 |
+
sys.exit()
|
| 75 |
+
msg_id += 1
|
| 76 |
+
trigger_prefixs = conf().get("single_chat_prefix", [""])
|
| 77 |
+
if check_prefix(prompt, trigger_prefixs) is None:
|
| 78 |
+
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
|
| 79 |
+
|
| 80 |
+
context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt))
|
| 81 |
+
if context:
|
| 82 |
+
self.produce(context)
|
| 83 |
+
else:
|
| 84 |
+
raise Exception("context is None")
|
| 85 |
+
|
| 86 |
+
def get_input(self):
|
| 87 |
+
"""
|
| 88 |
+
Multi-line input function
|
| 89 |
+
"""
|
| 90 |
+
sys.stdout.flush()
|
| 91 |
+
line = input()
|
| 92 |
+
return line
|
channel/wechat/wechat_channel.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# encoding:utf-8
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
wechat channel
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import io
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
import threading
|
| 11 |
+
import time
|
| 12 |
+
|
| 13 |
+
import requests
|
| 14 |
+
|
| 15 |
+
from bridge.context import *
|
| 16 |
+
from bridge.reply import *
|
| 17 |
+
from channel.chat_channel import ChatChannel
|
| 18 |
+
from channel.wechat.wechat_message import *
|
| 19 |
+
from common.expired_dict import ExpiredDict
|
| 20 |
+
from common.log import logger
|
| 21 |
+
from common.singleton import singleton
|
| 22 |
+
from common.time_check import time_checker
|
| 23 |
+
from config import conf, get_appdata_dir
|
| 24 |
+
from lib import itchat
|
| 25 |
+
from lib.itchat.content import *
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE])
|
| 29 |
+
def handler_single_msg(msg):
|
| 30 |
+
try:
|
| 31 |
+
cmsg = WechatMessage(msg, False)
|
| 32 |
+
except NotImplementedError as e:
|
| 33 |
+
logger.debug("[WX]single message {} skipped: {}".format(msg["MsgId"], e))
|
| 34 |
+
return None
|
| 35 |
+
WechatChannel().handle_single(cmsg)
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE], isGroupChat=True)
|
| 40 |
+
def handler_group_msg(msg):
|
| 41 |
+
try:
|
| 42 |
+
cmsg = WechatMessage(msg, True)
|
| 43 |
+
except NotImplementedError as e:
|
| 44 |
+
logger.debug("[WX]group message {} skipped: {}".format(msg["MsgId"], e))
|
| 45 |
+
return None
|
| 46 |
+
WechatChannel().handle_group(cmsg)
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _check(func):
|
| 51 |
+
def wrapper(self, cmsg: ChatMessage):
|
| 52 |
+
msgId = cmsg.msg_id
|
| 53 |
+
if msgId in self.receivedMsgs:
|
| 54 |
+
logger.info("Wechat message {} already received, ignore".format(msgId))
|
| 55 |
+
return
|
| 56 |
+
self.receivedMsgs[msgId] = True
|
| 57 |
+
create_time = cmsg.create_time # 消息时间戳
|
| 58 |
+
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
| 59 |
+
logger.debug("[WX]history message {} skipped".format(msgId))
|
| 60 |
+
return
|
| 61 |
+
if cmsg.my_msg and not cmsg.is_group:
|
| 62 |
+
logger.debug("[WX]my message {} skipped".format(msgId))
|
| 63 |
+
return
|
| 64 |
+
return func(self, cmsg)
|
| 65 |
+
|
| 66 |
+
return wrapper
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# 可用的二维码生成接口
|
| 70 |
+
# https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com
|
| 71 |
+
# https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com
|
| 72 |
+
def qrCallback(uuid, status, qrcode):
|
| 73 |
+
# logger.debug("qrCallback: {} {}".format(uuid,status))
|
| 74 |
+
if status == "0":
|
| 75 |
+
try:
|
| 76 |
+
from PIL import Image
|
| 77 |
+
|
| 78 |
+
img = Image.open(io.BytesIO(qrcode))
|
| 79 |
+
_thread = threading.Thread(target=img.show, args=("QRCode",))
|
| 80 |
+
_thread.setDaemon(True)
|
| 81 |
+
_thread.start()
|
| 82 |
+
except Exception as e:
|
| 83 |
+
pass
|
| 84 |
+
|
| 85 |
+
import qrcode
|
| 86 |
+
|
| 87 |
+
url = f"https://login.weixin.qq.com/l/{uuid}"
|
| 88 |
+
|
| 89 |
+
qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
|
| 90 |
+
qr_api2 = "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
|
| 91 |
+
qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url)
|
| 92 |
+
qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
|
| 93 |
+
print("You can also scan QRCode in any website below:")
|
| 94 |
+
print(qr_api3)
|
| 95 |
+
print(qr_api4)
|
| 96 |
+
print(qr_api2)
|
| 97 |
+
print(qr_api1)
|
| 98 |
+
|
| 99 |
+
qr = qrcode.QRCode(border=1)
|
| 100 |
+
qr.add_data(url)
|
| 101 |
+
qr.make(fit=True)
|
| 102 |
+
qr.print_ascii(invert=True)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@singleton
|
| 106 |
+
class WechatChannel(ChatChannel):
|
| 107 |
+
NOT_SUPPORT_REPLYTYPE = []
|
| 108 |
+
|
| 109 |
+
def __init__(self):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.receivedMsgs = ExpiredDict(60 * 60)
|
| 112 |
+
|
| 113 |
+
def startup(self):
|
| 114 |
+
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
|
| 115 |
+
# login by scan QRCode
|
| 116 |
+
hotReload = conf().get("hot_reload", False)
|
| 117 |
+
status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
|
| 118 |
+
itchat.auto_login(
|
| 119 |
+
enableCmdQR=2,
|
| 120 |
+
hotReload=hotReload,
|
| 121 |
+
statusStorageDir=status_path,
|
| 122 |
+
qrCallback=qrCallback,
|
| 123 |
+
)
|
| 124 |
+
self.user_id = itchat.instance.storageClass.userName
|
| 125 |
+
self.name = itchat.instance.storageClass.nickName
|
| 126 |
+
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
|
| 127 |
+
# start message listener
|
| 128 |
+
itchat.run()
|
| 129 |
+
|
| 130 |
+
# handle_* 系列函数处理收到的消息后构造Context,然后传入produce函数中处理Context和发送回复
|
| 131 |
+
# Context包含了消息的所有信息,包括以下属性
|
| 132 |
+
# type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE
|
| 133 |
+
# content 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令
|
| 134 |
+
# kwargs 附加参数字典,包含以下的key:
|
| 135 |
+
# session_id: 会话id
|
| 136 |
+
# isgroup: 是否是群聊
|
| 137 |
+
# receiver: 需要回复的对象
|
| 138 |
+
# msg: ChatMessage消息对象
|
| 139 |
+
# origin_ctype: 原始消息类型,语音转文字后,私聊时如果匹配前缀失败,会根据初始消息是否是语音来放宽触发规则
|
| 140 |
+
# desire_rtype: 希望回复类���,默认是文本回复,设置为ReplyType.VOICE是语音回复
|
| 141 |
+
|
| 142 |
+
@time_checker
|
| 143 |
+
@_check
|
| 144 |
+
def handle_single(self, cmsg: ChatMessage):
|
| 145 |
+
if cmsg.ctype == ContextType.VOICE:
|
| 146 |
+
if conf().get("speech_recognition") != True:
|
| 147 |
+
return
|
| 148 |
+
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
|
| 149 |
+
elif cmsg.ctype == ContextType.IMAGE:
|
| 150 |
+
logger.debug("[WX]receive image msg: {}".format(cmsg.content))
|
| 151 |
+
elif cmsg.ctype == ContextType.PATPAT:
|
| 152 |
+
logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
|
| 153 |
+
elif cmsg.ctype == ContextType.TEXT:
|
| 154 |
+
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
| 155 |
+
else:
|
| 156 |
+
logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
|
| 157 |
+
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
| 158 |
+
if context:
|
| 159 |
+
self.produce(context)
|
| 160 |
+
|
| 161 |
+
@time_checker
|
| 162 |
+
@_check
|
| 163 |
+
def handle_group(self, cmsg: ChatMessage):
|
| 164 |
+
if cmsg.ctype == ContextType.VOICE:
|
| 165 |
+
if conf().get("group_speech_recognition") != True:
|
| 166 |
+
return
|
| 167 |
+
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
|
| 168 |
+
elif cmsg.ctype == ContextType.IMAGE:
|
| 169 |
+
logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
|
| 170 |
+
elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT]:
|
| 171 |
+
logger.debug("[WX]receive note msg: {}".format(cmsg.content))
|
| 172 |
+
elif cmsg.ctype == ContextType.TEXT:
|
| 173 |
+
# logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
| 174 |
+
pass
|
| 175 |
+
else:
|
| 176 |
+
logger.debug("[WX]receive group msg: {}".format(cmsg.content))
|
| 177 |
+
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
| 178 |
+
if context:
|
| 179 |
+
self.produce(context)
|
| 180 |
+
|
| 181 |
+
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
| 182 |
+
def send(self, reply: Reply, context: Context):
|
| 183 |
+
receiver = context["receiver"]
|
| 184 |
+
if reply.type == ReplyType.TEXT:
|
| 185 |
+
itchat.send(reply.content, toUserName=receiver)
|
| 186 |
+
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
| 187 |
+
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
| 188 |
+
itchat.send(reply.content, toUserName=receiver)
|
| 189 |
+
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
| 190 |
+
elif reply.type == ReplyType.VOICE:
|
| 191 |
+
itchat.send_file(reply.content, toUserName=receiver)
|
| 192 |
+
logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
|
| 193 |
+
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
| 194 |
+
img_url = reply.content
|
| 195 |
+
logger.debug(f"[WX] start download image, img_url={img_url}")
|
| 196 |
+
pic_res = requests.get(img_url, stream=True)
|
| 197 |
+
image_storage = io.BytesIO()
|
| 198 |
+
size = 0
|
| 199 |
+
for block in pic_res.iter_content(1024):
|
| 200 |
+
size += len(block)
|
| 201 |
+
image_storage.write(block)
|
| 202 |
+
logger.info(f"[WX] download image success, size={size}, img_url={img_url}")
|
| 203 |
+
image_storage.seek(0)
|
| 204 |
+
itchat.send_image(image_storage, toUserName=receiver)
|
| 205 |
+
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
|
| 206 |
+
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
| 207 |
+
image_storage = reply.content
|
| 208 |
+
image_storage.seek(0)
|
| 209 |
+
itchat.send_image(image_storage, toUserName=receiver)
|
| 210 |
+
logger.info("[WX] sendImage, receiver={}".format(receiver))
|
channel/wechat/wechat_message.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
from bridge.context import ContextType
|
| 4 |
+
from channel.chat_message import ChatMessage
|
| 5 |
+
from common.log import logger
|
| 6 |
+
from common.tmp_dir import TmpDir
|
| 7 |
+
from lib import itchat
|
| 8 |
+
from lib.itchat.content import *
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class WechatMessage(ChatMessage):
|
| 12 |
+
def __init__(self, itchat_msg, is_group=False):
|
| 13 |
+
super().__init__(itchat_msg)
|
| 14 |
+
self.msg_id = itchat_msg["MsgId"]
|
| 15 |
+
self.create_time = itchat_msg["CreateTime"]
|
| 16 |
+
self.is_group = is_group
|
| 17 |
+
|
| 18 |
+
if itchat_msg["Type"] == TEXT:
|
| 19 |
+
self.ctype = ContextType.TEXT
|
| 20 |
+
self.content = itchat_msg["Text"]
|
| 21 |
+
elif itchat_msg["Type"] == VOICE:
|
| 22 |
+
self.ctype = ContextType.VOICE
|
| 23 |
+
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
|
| 24 |
+
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
| 25 |
+
elif itchat_msg["Type"] == PICTURE and itchat_msg["MsgType"] == 3:
|
| 26 |
+
self.ctype = ContextType.IMAGE
|
| 27 |
+
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
|
| 28 |
+
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
| 29 |
+
elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000:
|
| 30 |
+
if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]):
|
| 31 |
+
self.ctype = ContextType.JOIN_GROUP
|
| 32 |
+
self.content = itchat_msg["Content"]
|
| 33 |
+
# 这里只能得到nickname, actual_user_id还是机器人的id
|
| 34 |
+
if "加入了群聊" in itchat_msg["Content"]:
|
| 35 |
+
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
|
| 36 |
+
elif "加入群聊" in itchat_msg["Content"]:
|
| 37 |
+
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
| 38 |
+
elif "拍了拍我" in itchat_msg["Content"]:
|
| 39 |
+
self.ctype = ContextType.PATPAT
|
| 40 |
+
self.content = itchat_msg["Content"]
|
| 41 |
+
if is_group:
|
| 42 |
+
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
| 43 |
+
else:
|
| 44 |
+
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
|
| 45 |
+
else:
|
| 46 |
+
raise NotImplementedError("Unsupported message type: Type:{} MsgType:{}".format(itchat_msg["Type"], itchat_msg["MsgType"]))
|
| 47 |
+
|
| 48 |
+
self.from_user_id = itchat_msg["FromUserName"]
|
| 49 |
+
self.to_user_id = itchat_msg["ToUserName"]
|
| 50 |
+
|
| 51 |
+
user_id = itchat.instance.storageClass.userName
|
| 52 |
+
nickname = itchat.instance.storageClass.nickName
|
| 53 |
+
|
| 54 |
+
# 虽然from_user_id和to_user_id用的少,但是为了保持一致性,还是要填充一下
|
| 55 |
+
# 以下很繁琐,一句话总结:能填的都填了。
|
| 56 |
+
if self.from_user_id == user_id:
|
| 57 |
+
self.from_user_nickname = nickname
|
| 58 |
+
if self.to_user_id == user_id:
|
| 59 |
+
self.to_user_nickname = nickname
|
| 60 |
+
try: # 陌生人时候, User字段可能不存在
|
| 61 |
+
# my_msg 为True是表示是自己发送的消息
|
| 62 |
+
self.my_msg = itchat_msg["ToUserName"] == itchat_msg["User"]["UserName"] and \
|
| 63 |
+
itchat_msg["ToUserName"] != itchat_msg["FromUserName"]
|
| 64 |
+
self.other_user_id = itchat_msg["User"]["UserName"]
|
| 65 |
+
self.other_user_nickname = itchat_msg["User"]["NickName"]
|
| 66 |
+
if self.other_user_id == self.from_user_id:
|
| 67 |
+
self.from_user_nickname = self.other_user_nickname
|
| 68 |
+
if self.other_user_id == self.to_user_id:
|
| 69 |
+
self.to_user_nickname = self.other_user_nickname
|
| 70 |
+
if itchat_msg["User"].get("Self"):
|
| 71 |
+
# 自身的展示名,当设置了群昵称时,该字段表示群昵称
|
| 72 |
+
self.self_display_name = itchat_msg["User"].get("Self").get("DisplayName")
|
| 73 |
+
except KeyError as e: # 处理偶尔没有对方信息的情况
|
| 74 |
+
logger.warn("[WX]get other_user_id failed: " + str(e))
|
| 75 |
+
if self.from_user_id == user_id:
|
| 76 |
+
self.other_user_id = self.to_user_id
|
| 77 |
+
else:
|
| 78 |
+
self.other_user_id = self.from_user_id
|
| 79 |
+
|
| 80 |
+
if self.is_group:
|
| 81 |
+
self.is_at = itchat_msg["IsAt"]
|
| 82 |
+
self.actual_user_id = itchat_msg["ActualUserName"]
|
| 83 |
+
if self.ctype not in [ContextType.JOIN_GROUP, ContextType.PATPAT]:
|
| 84 |
+
self.actual_user_nickname = itchat_msg["ActualNickName"]
|
channel/wechat/wechaty_channel.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# encoding:utf-8
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
wechaty channel
|
| 5 |
+
Python Wechaty - https://github.com/wechaty/python-wechaty
|
| 6 |
+
"""
|
| 7 |
+
import asyncio
|
| 8 |
+
import base64
|
| 9 |
+
import os
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
from wechaty import Contact, Wechaty
|
| 13 |
+
from wechaty.user import Message
|
| 14 |
+
from wechaty_puppet import FileBox
|
| 15 |
+
|
| 16 |
+
from bridge.context import *
|
| 17 |
+
from bridge.context import Context
|
| 18 |
+
from bridge.reply import *
|
| 19 |
+
from channel.chat_channel import ChatChannel
|
| 20 |
+
from channel.wechat.wechaty_message import WechatyMessage
|
| 21 |
+
from common.log import logger
|
| 22 |
+
from common.singleton import singleton
|
| 23 |
+
from config import conf
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from voice.audio_convert import any_to_sil
|
| 27 |
+
except Exception as e:
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@singleton
|
| 32 |
+
class WechatyChannel(ChatChannel):
|
| 33 |
+
NOT_SUPPORT_REPLYTYPE = []
|
| 34 |
+
|
| 35 |
+
def __init__(self):
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
def startup(self):
|
| 39 |
+
config = conf()
|
| 40 |
+
token = config.get("wechaty_puppet_service_token")
|
| 41 |
+
os.environ["WECHATY_PUPPET_SERVICE_TOKEN"] = token
|
| 42 |
+
asyncio.run(self.main())
|
| 43 |
+
|
| 44 |
+
async def main(self):
|
| 45 |
+
loop = asyncio.get_event_loop()
|
| 46 |
+
# 将asyncio的loop传入处理线程
|
| 47 |
+
self.handler_pool._initializer = lambda: asyncio.set_event_loop(loop)
|
| 48 |
+
self.bot = Wechaty()
|
| 49 |
+
self.bot.on("login", self.on_login)
|
| 50 |
+
self.bot.on("message", self.on_message)
|
| 51 |
+
await self.bot.start()
|
| 52 |
+
|
| 53 |
+
async def on_login(self, contact: Contact):
|
| 54 |
+
self.user_id = contact.contact_id
|
| 55 |
+
self.name = contact.name
|
| 56 |
+
logger.info("[WX] login user={}".format(contact))
|
| 57 |
+
|
| 58 |
+
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
| 59 |
+
def send(self, reply: Reply, context: Context):
|
| 60 |
+
receiver_id = context["receiver"]
|
| 61 |
+
loop = asyncio.get_event_loop()
|
| 62 |
+
if context["isgroup"]:
|
| 63 |
+
receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id), loop).result()
|
| 64 |
+
else:
|
| 65 |
+
receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id), loop).result()
|
| 66 |
+
msg = None
|
| 67 |
+
if reply.type == ReplyType.TEXT:
|
| 68 |
+
msg = reply.content
|
| 69 |
+
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
| 70 |
+
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
| 71 |
+
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
| 72 |
+
msg = reply.content
|
| 73 |
+
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
| 74 |
+
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
| 75 |
+
elif reply.type == ReplyType.VOICE:
|
| 76 |
+
voiceLength = None
|
| 77 |
+
file_path = reply.content
|
| 78 |
+
sil_file = os.path.splitext(file_path)[0] + ".sil"
|
| 79 |
+
voiceLength = int(any_to_sil(file_path, sil_file))
|
| 80 |
+
if voiceLength >= 60000:
|
| 81 |
+
voiceLength = 60000
|
| 82 |
+
logger.info("[WX] voice too long, length={}, set to 60s".format(voiceLength))
|
| 83 |
+
# 发送语音
|
| 84 |
+
t = int(time.time())
|
| 85 |
+
msg = FileBox.from_file(sil_file, name=str(t) + ".sil")
|
| 86 |
+
if voiceLength is not None:
|
| 87 |
+
msg.metadata["voiceLength"] = voiceLength
|
| 88 |
+
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
| 89 |
+
try:
|
| 90 |
+
os.remove(file_path)
|
| 91 |
+
if sil_file != file_path:
|
| 92 |
+
os.remove(sil_file)
|
| 93 |
+
except Exception as e:
|
| 94 |
+
pass
|
| 95 |
+
logger.info("[WX] sendVoice={}, receiver={}".format(reply.content, receiver))
|
| 96 |
+
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
| 97 |
+
img_url = reply.content
|
| 98 |
+
t = int(time.time())
|
| 99 |
+
msg = FileBox.from_url(url=img_url, name=str(t) + ".png")
|
| 100 |
+
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
| 101 |
+
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
|
| 102 |
+
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
| 103 |
+
image_storage = reply.content
|
| 104 |
+
image_storage.seek(0)
|
| 105 |
+
t = int(time.time())
|
| 106 |
+
msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + ".png")
|
| 107 |
+
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
| 108 |
+
logger.info("[WX] sendImage, receiver={}".format(receiver))
|
| 109 |
+
|
| 110 |
+
async def on_message(self, msg: Message):
|
| 111 |
+
"""
|
| 112 |
+
listen for message event
|
| 113 |
+
"""
|
| 114 |
+
try:
|
| 115 |
+
cmsg = await WechatyMessage(msg)
|
| 116 |
+
except NotImplementedError as e:
|
| 117 |
+
logger.debug("[WX] {}".format(e))
|
| 118 |
+
return
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logger.exception("[WX] {}".format(e))
|
| 121 |
+
return
|
| 122 |
+
logger.debug("[WX] message:{}".format(cmsg))
|
| 123 |
+
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
|
| 124 |
+
isgroup = room is not None
|
| 125 |
+
ctype = cmsg.ctype
|
| 126 |
+
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
|
| 127 |
+
if context:
|
| 128 |
+
logger.info("[WX] receiveMsg={}, context={}".format(cmsg, context))
|
| 129 |
+
self.produce(context)
|
channel/wechat/wechaty_message.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
from wechaty import MessageType
|
| 5 |
+
from wechaty.user import Message
|
| 6 |
+
|
| 7 |
+
from bridge.context import ContextType
|
| 8 |
+
from channel.chat_message import ChatMessage
|
| 9 |
+
from common.log import logger
|
| 10 |
+
from common.tmp_dir import TmpDir
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class aobject(object):
|
| 14 |
+
"""Inheriting this class allows you to define an async __init__.
|
| 15 |
+
|
| 16 |
+
So you can create objects by doing something like `await MyClass(params)`
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
async def __new__(cls, *a, **kw):
|
| 20 |
+
instance = super().__new__(cls)
|
| 21 |
+
await instance.__init__(*a, **kw)
|
| 22 |
+
return instance
|
| 23 |
+
|
| 24 |
+
async def __init__(self):
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class WechatyMessage(ChatMessage, aobject):
|
| 29 |
+
async def __init__(self, wechaty_msg: Message):
|
| 30 |
+
super().__init__(wechaty_msg)
|
| 31 |
+
|
| 32 |
+
room = wechaty_msg.room()
|
| 33 |
+
|
| 34 |
+
self.msg_id = wechaty_msg.message_id
|
| 35 |
+
self.create_time = wechaty_msg.payload.timestamp
|
| 36 |
+
self.is_group = room is not None
|
| 37 |
+
|
| 38 |
+
if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT:
|
| 39 |
+
self.ctype = ContextType.TEXT
|
| 40 |
+
self.content = wechaty_msg.text()
|
| 41 |
+
elif wechaty_msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
|
| 42 |
+
self.ctype = ContextType.VOICE
|
| 43 |
+
voice_file = await wechaty_msg.to_file_box()
|
| 44 |
+
self.content = TmpDir().path() + voice_file.name # content直接存临时目录路径
|
| 45 |
+
|
| 46 |
+
def func():
|
| 47 |
+
loop = asyncio.get_event_loop()
|
| 48 |
+
asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content), loop).result()
|
| 49 |
+
|
| 50 |
+
self._prepare_fn = func
|
| 51 |
+
|
| 52 |
+
else:
|
| 53 |
+
raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
|
| 54 |
+
|
| 55 |
+
from_contact = wechaty_msg.talker() # 获取消息的发送者
|
| 56 |
+
self.from_user_id = from_contact.contact_id
|
| 57 |
+
self.from_user_nickname = from_contact.name
|
| 58 |
+
|
| 59 |
+
# group中的from和to,wechaty跟itchat含义不一样
|
| 60 |
+
# wecahty: from是消息实际发送者, to:所在群
|
| 61 |
+
# itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己
|
| 62 |
+
# 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户
|
| 63 |
+
|
| 64 |
+
if self.is_group:
|
| 65 |
+
self.to_user_id = room.room_id
|
| 66 |
+
self.to_user_nickname = await room.topic()
|
| 67 |
+
else:
|
| 68 |
+
to_contact = wechaty_msg.to()
|
| 69 |
+
self.to_user_id = to_contact.contact_id
|
| 70 |
+
self.to_user_nickname = to_contact.name
|
| 71 |
+
|
| 72 |
+
if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
|
| 73 |
+
self.other_user_id = self.to_user_id
|
| 74 |
+
self.other_user_nickname = self.to_user_nickname
|
| 75 |
+
else:
|
| 76 |
+
self.other_user_id = self.from_user_id
|
| 77 |
+
self.other_user_nickname = self.from_user_nickname
|
| 78 |
+
|
| 79 |
+
if self.is_group: # wechaty群聊中,实际发送用户就是from_user
|
| 80 |
+
self.is_at = await wechaty_msg.mention_self()
|
| 81 |
+
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容
|
| 82 |
+
name = wechaty_msg.wechaty.user_self().name
|
| 83 |
+
pattern = f"@{re.escape(name)}(\u2005|\u0020)"
|
| 84 |
+
if re.search(pattern, self.content):
|
| 85 |
+
logger.debug(f"wechaty message {self.msg_id} include at")
|
| 86 |
+
self.is_at = True
|
| 87 |
+
|
| 88 |
+
self.actual_user_id = self.from_user_id
|
| 89 |
+
self.actual_user_nickname = self.from_user_nickname
|
channel/wechatcom/README.md
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 企业微信应用号channel
|
| 2 |
+
|
| 3 |
+
企业微信官方提供了客服、应用等API,本channel使用的是企业微信的自建应用API的能力。
|
| 4 |
+
|
| 5 |
+
因为未来可能还会开发客服能力,所以本channel的类型名叫作`wechatcom_app`。
|
| 6 |
+
|
| 7 |
+
`wechatcom_app` channel支持插件系统和图片声音交互等能力,除了无法加入群聊,作为个人使用的私人助理已绰绰有余。
|
| 8 |
+
|
| 9 |
+
## 开始之前
|
| 10 |
+
|
| 11 |
+
- 在企业中确认自己拥有在企业内自建应用的权限。
|
| 12 |
+
- 如果没有权限或者是个人用户,也可创建未认证的企业。操作方式:登录手机企业微信,选择`创建/加入企业`来创建企业,类型请选择企业,企业名称可随意填写。
|
| 13 |
+
未认证的企业有100人的服务人数上限,其他功能与认证企业没有差异。
|
| 14 |
+
|
| 15 |
+
本channel需安装的依赖与公众号一致,需要安装`wechatpy`和`web.py`,它们包含在`requirements-optional.txt`中。
|
| 16 |
+
|
| 17 |
+
此外,如果你是`Linux`系统,除了`ffmpeg`还需要安装`amr`编码器,否则会出现找不到编码器的错误,无法正常使用语音功能。
|
| 18 |
+
|
| 19 |
+
- Ubuntu/Debian
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
apt-get install libavcodec-extra
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
- Alpine
|
| 26 |
+
|
| 27 |
+
需自行编译`ffmpeg`,在编译参数里加入`amr`编码器的支持
|
| 28 |
+
|
| 29 |
+
## 使用方法
|
| 30 |
+
|
| 31 |
+
1.查看企业ID
|
| 32 |
+
|
| 33 |
+
- 扫码登陆[企业微信后台](https://work.weixin.qq.com)
|
| 34 |
+
- 选择`我的企业`,点击`企业信息`,记住该`企业ID`
|
| 35 |
+
|
| 36 |
+
2.创建自建应用
|
| 37 |
+
|
| 38 |
+
- 选择应用管理, 在自建区选创建应用来创建企业自建应用
|
| 39 |
+
- 上传应用logo,填写应用名称等项
|
| 40 |
+
- 创建应用后进入应用详情页面,记住`AgentId`和`Secert`
|
| 41 |
+
|
| 42 |
+
3.配置应用
|
| 43 |
+
|
| 44 |
+
- 在详情页点击`企业可信IP`的配置(没看到可以不管),填入你服务器的公网IP,如果不知道可以先不填
|
| 45 |
+
- 点击`接收消息`下的启用API接收消息
|
| 46 |
+
- `URL`填写格式为`http://url:port/wxcomapp`,`port`是程序监听的端口,默认是9898
|
| 47 |
+
如果是未认证的企业,url可直接使用服务器的IP。如果是认证企业,需要使用备案的域名,可使用二级域名。
|
| 48 |
+
- `Token`可随意填写,停留在这个页面
|
| 49 |
+
- 在程序根目录`config.json`中增加配置(**去掉注释**),`wechatcomapp_aes_key`是当前页面的`wechatcomapp_aes_key`
|
| 50 |
+
|
| 51 |
+
```python
|
| 52 |
+
"channel_type": "wechatcom_app",
|
| 53 |
+
"wechatcom_corp_id": "", # 企业微信公司的corpID
|
| 54 |
+
"wechatcomapp_token": "", # 企业微信app的token
|
| 55 |
+
"wechatcomapp_port": 9898, # 企业微信app的服务端口, 不需要端口转发
|
| 56 |
+
"wechatcomapp_secret": "", # 企业微信app的secret
|
| 57 |
+
"wechatcomapp_agent_id": "", # 企业微信app的agent_id
|
| 58 |
+
"wechatcomapp_aes_key": "", # 企业微信app的aes_key
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
- 运行程序,在页面中点击保存,保存成功说明验证成功
|
| 62 |
+
|
| 63 |
+
4.连接个人微信
|
| 64 |
+
|
| 65 |
+
选择`我的企业`,点击`微信插件`,下面有个邀请关注的二维码。微信扫码后,即可在微信中看到对应企业,在这里你便可以和机器人沟通。
|
| 66 |
+
|
| 67 |
+
向机器人发送消息,如果日志里出现报错:
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
Error code: 60020, message: "not allow to access from your ip, ...from ip: xx.xx.xx.xx"
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
意思是IP不可信,需要参考上一步的`企业可信IP`配置,把这里的IP加进去。
|
| 74 |
+
|
| 75 |
+
~~### Railway部署方式~~(2023-06-08已失效)
|
| 76 |
+
|
| 77 |
+
~~公众号不能在`Railway`上部署,但企业微信应用[可以](https://railway.app/template/-FHS--?referralCode=RC3znh)!~~
|
| 78 |
+
|
| 79 |
+
~~填写配置后,将部署完成后的网址```**.railway.app/wxcomapp```,填写在上一步的URL中。发送信息后观察日志,把报错的IP加入到可信IP。(每次重启后都需要加入可信IP)~~
|
| 80 |
+
|
| 81 |
+
## 测试体验
|
| 82 |
+
|
| 83 |
+
AIGC开放社区中已经部署了多个可免费使用的Bot,扫描下方的二维码会自动邀请你来体验。
|
| 84 |
+
|
| 85 |
+
<img width="200" src="../../docs/images/aigcopen.png">
|
channel/wechatcom/wechatcomapp_channel.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding=utf-8 -*-
|
| 2 |
+
import io
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import requests
|
| 7 |
+
import web
|
| 8 |
+
from wechatpy.enterprise import create_reply, parse_message
|
| 9 |
+
from wechatpy.enterprise.crypto import WeChatCrypto
|
| 10 |
+
from wechatpy.enterprise.exceptions import InvalidCorpIdException
|
| 11 |
+
from wechatpy.exceptions import InvalidSignatureException, WeChatClientException
|
| 12 |
+
|
| 13 |
+
from bridge.context import Context
|
| 14 |
+
from bridge.reply import Reply, ReplyType
|
| 15 |
+
from channel.chat_channel import ChatChannel
|
| 16 |
+
from channel.wechatcom.wechatcomapp_client import WechatComAppClient
|
| 17 |
+
from channel.wechatcom.wechatcomapp_message import WechatComAppMessage
|
| 18 |
+
from common.log import logger
|
| 19 |
+
from common.singleton import singleton
|
| 20 |
+
from common.utils import compress_imgfile, fsize, split_string_by_utf8_length
|
| 21 |
+
from config import conf, subscribe_msg
|
| 22 |
+
from voice.audio_convert import any_to_amr, split_audio
|
| 23 |
+
|
| 24 |
+
MAX_UTF8_LEN = 2048
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@singleton
|
| 28 |
+
class WechatComAppChannel(ChatChannel):
|
| 29 |
+
NOT_SUPPORT_REPLYTYPE = []
|
| 30 |
+
|
| 31 |
+
def __init__(self):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.corp_id = conf().get("wechatcom_corp_id")
|
| 34 |
+
self.secret = conf().get("wechatcomapp_secret")
|
| 35 |
+
self.agent_id = conf().get("wechatcomapp_agent_id")
|
| 36 |
+
self.token = conf().get("wechatcomapp_token")
|
| 37 |
+
self.aes_key = conf().get("wechatcomapp_aes_key")
|
| 38 |
+
print(self.corp_id, self.secret, self.agent_id, self.token, self.aes_key)
|
| 39 |
+
logger.info(
|
| 40 |
+
"[wechatcom] init: corp_id: {}, secret: {}, agent_id: {}, token: {}, aes_key: {}".format(self.corp_id, self.secret, self.agent_id, self.token, self.aes_key)
|
| 41 |
+
)
|
| 42 |
+
self.crypto = WeChatCrypto(self.token, self.aes_key, self.corp_id)
|
| 43 |
+
self.client = WechatComAppClient(self.corp_id, self.secret)
|
| 44 |
+
|
| 45 |
+
def startup(self):
|
| 46 |
+
# start message listener
|
| 47 |
+
urls = ("/wxcomapp", "channel.wechatcom.wechatcomapp_channel.Query")
|
| 48 |
+
app = web.application(urls, globals(), autoreload=False)
|
| 49 |
+
port = conf().get("wechatcomapp_port", 9898)
|
| 50 |
+
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
| 51 |
+
|
| 52 |
+
def send(self, reply: Reply, context: Context):
|
| 53 |
+
receiver = context["receiver"]
|
| 54 |
+
if reply.type in [ReplyType.TEXT, ReplyType.ERROR, ReplyType.INFO]:
|
| 55 |
+
reply_text = reply.content
|
| 56 |
+
texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
|
| 57 |
+
if len(texts) > 1:
|
| 58 |
+
logger.info("[wechatcom] text too long, split into {} parts".format(len(texts)))
|
| 59 |
+
for i, text in enumerate(texts):
|
| 60 |
+
self.client.message.send_text(self.agent_id, receiver, text)
|
| 61 |
+
if i != len(texts) - 1:
|
| 62 |
+
time.sleep(0.5) # 休眠0.5秒,防止发送过快乱序
|
| 63 |
+
logger.info("[wechatcom] Do send text to {}: {}".format(receiver, reply_text))
|
| 64 |
+
elif reply.type == ReplyType.VOICE:
|
| 65 |
+
try:
|
| 66 |
+
media_ids = []
|
| 67 |
+
file_path = reply.content
|
| 68 |
+
amr_file = os.path.splitext(file_path)[0] + ".amr"
|
| 69 |
+
any_to_amr(file_path, amr_file)
|
| 70 |
+
duration, files = split_audio(amr_file, 60 * 1000)
|
| 71 |
+
if len(files) > 1:
|
| 72 |
+
logger.info("[wechatcom] voice too long {}s > 60s , split into {} parts".format(duration / 1000.0, len(files)))
|
| 73 |
+
for path in files:
|
| 74 |
+
response = self.client.media.upload("voice", open(path, "rb"))
|
| 75 |
+
logger.debug("[wechatcom] upload voice response: {}".format(response))
|
| 76 |
+
media_ids.append(response["media_id"])
|
| 77 |
+
except WeChatClientException as e:
|
| 78 |
+
logger.error("[wechatcom] upload voice failed: {}".format(e))
|
| 79 |
+
return
|
| 80 |
+
try:
|
| 81 |
+
os.remove(file_path)
|
| 82 |
+
if amr_file != file_path:
|
| 83 |
+
os.remove(amr_file)
|
| 84 |
+
except Exception:
|
| 85 |
+
pass
|
| 86 |
+
for media_id in media_ids:
|
| 87 |
+
self.client.message.send_voice(self.agent_id, receiver, media_id)
|
| 88 |
+
time.sleep(1)
|
| 89 |
+
logger.info("[wechatcom] sendVoice={}, receiver={}".format(reply.content, receiver))
|
| 90 |
+
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
| 91 |
+
img_url = reply.content
|
| 92 |
+
pic_res = requests.get(img_url, stream=True)
|
| 93 |
+
image_storage = io.BytesIO()
|
| 94 |
+
for block in pic_res.iter_content(1024):
|
| 95 |
+
image_storage.write(block)
|
| 96 |
+
sz = fsize(image_storage)
|
| 97 |
+
if sz >= 10 * 1024 * 1024:
|
| 98 |
+
logger.info("[wechatcom] image too large, ready to compress, sz={}".format(sz))
|
| 99 |
+
image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
|
| 100 |
+
logger.info("[wechatcom] image compressed, sz={}".format(fsize(image_storage)))
|
| 101 |
+
image_storage.seek(0)
|
| 102 |
+
try:
|
| 103 |
+
response = self.client.media.upload("image", image_storage)
|
| 104 |
+
logger.debug("[wechatcom] upload image response: {}".format(response))
|
| 105 |
+
except WeChatClientException as e:
|
| 106 |
+
logger.error("[wechatcom] upload image failed: {}".format(e))
|
| 107 |
+
return
|
| 108 |
+
|
| 109 |
+
self.client.message.send_image(self.agent_id, receiver, response["media_id"])
|
| 110 |
+
logger.info("[wechatcom] sendImage url={}, receiver={}".format(img_url, receiver))
|
| 111 |
+
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
| 112 |
+
image_storage = reply.content
|
| 113 |
+
sz = fsize(image_storage)
|
| 114 |
+
if sz >= 10 * 1024 * 1024:
|
| 115 |
+
logger.info("[wechatcom] image too large, ready to compress, sz={}".format(sz))
|
| 116 |
+
image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
|
| 117 |
+
logger.info("[wechatcom] image compressed, sz={}".format(fsize(image_storage)))
|
| 118 |
+
image_storage.seek(0)
|
| 119 |
+
try:
|
| 120 |
+
response = self.client.media.upload("image", image_storage)
|
| 121 |
+
logger.debug("[wechatcom] upload image response: {}".format(response))
|
| 122 |
+
except WeChatClientException as e:
|
| 123 |
+
logger.error("[wechatcom] upload image failed: {}".format(e))
|
| 124 |
+
return
|
| 125 |
+
self.client.message.send_image(self.agent_id, receiver, response["media_id"])
|
| 126 |
+
logger.info("[wechatcom] sendImage, receiver={}".format(receiver))
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class Query:
|
| 130 |
+
def GET(self):
|
| 131 |
+
channel = WechatComAppChannel()
|
| 132 |
+
params = web.input()
|
| 133 |
+
logger.info("[wechatcom] receive params: {}".format(params))
|
| 134 |
+
try:
|
| 135 |
+
signature = params.msg_signature
|
| 136 |
+
timestamp = params.timestamp
|
| 137 |
+
nonce = params.nonce
|
| 138 |
+
echostr = params.echostr
|
| 139 |
+
echostr = channel.crypto.check_signature(signature, timestamp, nonce, echostr)
|
| 140 |
+
except InvalidSignatureException:
|
| 141 |
+
raise web.Forbidden()
|
| 142 |
+
return echostr
|
| 143 |
+
|
| 144 |
+
def POST(self):
|
| 145 |
+
channel = WechatComAppChannel()
|
| 146 |
+
params = web.input()
|
| 147 |
+
logger.info("[wechatcom] receive params: {}".format(params))
|
| 148 |
+
try:
|
| 149 |
+
signature = params.msg_signature
|
| 150 |
+
timestamp = params.timestamp
|
| 151 |
+
nonce = params.nonce
|
| 152 |
+
message = channel.crypto.decrypt_message(web.data(), signature, timestamp, nonce)
|
| 153 |
+
except (InvalidSignatureException, InvalidCorpIdException):
|
| 154 |
+
raise web.Forbidden()
|
| 155 |
+
msg = parse_message(message)
|
| 156 |
+
logger.debug("[wechatcom] receive message: {}, msg= {}".format(message, msg))
|
| 157 |
+
if msg.type == "event":
|
| 158 |
+
if msg.event == "subscribe":
|
| 159 |
+
reply_content = subscribe_msg()
|
| 160 |
+
if reply_content:
|
| 161 |
+
reply = create_reply(reply_content, msg).render()
|
| 162 |
+
res = channel.crypto.encrypt_message(reply, nonce, timestamp)
|
| 163 |
+
return res
|
| 164 |
+
else:
|
| 165 |
+
try:
|
| 166 |
+
wechatcom_msg = WechatComAppMessage(msg, client=channel.client)
|
| 167 |
+
except NotImplementedError as e:
|
| 168 |
+
logger.debug("[wechatcom] " + str(e))
|
| 169 |
+
return "success"
|
| 170 |
+
context = channel._compose_context(
|
| 171 |
+
wechatcom_msg.ctype,
|
| 172 |
+
wechatcom_msg.content,
|
| 173 |
+
isgroup=False,
|
| 174 |
+
msg=wechatcom_msg,
|
| 175 |
+
)
|
| 176 |
+
if context:
|
| 177 |
+
channel.produce(context)
|
| 178 |
+
return "success"
|
channel/wechatcom/wechatcomapp_client.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
from wechatpy.enterprise import WeChatClient
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class WechatComAppClient(WeChatClient):
|
| 8 |
+
def __init__(self, corp_id, secret, access_token=None, session=None, timeout=None, auto_retry=True):
|
| 9 |
+
super(WechatComAppClient, self).__init__(corp_id, secret, access_token, session, timeout, auto_retry)
|
| 10 |
+
self.fetch_access_token_lock = threading.Lock()
|
| 11 |
+
|
| 12 |
+
def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
|
| 13 |
+
with self.fetch_access_token_lock:
|
| 14 |
+
access_token = self.session.get(self.access_token_key)
|
| 15 |
+
if access_token:
|
| 16 |
+
if not self.expires_at:
|
| 17 |
+
return access_token
|
| 18 |
+
timestamp = time.time()
|
| 19 |
+
if self.expires_at - timestamp > 60:
|
| 20 |
+
return access_token
|
| 21 |
+
return super().fetch_access_token()
|
channel/wechatcom/wechatcomapp_message.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from wechatpy.enterprise import WeChatClient
|
| 2 |
+
|
| 3 |
+
from bridge.context import ContextType
|
| 4 |
+
from channel.chat_message import ChatMessage
|
| 5 |
+
from common.log import logger
|
| 6 |
+
from common.tmp_dir import TmpDir
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class WechatComAppMessage(ChatMessage):
|
| 10 |
+
def __init__(self, msg, client: WeChatClient, is_group=False):
|
| 11 |
+
super().__init__(msg)
|
| 12 |
+
self.msg_id = msg.id
|
| 13 |
+
self.create_time = msg.time
|
| 14 |
+
self.is_group = is_group
|
| 15 |
+
|
| 16 |
+
if msg.type == "text":
|
| 17 |
+
self.ctype = ContextType.TEXT
|
| 18 |
+
self.content = msg.content
|
| 19 |
+
elif msg.type == "voice":
|
| 20 |
+
self.ctype = ContextType.VOICE
|
| 21 |
+
self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径
|
| 22 |
+
|
| 23 |
+
def download_voice():
|
| 24 |
+
# 如果响应状态码是200,则将响应内容写入本地文件
|
| 25 |
+
response = client.media.download(msg.media_id)
|
| 26 |
+
if response.status_code == 200:
|
| 27 |
+
with open(self.content, "wb") as f:
|
| 28 |
+
f.write(response.content)
|
| 29 |
+
else:
|
| 30 |
+
logger.info(f"[wechatcom] Failed to download voice file, {response.content}")
|
| 31 |
+
|
| 32 |
+
self._prepare_fn = download_voice
|
| 33 |
+
elif msg.type == "image":
|
| 34 |
+
self.ctype = ContextType.IMAGE
|
| 35 |
+
self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径
|
| 36 |
+
|
| 37 |
+
def download_image():
|
| 38 |
+
# 如果响应状态码是200,则将响应内容写入本地文件
|
| 39 |
+
response = client.media.download(msg.media_id)
|
| 40 |
+
if response.status_code == 200:
|
| 41 |
+
with open(self.content, "wb") as f:
|
| 42 |
+
f.write(response.content)
|
| 43 |
+
else:
|
| 44 |
+
logger.info(f"[wechatcom] Failed to download image file, {response.content}")
|
| 45 |
+
|
| 46 |
+
self._prepare_fn = download_image
|
| 47 |
+
else:
|
| 48 |
+
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))
|
| 49 |
+
|
| 50 |
+
self.from_user_id = msg.source
|
| 51 |
+
self.to_user_id = msg.target
|
| 52 |
+
self.other_user_id = msg.source
|
channel/wechatmp/README.md
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 微信公众号channel
|
| 2 |
+
|
| 3 |
+
鉴于个人微信号在服务器上通过itchat登录有封号风险,这里新增了微信公众号channel,提供无风险的服务。
|
| 4 |
+
目前支持订阅号和服务号两种类型的公众号,它们都支持文本交互,语音和图片输入。其中个人主体的微信订阅号由于无法通过微信认证,存在回复时间限制,每天的图片和声音回复次数也有限制。
|
| 5 |
+
|
| 6 |
+
## 使用方法(订阅号,服务号类似)
|
| 7 |
+
|
| 8 |
+
在开始部署前,你需要一个拥有公网IP的服务器,以提供微信服务器和我们自己服务器的连接。或者你需要进行内网穿透,否则微信服务器无法将消息发送给我们的服务器。
|
| 9 |
+
|
| 10 |
+
此外,需要在我们的服务器上安装python的web框架web.py和wechatpy。
|
| 11 |
+
以ubuntu为例(在ubuntu 22.04上测试):
|
| 12 |
+
```
|
| 13 |
+
pip3 install web.py
|
| 14 |
+
pip3 install wechatpy
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
然后在[微信公众平台](https://mp.weixin.qq.com)注册一个自己的公众号,类型选择订阅号,主体为个人即可。
|
| 18 |
+
|
| 19 |
+
然后根据[接入指南](https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Access_Overview.html)的说明,在[微信公众平台](https://mp.weixin.qq.com)的“设置与开发”-“基本配置”-“服务器配置”中填写服务器地址`URL`和令牌`Token`。`URL`填写格式为`http://url/wx`,可使用IP(成功几率看脸),`Token`是你自己编的一个特定的令牌。消息加解密方式如果选择了需要加密的模式,需要在配置中填写`wechatmp_aes_key`。
|
| 20 |
+
|
| 21 |
+
相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加
|
| 22 |
+
```
|
| 23 |
+
"channel_type": "wechatmp", # 如果通过了微信认证,将"wechatmp"替换为"wechatmp_service",可极大的优化使用体验
|
| 24 |
+
"wechatmp_token": "xxxx", # 微信公众平台的Token
|
| 25 |
+
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
|
| 26 |
+
"wechatmp_app_id": "xxxx", # 微信公众平台的appID
|
| 27 |
+
"wechatmp_app_secret": "xxxx", # 微信公众平台的appsecret
|
| 28 |
+
"wechatmp_aes_key": "", # 微信公众平台的EncodingAESKey,加密模式需要
|
| 29 |
+
"single_chat_prefix": [""], # 推荐设置,任意对话都可以触发回复,不添加前缀
|
| 30 |
+
"single_chat_reply_prefix": "", # 推荐设置,回复不设置前缀
|
| 31 |
+
"plugin_trigger_prefix": "&", # 推荐设置,在手机微信客户端中,$%^等符号与中文连在一起时会自动显示一段较大的间隔,用户体验不好。请不要使用管理员指令前缀"#",这会造成未知问题。
|
| 32 |
+
```
|
| 33 |
+
然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口,但是微信公众号的服务器配置只支持80/443端口,有两种方法来解决这个问题。第一个是推荐的方法,使用端口转发命令将80端口转发到8080端口:
|
| 34 |
+
```
|
| 35 |
+
sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8080
|
| 36 |
+
sudo iptables-save > /etc/iptables/rules.v4
|
| 37 |
+
```
|
| 38 |
+
第二个方法是让python程序直接监听80端口,在配置文件中设置`"wechatmp_port": 80` ,在linux上需要使用`sudo python3 app.py`启动程序。然而这会导致一系列环境和权限问题,因此不是推荐的方法。
|
| 39 |
+
|
| 40 |
+
443端口同理,注意需要支持SSL,也就是https的访问,在`wechatmp_channel.py`中需要修改相应的证书路径。
|
| 41 |
+
|
| 42 |
+
程序启动并监听端口后,在刚才的“服务器配置”中点击`提交`即可验证你的服务器。
|
| 43 |
+
随后在[微信公众平台](https://mp.weixin.qq.com)启用服务器,关闭手动填写规则的自动回复,即可实现ChatGPT的自动回复。
|
| 44 |
+
|
| 45 |
+
之后需要在公众号开发信息下将本机IP加入到IP白名单。
|
| 46 |
+
|
| 47 |
+
不然在启用后,发送语音、图片等消息可能会遇到如下报错:
|
| 48 |
+
```
|
| 49 |
+
'errcode': 40164, 'errmsg': 'invalid ip xx.xx.xx.xx not in whitelist rid
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
## 个人微信公众号的限制
|
| 54 |
+
由于人微信公众号不能通过微信认证,所以没有客服接口,因此公众号无法主动发出消息,只能被动回复。而微信官方对被动回复有5秒的时间限制,最多重试2次,因此最多只有15秒的自动回复时间窗口。因此如果问题比较复杂或者我们的服务器比较忙,ChatGPT的回答就没办法及时回复给用户。为了解决这个问题,这里做了回答缓存,它需要你在回复超时后,再次主动发送任意文字(例如1)来尝试拿到回答缓存。为了优化使用体验,目前设置了两分钟(120秒)的timeout,用户在至多两分钟后即可得到查询到回复或者错误原因。
|
| 55 |
+
|
| 56 |
+
另外,由于微信官方的限制,自动回复有长度限制。因此这里将ChatGPT的回答进行了拆分,以满足限制。
|
| 57 |
+
|
| 58 |
+
## 私有api_key
|
| 59 |
+
公共api有访问频率限制(免费账号每分钟最多3次ChatGPT的API调用),这在服务多人的时候会遇到问题。因此这里多加了一个设置私有api_key的功能。目前通过godcmd插件的命令来设置私有api_key。
|
| 60 |
+
|
| 61 |
+
## 语音输入
|
| 62 |
+
利用微信自带的语音识别功能,提供语音输入能力。需要在公众号管理页面的“设置与开发”->“接口权限”页面开启“接收语音识别结果”。
|
| 63 |
+
|
| 64 |
+
## 语音回复
|
| 65 |
+
请在配置文件中添加以下词条:
|
| 66 |
+
```
|
| 67 |
+
"voice_reply_voice": true,
|
| 68 |
+
```
|
| 69 |
+
这样公众号将会用语音回复语音消息,实现语音对话。
|
| 70 |
+
|
| 71 |
+
默认的语音合成引擎是`google`,它是免费使用的。
|
| 72 |
+
|
| 73 |
+
如果要选择其他的语音合成引擎,请添加以下配置项:
|
| 74 |
+
```
|
| 75 |
+
"text_to_voice": "pytts"
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
pytts是本地的语音合成引擎。还支持baidu,azure,这些你需要自行配置相关的依赖和key。
|
| 79 |
+
|
| 80 |
+
如果使用pytts,在ubuntu上需要安装如下依赖:
|
| 81 |
+
```
|
| 82 |
+
sudo apt update
|
| 83 |
+
sudo apt install espeak
|
| 84 |
+
sudo apt install ffmpeg
|
| 85 |
+
python3 -m pip install pyttsx3
|
| 86 |
+
```
|
| 87 |
+
不是很建议开启pytts语音回复,因为它是离线本地计算,算的慢会拖垮服务器,且声音不好听。
|
| 88 |
+
|
| 89 |
+
## 图片回复
|
| 90 |
+
现在认证公众号和非认证公众号都可以实现的图片和语音回复。但是非认证公众号使用了永久素材接口,每天有1000次的调用上限(每个月有10次重置机会,程序中已设定遇到上限会自动重置),且永久素材库存也有上限。因此对于非认证公众号,我们会在回复图片或者语音消息后的10秒内从永久素材库存内删除该素材。
|
| 91 |
+
|
| 92 |
+
## 测试
|
| 93 |
+
目前在`RoboStyle`这个公众号上进行了测试(基于[wechatmp分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp)),感兴趣的可以关注并体验。开启了godcmd, Banwords, role, dungeon, finish这五个插件,其他的插件还没有详尽测试。百度的接口暂未测试。[wechatmp-stable分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp-stable)是较稳定的上个版本,但也缺少最新的功能支持。
|
| 94 |
+
|
| 95 |
+
## TODO
|
| 96 |
+
- [x] 语音输入
|
| 97 |
+
- [x] 图片输入
|
| 98 |
+
- [x] 使用临时素材接口提供认证公众号的图片和语音回复
|
| 99 |
+
- [x] 使用永久素材接口提供未认证公众号的图片和语音回复
|
| 100 |
+
- [ ] 高并发支持
|
channel/wechatmp/active_reply.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
import web
|
| 4 |
+
from wechatpy import parse_message
|
| 5 |
+
from wechatpy.replies import create_reply
|
| 6 |
+
|
| 7 |
+
from bridge.context import *
|
| 8 |
+
from bridge.reply import *
|
| 9 |
+
from channel.wechatmp.common import *
|
| 10 |
+
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
| 11 |
+
from channel.wechatmp.wechatmp_message import WeChatMPMessage
|
| 12 |
+
from common.log import logger
|
| 13 |
+
from config import conf, subscribe_msg
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# This class is instantiated once per query
|
| 17 |
+
class Query:
|
| 18 |
+
def GET(self):
|
| 19 |
+
return verify_server(web.input())
|
| 20 |
+
|
| 21 |
+
def POST(self):
|
| 22 |
+
# Make sure to return the instance that first created, @singleton will do that.
|
| 23 |
+
try:
|
| 24 |
+
args = web.input()
|
| 25 |
+
verify_server(args)
|
| 26 |
+
channel = WechatMPChannel()
|
| 27 |
+
message = web.data()
|
| 28 |
+
encrypt_func = lambda x: x
|
| 29 |
+
if args.get("encrypt_type") == "aes":
|
| 30 |
+
logger.debug("[wechatmp] Receive encrypted post data:\n" + message.decode("utf-8"))
|
| 31 |
+
if not channel.crypto:
|
| 32 |
+
raise Exception("Crypto not initialized, Please set wechatmp_aes_key in config.json")
|
| 33 |
+
message = channel.crypto.decrypt_message(message, args.msg_signature, args.timestamp, args.nonce)
|
| 34 |
+
encrypt_func = lambda x: channel.crypto.encrypt_message(x, args.nonce, args.timestamp)
|
| 35 |
+
else:
|
| 36 |
+
logger.debug("[wechatmp] Receive post data:\n" + message.decode("utf-8"))
|
| 37 |
+
msg = parse_message(message)
|
| 38 |
+
if msg.type in ["text", "voice", "image"]:
|
| 39 |
+
wechatmp_msg = WeChatMPMessage(msg, client=channel.client)
|
| 40 |
+
from_user = wechatmp_msg.from_user_id
|
| 41 |
+
content = wechatmp_msg.content
|
| 42 |
+
message_id = wechatmp_msg.msg_id
|
| 43 |
+
|
| 44 |
+
logger.info(
|
| 45 |
+
"[wechatmp] {}:{} Receive post query {} {}: {}".format(
|
| 46 |
+
web.ctx.env.get("REMOTE_ADDR"),
|
| 47 |
+
web.ctx.env.get("REMOTE_PORT"),
|
| 48 |
+
from_user,
|
| 49 |
+
message_id,
|
| 50 |
+
content,
|
| 51 |
+
)
|
| 52 |
+
)
|
| 53 |
+
if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
|
| 54 |
+
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
|
| 55 |
+
else:
|
| 56 |
+
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
|
| 57 |
+
if context:
|
| 58 |
+
channel.produce(context)
|
| 59 |
+
# The reply will be sent by channel.send() in another thread
|
| 60 |
+
return "success"
|
| 61 |
+
elif msg.type == "event":
|
| 62 |
+
logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
|
| 63 |
+
if msg.event in ["subscribe", "subscribe_scan"]:
|
| 64 |
+
reply_text = subscribe_msg()
|
| 65 |
+
if reply_text:
|
| 66 |
+
replyPost = create_reply(reply_text, msg)
|
| 67 |
+
return encrypt_func(replyPost.render())
|
| 68 |
+
else:
|
| 69 |
+
return "success"
|
| 70 |
+
else:
|
| 71 |
+
logger.info("暂且不处理")
|
| 72 |
+
return "success"
|
| 73 |
+
except Exception as exc:
|
| 74 |
+
logger.exception(exc)
|
| 75 |
+
return exc
|
channel/wechatmp/common.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import web
|
| 2 |
+
from wechatpy.crypto import WeChatCrypto
|
| 3 |
+
from wechatpy.exceptions import InvalidSignatureException
|
| 4 |
+
from wechatpy.utils import check_signature
|
| 5 |
+
|
| 6 |
+
from config import conf
|
| 7 |
+
|
| 8 |
+
MAX_UTF8_LEN = 2048
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class WeChatAPIException(Exception):
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def verify_server(data):
|
| 16 |
+
try:
|
| 17 |
+
signature = data.signature
|
| 18 |
+
timestamp = data.timestamp
|
| 19 |
+
nonce = data.nonce
|
| 20 |
+
echostr = data.get("echostr", None)
|
| 21 |
+
token = conf().get("wechatmp_token") # 请按照公众平台官网\基本配置中信息填写
|
| 22 |
+
check_signature(token, signature, timestamp, nonce)
|
| 23 |
+
return echostr
|
| 24 |
+
except InvalidSignatureException:
|
| 25 |
+
raise web.Forbidden("Invalid signature")
|
| 26 |
+
except Exception as e:
|
| 27 |
+
raise web.Forbidden(str(e))
|
channel/wechatmp/passive_reply.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import web
|
| 5 |
+
from wechatpy import parse_message
|
| 6 |
+
from wechatpy.replies import ImageReply, VoiceReply, create_reply
|
| 7 |
+
import textwrap
|
| 8 |
+
from bridge.context import *
|
| 9 |
+
from bridge.reply import *
|
| 10 |
+
from channel.wechatmp.common import *
|
| 11 |
+
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
| 12 |
+
from channel.wechatmp.wechatmp_message import WeChatMPMessage
|
| 13 |
+
from common.log import logger
|
| 14 |
+
from common.utils import split_string_by_utf8_length
|
| 15 |
+
from config import conf, subscribe_msg
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# This class is instantiated once per query
|
| 19 |
+
class Query:
|
| 20 |
+
def GET(self):
|
| 21 |
+
return verify_server(web.input())
|
| 22 |
+
|
| 23 |
+
def POST(self):
|
| 24 |
+
try:
|
| 25 |
+
args = web.input()
|
| 26 |
+
verify_server(args)
|
| 27 |
+
request_time = time.time()
|
| 28 |
+
channel = WechatMPChannel()
|
| 29 |
+
message = web.data()
|
| 30 |
+
encrypt_func = lambda x: x
|
| 31 |
+
if args.get("encrypt_type") == "aes":
|
| 32 |
+
logger.debug("[wechatmp] Receive encrypted post data:\n" + message.decode("utf-8"))
|
| 33 |
+
if not channel.crypto:
|
| 34 |
+
raise Exception("Crypto not initialized, Please set wechatmp_aes_key in config.json")
|
| 35 |
+
message = channel.crypto.decrypt_message(message, args.msg_signature, args.timestamp, args.nonce)
|
| 36 |
+
encrypt_func = lambda x: channel.crypto.encrypt_message(x, args.nonce, args.timestamp)
|
| 37 |
+
else:
|
| 38 |
+
logger.debug("[wechatmp] Receive post data:\n" + message.decode("utf-8"))
|
| 39 |
+
msg = parse_message(message)
|
| 40 |
+
if msg.type in ["text", "voice", "image"]:
|
| 41 |
+
wechatmp_msg = WeChatMPMessage(msg, client=channel.client)
|
| 42 |
+
from_user = wechatmp_msg.from_user_id
|
| 43 |
+
content = wechatmp_msg.content
|
| 44 |
+
message_id = wechatmp_msg.msg_id
|
| 45 |
+
|
| 46 |
+
supported = True
|
| 47 |
+
if "【收到不支持的消息类型,暂无法显示】" in content:
|
| 48 |
+
supported = False # not supported, used to refresh
|
| 49 |
+
|
| 50 |
+
# New request
|
| 51 |
+
if (
|
| 52 |
+
from_user not in channel.cache_dict
|
| 53 |
+
and from_user not in channel.running
|
| 54 |
+
or content.startswith("#")
|
| 55 |
+
and message_id not in channel.request_cnt # insert the godcmd
|
| 56 |
+
):
|
| 57 |
+
# The first query begin
|
| 58 |
+
if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
|
| 59 |
+
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
|
| 60 |
+
else:
|
| 61 |
+
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
|
| 62 |
+
logger.debug("[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported))
|
| 63 |
+
|
| 64 |
+
if supported and context:
|
| 65 |
+
channel.running.add(from_user)
|
| 66 |
+
channel.produce(context)
|
| 67 |
+
else:
|
| 68 |
+
trigger_prefix = conf().get("single_chat_prefix", [""])[0]
|
| 69 |
+
if trigger_prefix or not supported:
|
| 70 |
+
if trigger_prefix:
|
| 71 |
+
reply_text = textwrap.dedent(
|
| 72 |
+
f"""\
|
| 73 |
+
请输入'{trigger_prefix}'接你想说的话跟我说话。
|
| 74 |
+
例如:
|
| 75 |
+
{trigger_prefix}你好,很高兴见到你。"""
|
| 76 |
+
)
|
| 77 |
+
else:
|
| 78 |
+
reply_text = textwrap.dedent(
|
| 79 |
+
"""\
|
| 80 |
+
你好,很高兴见到你。
|
| 81 |
+
请跟我说话吧。"""
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
logger.error(f"[wechatmp] unknown error")
|
| 85 |
+
reply_text = textwrap.dedent(
|
| 86 |
+
"""\
|
| 87 |
+
未知错误,请稍后再试"""
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
replyPost = create_reply(reply_text, msg)
|
| 91 |
+
return encrypt_func(replyPost.render())
|
| 92 |
+
|
| 93 |
+
# Wechat official server will request 3 times (5 seconds each), with the same message_id.
|
| 94 |
+
# Because the interval is 5 seconds, here assumed that do not have multithreading problems.
|
| 95 |
+
request_cnt = channel.request_cnt.get(message_id, 0) + 1
|
| 96 |
+
channel.request_cnt[message_id] = request_cnt
|
| 97 |
+
logger.info(
|
| 98 |
+
"[wechatmp] Request {} from {} {} {}:{}\n{}".format(
|
| 99 |
+
request_cnt, from_user, message_id, web.ctx.env.get("REMOTE_ADDR"), web.ctx.env.get("REMOTE_PORT"), content
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
task_running = True
|
| 104 |
+
waiting_until = request_time + 4
|
| 105 |
+
while time.time() < waiting_until:
|
| 106 |
+
if from_user in channel.running:
|
| 107 |
+
time.sleep(0.1)
|
| 108 |
+
else:
|
| 109 |
+
task_running = False
|
| 110 |
+
break
|
| 111 |
+
|
| 112 |
+
reply_text = ""
|
| 113 |
+
if task_running:
|
| 114 |
+
if request_cnt < 3:
|
| 115 |
+
# waiting for timeout (the POST request will be closed by Wechat official server)
|
| 116 |
+
time.sleep(2)
|
| 117 |
+
# and do nothing, waiting for the next request
|
| 118 |
+
return "success"
|
| 119 |
+
else: # request_cnt == 3:
|
| 120 |
+
# return timeout message
|
| 121 |
+
reply_text = "【正在思考中,回复任意文字尝试获取回复】"
|
| 122 |
+
replyPost = create_reply(reply_text, msg)
|
| 123 |
+
return encrypt_func(replyPost.render())
|
| 124 |
+
|
| 125 |
+
# reply is ready
|
| 126 |
+
channel.request_cnt.pop(message_id)
|
| 127 |
+
|
| 128 |
+
# no return because of bandwords or other reasons
|
| 129 |
+
if from_user not in channel.cache_dict and from_user not in channel.running:
|
| 130 |
+
return "success"
|
| 131 |
+
|
| 132 |
+
# Only one request can access to the cached data
|
| 133 |
+
try:
|
| 134 |
+
(reply_type, reply_content) = channel.cache_dict.pop(from_user)
|
| 135 |
+
except KeyError:
|
| 136 |
+
return "success"
|
| 137 |
+
|
| 138 |
+
if reply_type == "text":
|
| 139 |
+
if len(reply_content.encode("utf8")) <= MAX_UTF8_LEN:
|
| 140 |
+
reply_text = reply_content
|
| 141 |
+
else:
|
| 142 |
+
continue_text = "\n【未完待续,回复任意文字以继续】"
|
| 143 |
+
splits = split_string_by_utf8_length(
|
| 144 |
+
reply_content,
|
| 145 |
+
MAX_UTF8_LEN - len(continue_text.encode("utf-8")),
|
| 146 |
+
max_split=1,
|
| 147 |
+
)
|
| 148 |
+
reply_text = splits[0] + continue_text
|
| 149 |
+
channel.cache_dict[from_user] = ("text", splits[1])
|
| 150 |
+
|
| 151 |
+
logger.info(
|
| 152 |
+
"[wechatmp] Request {} do send to {} {}: {}\n{}".format(
|
| 153 |
+
request_cnt,
|
| 154 |
+
from_user,
|
| 155 |
+
message_id,
|
| 156 |
+
content,
|
| 157 |
+
reply_text,
|
| 158 |
+
)
|
| 159 |
+
)
|
| 160 |
+
replyPost = create_reply(reply_text, msg)
|
| 161 |
+
return encrypt_func(replyPost.render())
|
| 162 |
+
|
| 163 |
+
elif reply_type == "voice":
|
| 164 |
+
media_id = reply_content
|
| 165 |
+
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
|
| 166 |
+
logger.info(
|
| 167 |
+
"[wechatmp] Request {} do send to {} {}: {} voice media_id {}".format(
|
| 168 |
+
request_cnt,
|
| 169 |
+
from_user,
|
| 170 |
+
message_id,
|
| 171 |
+
content,
|
| 172 |
+
media_id,
|
| 173 |
+
)
|
| 174 |
+
)
|
| 175 |
+
replyPost = VoiceReply(message=msg)
|
| 176 |
+
replyPost.media_id = media_id
|
| 177 |
+
return encrypt_func(replyPost.render())
|
| 178 |
+
|
| 179 |
+
elif reply_type == "image":
|
| 180 |
+
media_id = reply_content
|
| 181 |
+
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
|
| 182 |
+
logger.info(
|
| 183 |
+
"[wechatmp] Request {} do send to {} {}: {} image media_id {}".format(
|
| 184 |
+
request_cnt,
|
| 185 |
+
from_user,
|
| 186 |
+
message_id,
|
| 187 |
+
content,
|
| 188 |
+
media_id,
|
| 189 |
+
)
|
| 190 |
+
)
|
| 191 |
+
replyPost = ImageReply(message=msg)
|
| 192 |
+
replyPost.media_id = media_id
|
| 193 |
+
return encrypt_func(replyPost.render())
|
| 194 |
+
|
| 195 |
+
elif msg.type == "event":
|
| 196 |
+
logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
|
| 197 |
+
if msg.event in ["subscribe", "subscribe_scan"]:
|
| 198 |
+
reply_text = subscribe_msg()
|
| 199 |
+
if reply_text:
|
| 200 |
+
replyPost = create_reply(reply_text, msg)
|
| 201 |
+
return encrypt_func(replyPost.render())
|
| 202 |
+
else:
|
| 203 |
+
return "success"
|
| 204 |
+
else:
|
| 205 |
+
logger.info("暂且不处理")
|
| 206 |
+
return "success"
|
| 207 |
+
except Exception as exc:
|
| 208 |
+
logger.exception(exc)
|
| 209 |
+
return exc
|
channel/wechatmp/wechatmp_channel.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import asyncio
|
| 3 |
+
import imghdr
|
| 4 |
+
import io
|
| 5 |
+
import os
|
| 6 |
+
import threading
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
import requests
|
| 10 |
+
import web
|
| 11 |
+
from wechatpy.crypto import WeChatCrypto
|
| 12 |
+
from wechatpy.exceptions import WeChatClientException
|
| 13 |
+
|
| 14 |
+
from bridge.context import *
|
| 15 |
+
from bridge.reply import *
|
| 16 |
+
from channel.chat_channel import ChatChannel
|
| 17 |
+
from channel.wechatmp.common import *
|
| 18 |
+
from channel.wechatmp.wechatmp_client import WechatMPClient
|
| 19 |
+
from common.log import logger
|
| 20 |
+
from common.singleton import singleton
|
| 21 |
+
from common.utils import split_string_by_utf8_length
|
| 22 |
+
from config import conf
|
| 23 |
+
from voice.audio_convert import any_to_mp3
|
| 24 |
+
|
| 25 |
+
# If using SSL, uncomment the following lines, and modify the certificate path.
|
| 26 |
+
# from cheroot.server import HTTPServer
|
| 27 |
+
# from cheroot.ssl.builtin import BuiltinSSLAdapter
|
| 28 |
+
# HTTPServer.ssl_adapter = BuiltinSSLAdapter(
|
| 29 |
+
# certificate='/ssl/cert.pem',
|
| 30 |
+
# private_key='/ssl/cert.key')
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@singleton
|
| 34 |
+
class WechatMPChannel(ChatChannel):
|
| 35 |
+
def __init__(self, passive_reply=True):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.passive_reply = passive_reply
|
| 38 |
+
self.NOT_SUPPORT_REPLYTYPE = []
|
| 39 |
+
appid = conf().get("wechatmp_app_id")
|
| 40 |
+
secret = conf().get("wechatmp_app_secret")
|
| 41 |
+
token = conf().get("wechatmp_token")
|
| 42 |
+
aes_key = conf().get("wechatmp_aes_key")
|
| 43 |
+
self.client = WechatMPClient(appid, secret)
|
| 44 |
+
self.crypto = None
|
| 45 |
+
if aes_key:
|
| 46 |
+
self.crypto = WeChatCrypto(token, aes_key, appid)
|
| 47 |
+
if self.passive_reply:
|
| 48 |
+
# Cache the reply to the user's first message
|
| 49 |
+
self.cache_dict = dict()
|
| 50 |
+
# Record whether the current message is being processed
|
| 51 |
+
self.running = set()
|
| 52 |
+
# Count the request from wechat official server by message_id
|
| 53 |
+
self.request_cnt = dict()
|
| 54 |
+
# The permanent media need to be deleted to avoid media number limit
|
| 55 |
+
self.delete_media_loop = asyncio.new_event_loop()
|
| 56 |
+
t = threading.Thread(target=self.start_loop, args=(self.delete_media_loop,))
|
| 57 |
+
t.setDaemon(True)
|
| 58 |
+
t.start()
|
| 59 |
+
|
| 60 |
+
def startup(self):
|
| 61 |
+
if self.passive_reply:
|
| 62 |
+
urls = ("/wx", "channel.wechatmp.passive_reply.Query")
|
| 63 |
+
else:
|
| 64 |
+
urls = ("/wx", "channel.wechatmp.active_reply.Query")
|
| 65 |
+
app = web.application(urls, globals(), autoreload=False)
|
| 66 |
+
port = conf().get("wechatmp_port", 8080)
|
| 67 |
+
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
| 68 |
+
|
| 69 |
+
def start_loop(self, loop):
|
| 70 |
+
asyncio.set_event_loop(loop)
|
| 71 |
+
loop.run_forever()
|
| 72 |
+
|
| 73 |
+
async def delete_media(self, media_id):
|
| 74 |
+
logger.debug("[wechatmp] permanent media {} will be deleted in 10s".format(media_id))
|
| 75 |
+
await asyncio.sleep(10)
|
| 76 |
+
self.client.material.delete(media_id)
|
| 77 |
+
logger.info("[wechatmp] permanent media {} has been deleted".format(media_id))
|
| 78 |
+
|
| 79 |
+
def send(self, reply: Reply, context: Context):
|
| 80 |
+
receiver = context["receiver"]
|
| 81 |
+
if self.passive_reply:
|
| 82 |
+
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
|
| 83 |
+
reply_text = reply.content
|
| 84 |
+
logger.info("[wechatmp] text cached, receiver {}\n{}".format(receiver, reply_text))
|
| 85 |
+
self.cache_dict[receiver] = ("text", reply_text)
|
| 86 |
+
elif reply.type == ReplyType.VOICE:
|
| 87 |
+
try:
|
| 88 |
+
voice_file_path = reply.content
|
| 89 |
+
with open(voice_file_path, "rb") as f:
|
| 90 |
+
# support: <2M, <60s, mp3/wma/wav/amr
|
| 91 |
+
response = self.client.material.add("voice", f)
|
| 92 |
+
logger.debug("[wechatmp] upload voice response: {}".format(response))
|
| 93 |
+
# 根据文件大小估计一个微信自动审核的时间,审核结束前返回将会导致语音无法播放,这个估计有待验证
|
| 94 |
+
f_size = os.fstat(f.fileno()).st_size
|
| 95 |
+
time.sleep(1.0 + 2 * f_size / 1024 / 1024)
|
| 96 |
+
# todo check media_id
|
| 97 |
+
except WeChatClientException as e:
|
| 98 |
+
logger.error("[wechatmp] upload voice failed: {}".format(e))
|
| 99 |
+
return
|
| 100 |
+
media_id = response["media_id"]
|
| 101 |
+
logger.info("[wechatmp] voice uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
| 102 |
+
self.cache_dict[receiver] = ("voice", media_id)
|
| 103 |
+
|
| 104 |
+
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
| 105 |
+
img_url = reply.content
|
| 106 |
+
pic_res = requests.get(img_url, stream=True)
|
| 107 |
+
image_storage = io.BytesIO()
|
| 108 |
+
for block in pic_res.iter_content(1024):
|
| 109 |
+
image_storage.write(block)
|
| 110 |
+
image_storage.seek(0)
|
| 111 |
+
image_type = imghdr.what(image_storage)
|
| 112 |
+
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
|
| 113 |
+
content_type = "image/" + image_type
|
| 114 |
+
try:
|
| 115 |
+
response = self.client.material.add("image", (filename, image_storage, content_type))
|
| 116 |
+
logger.debug("[wechatmp] upload image response: {}".format(response))
|
| 117 |
+
except WeChatClientException as e:
|
| 118 |
+
logger.error("[wechatmp] upload image failed: {}".format(e))
|
| 119 |
+
return
|
| 120 |
+
media_id = response["media_id"]
|
| 121 |
+
logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
| 122 |
+
self.cache_dict[receiver] = ("image", media_id)
|
| 123 |
+
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
| 124 |
+
image_storage = reply.content
|
| 125 |
+
image_storage.seek(0)
|
| 126 |
+
image_type = imghdr.what(image_storage)
|
| 127 |
+
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
|
| 128 |
+
content_type = "image/" + image_type
|
| 129 |
+
try:
|
| 130 |
+
response = self.client.material.add("image", (filename, image_storage, content_type))
|
| 131 |
+
logger.debug("[wechatmp] upload image response: {}".format(response))
|
| 132 |
+
except WeChatClientException as e:
|
| 133 |
+
logger.error("[wechatmp] upload image failed: {}".format(e))
|
| 134 |
+
return
|
| 135 |
+
media_id = response["media_id"]
|
| 136 |
+
logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
| 137 |
+
self.cache_dict[receiver] = ("image", media_id)
|
| 138 |
+
else:
|
| 139 |
+
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
|
| 140 |
+
reply_text = reply.content
|
| 141 |
+
texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
|
| 142 |
+
if len(texts) > 1:
|
| 143 |
+
logger.info("[wechatmp] text too long, split into {} parts".format(len(texts)))
|
| 144 |
+
for i, text in enumerate(texts):
|
| 145 |
+
self.client.message.send_text(receiver, text)
|
| 146 |
+
if i != len(texts) - 1:
|
| 147 |
+
time.sleep(0.5) # 休眠0.5秒,防止发送过快乱序
|
| 148 |
+
logger.info("[wechatmp] Do send text to {}: {}".format(receiver, reply_text))
|
| 149 |
+
elif reply.type == ReplyType.VOICE:
|
| 150 |
+
try:
|
| 151 |
+
file_path = reply.content
|
| 152 |
+
file_name = os.path.basename(file_path)
|
| 153 |
+
file_type = os.path.splitext(file_name)[1]
|
| 154 |
+
if file_type == ".mp3":
|
| 155 |
+
file_type = "audio/mpeg"
|
| 156 |
+
elif file_type == ".amr":
|
| 157 |
+
file_type = "audio/amr"
|
| 158 |
+
else:
|
| 159 |
+
mp3_file = os.path.splitext(file_path)[0] + ".mp3"
|
| 160 |
+
any_to_mp3(file_path, mp3_file)
|
| 161 |
+
file_path = mp3_file
|
| 162 |
+
file_name = os.path.basename(file_path)
|
| 163 |
+
file_type = "audio/mpeg"
|
| 164 |
+
logger.info("[wechatmp] file_name: {}, file_type: {} ".format(file_name, file_type))
|
| 165 |
+
# support: <2M, <60s, AMR\MP3
|
| 166 |
+
response = self.client.media.upload("voice", (file_name, open(file_path, "rb"), file_type))
|
| 167 |
+
logger.debug("[wechatmp] upload voice response: {}".format(response))
|
| 168 |
+
except WeChatClientException as e:
|
| 169 |
+
logger.error("[wechatmp] upload voice failed: {}".format(e))
|
| 170 |
+
return
|
| 171 |
+
self.client.message.send_voice(receiver, response["media_id"])
|
| 172 |
+
logger.info("[wechatmp] Do send voice to {}".format(receiver))
|
| 173 |
+
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
| 174 |
+
img_url = reply.content
|
| 175 |
+
pic_res = requests.get(img_url, stream=True)
|
| 176 |
+
image_storage = io.BytesIO()
|
| 177 |
+
for block in pic_res.iter_content(1024):
|
| 178 |
+
image_storage.write(block)
|
| 179 |
+
image_storage.seek(0)
|
| 180 |
+
image_type = imghdr.what(image_storage)
|
| 181 |
+
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
|
| 182 |
+
content_type = "image/" + image_type
|
| 183 |
+
try:
|
| 184 |
+
response = self.client.media.upload("image", (filename, image_storage, content_type))
|
| 185 |
+
logger.debug("[wechatmp] upload image response: {}".format(response))
|
| 186 |
+
except WeChatClientException as e:
|
| 187 |
+
logger.error("[wechatmp] upload image failed: {}".format(e))
|
| 188 |
+
return
|
| 189 |
+
self.client.message.send_image(receiver, response["media_id"])
|
| 190 |
+
logger.info("[wechatmp] Do send image to {}".format(receiver))
|
| 191 |
+
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
| 192 |
+
image_storage = reply.content
|
| 193 |
+
image_storage.seek(0)
|
| 194 |
+
image_type = imghdr.what(image_storage)
|
| 195 |
+
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
|
| 196 |
+
content_type = "image/" + image_type
|
| 197 |
+
try:
|
| 198 |
+
response = self.client.media.upload("image", (filename, image_storage, content_type))
|
| 199 |
+
logger.debug("[wechatmp] upload image response: {}".format(response))
|
| 200 |
+
except WeChatClientException as e:
|
| 201 |
+
logger.error("[wechatmp] upload image failed: {}".format(e))
|
| 202 |
+
return
|
| 203 |
+
self.client.message.send_image(receiver, response["media_id"])
|
| 204 |
+
logger.info("[wechatmp] Do send image to {}".format(receiver))
|
| 205 |
+
return
|
| 206 |
+
|
| 207 |
+
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
|
| 208 |
+
logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context["msg"].msg_id))
|
| 209 |
+
if self.passive_reply:
|
| 210 |
+
self.running.remove(session_id)
|
| 211 |
+
|
| 212 |
+
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
|
| 213 |
+
logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context["msg"].msg_id, exception))
|
| 214 |
+
if self.passive_reply:
|
| 215 |
+
assert session_id not in self.cache_dict
|
| 216 |
+
self.running.remove(session_id)
|
channel/wechatmp/wechatmp_client.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
from wechatpy.client import WeChatClient
|
| 5 |
+
from wechatpy.exceptions import APILimitedException
|
| 6 |
+
|
| 7 |
+
from channel.wechatmp.common import *
|
| 8 |
+
from common.log import logger
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class WechatMPClient(WeChatClient):
|
| 12 |
+
def __init__(self, appid, secret, access_token=None, session=None, timeout=None, auto_retry=True):
|
| 13 |
+
super(WechatMPClient, self).__init__(appid, secret, access_token, session, timeout, auto_retry)
|
| 14 |
+
self.fetch_access_token_lock = threading.Lock()
|
| 15 |
+
self.clear_quota_lock = threading.Lock()
|
| 16 |
+
self.last_clear_quota_time = -1
|
| 17 |
+
|
| 18 |
+
def clear_quota(self):
|
| 19 |
+
return self.post("clear_quota", data={"appid": self.appid})
|
| 20 |
+
|
| 21 |
+
def clear_quota_v2(self):
|
| 22 |
+
return self.post("clear_quota/v2", params={"appid": self.appid, "appsecret": self.secret})
|
| 23 |
+
|
| 24 |
+
def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
|
| 25 |
+
with self.fetch_access_token_lock:
|
| 26 |
+
access_token = self.session.get(self.access_token_key)
|
| 27 |
+
if access_token:
|
| 28 |
+
if not self.expires_at:
|
| 29 |
+
return access_token
|
| 30 |
+
timestamp = time.time()
|
| 31 |
+
if self.expires_at - timestamp > 60:
|
| 32 |
+
return access_token
|
| 33 |
+
return super().fetch_access_token()
|
| 34 |
+
|
| 35 |
+
def _request(self, method, url_or_endpoint, **kwargs): # 重载父类方法,遇到API限流时,清除quota后重试
|
| 36 |
+
try:
|
| 37 |
+
return super()._request(method, url_or_endpoint, **kwargs)
|
| 38 |
+
except APILimitedException as e:
|
| 39 |
+
logger.error("[wechatmp] API quata has been used up. {}".format(e))
|
| 40 |
+
if self.last_clear_quota_time == -1 or time.time() - self.last_clear_quota_time > 60:
|
| 41 |
+
with self.clear_quota_lock:
|
| 42 |
+
if self.last_clear_quota_time == -1 or time.time() - self.last_clear_quota_time > 60:
|
| 43 |
+
self.last_clear_quota_time = time.time()
|
| 44 |
+
response = self.clear_quota_v2()
|
| 45 |
+
logger.debug("[wechatmp] API quata has been cleard, {}".format(response))
|
| 46 |
+
return super()._request(method, url_or_endpoint, **kwargs)
|
| 47 |
+
else:
|
| 48 |
+
logger.error("[wechatmp] last clear quota time is {}, less than 60s, skip clear quota")
|
| 49 |
+
raise e
|
channel/wechatmp/wechatmp_message.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-#
|
| 2 |
+
|
| 3 |
+
from bridge.context import ContextType
|
| 4 |
+
from channel.chat_message import ChatMessage
|
| 5 |
+
from common.log import logger
|
| 6 |
+
from common.tmp_dir import TmpDir
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class WeChatMPMessage(ChatMessage):
|
| 10 |
+
def __init__(self, msg, client=None):
|
| 11 |
+
super().__init__(msg)
|
| 12 |
+
self.msg_id = msg.id
|
| 13 |
+
self.create_time = msg.time
|
| 14 |
+
self.is_group = False
|
| 15 |
+
|
| 16 |
+
if msg.type == "text":
|
| 17 |
+
self.ctype = ContextType.TEXT
|
| 18 |
+
self.content = msg.content
|
| 19 |
+
elif msg.type == "voice":
|
| 20 |
+
if msg.recognition == None:
|
| 21 |
+
self.ctype = ContextType.VOICE
|
| 22 |
+
self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径
|
| 23 |
+
|
| 24 |
+
def download_voice():
|
| 25 |
+
# 如果响应状态码是200,则将响应内容写入本地文件
|
| 26 |
+
response = client.media.download(msg.media_id)
|
| 27 |
+
if response.status_code == 200:
|
| 28 |
+
with open(self.content, "wb") as f:
|
| 29 |
+
f.write(response.content)
|
| 30 |
+
else:
|
| 31 |
+
logger.info(f"[wechatmp] Failed to download voice file, {response.content}")
|
| 32 |
+
|
| 33 |
+
self._prepare_fn = download_voice
|
| 34 |
+
else:
|
| 35 |
+
self.ctype = ContextType.TEXT
|
| 36 |
+
self.content = msg.recognition
|
| 37 |
+
elif msg.type == "image":
|
| 38 |
+
self.ctype = ContextType.IMAGE
|
| 39 |
+
self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径
|
| 40 |
+
|
| 41 |
+
def download_image():
|
| 42 |
+
# 如果响应状态码是200,则将响应内容写入本地文件
|
| 43 |
+
response = client.media.download(msg.media_id)
|
| 44 |
+
if response.status_code == 200:
|
| 45 |
+
with open(self.content, "wb") as f:
|
| 46 |
+
f.write(response.content)
|
| 47 |
+
else:
|
| 48 |
+
logger.info(f"[wechatmp] Failed to download image file, {response.content}")
|
| 49 |
+
|
| 50 |
+
self._prepare_fn = download_image
|
| 51 |
+
else:
|
| 52 |
+
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))
|
| 53 |
+
|
| 54 |
+
self.from_user_id = msg.source
|
| 55 |
+
self.to_user_id = msg.target
|
| 56 |
+
self.other_user_id = msg.source
|
common/const.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# bot_type
|
| 2 |
+
OPEN_AI = "openAI"
|
| 3 |
+
CHATGPT = "chatGPT"
|
| 4 |
+
BAIDU = "baidu"
|
| 5 |
+
XUNFEI = "xunfei"
|
| 6 |
+
CHATGPTONAZURE = "chatGPTOnAzure"
|
| 7 |
+
LINKAI = "linkai"
|
| 8 |
+
|
| 9 |
+
VERSION = "1.3.0"
|
| 10 |
+
|
| 11 |
+
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "xunfei"]
|
common/dequeue.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from queue import Full, Queue
|
| 2 |
+
from time import monotonic as time
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# add implementation of putleft to Queue
|
| 6 |
+
class Dequeue(Queue):
|
| 7 |
+
def putleft(self, item, block=True, timeout=None):
|
| 8 |
+
with self.not_full:
|
| 9 |
+
if self.maxsize > 0:
|
| 10 |
+
if not block:
|
| 11 |
+
if self._qsize() >= self.maxsize:
|
| 12 |
+
raise Full
|
| 13 |
+
elif timeout is None:
|
| 14 |
+
while self._qsize() >= self.maxsize:
|
| 15 |
+
self.not_full.wait()
|
| 16 |
+
elif timeout < 0:
|
| 17 |
+
raise ValueError("'timeout' must be a non-negative number")
|
| 18 |
+
else:
|
| 19 |
+
endtime = time() + timeout
|
| 20 |
+
while self._qsize() >= self.maxsize:
|
| 21 |
+
remaining = endtime - time()
|
| 22 |
+
if remaining <= 0.0:
|
| 23 |
+
raise Full
|
| 24 |
+
self.not_full.wait(remaining)
|
| 25 |
+
self._putleft(item)
|
| 26 |
+
self.unfinished_tasks += 1
|
| 27 |
+
self.not_empty.notify()
|
| 28 |
+
|
| 29 |
+
def putleft_nowait(self, item):
|
| 30 |
+
return self.putleft(item, block=False)
|
| 31 |
+
|
| 32 |
+
def _putleft(self, item):
|
| 33 |
+
self.queue.appendleft(item)
|
common/expired_dict.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime, timedelta
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ExpiredDict(dict):
|
| 5 |
+
def __init__(self, expires_in_seconds):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.expires_in_seconds = expires_in_seconds
|
| 8 |
+
|
| 9 |
+
def __getitem__(self, key):
|
| 10 |
+
value, expiry_time = super().__getitem__(key)
|
| 11 |
+
if datetime.now() > expiry_time:
|
| 12 |
+
del self[key]
|
| 13 |
+
raise KeyError("expired {}".format(key))
|
| 14 |
+
self.__setitem__(key, value)
|
| 15 |
+
return value
|
| 16 |
+
|
| 17 |
+
def __setitem__(self, key, value):
|
| 18 |
+
expiry_time = datetime.now() + timedelta(seconds=self.expires_in_seconds)
|
| 19 |
+
super().__setitem__(key, (value, expiry_time))
|
| 20 |
+
|
| 21 |
+
def get(self, key, default=None):
|
| 22 |
+
try:
|
| 23 |
+
return self[key]
|
| 24 |
+
except KeyError:
|
| 25 |
+
return default
|
| 26 |
+
|
| 27 |
+
def __contains__(self, key):
|
| 28 |
+
try:
|
| 29 |
+
self[key]
|
| 30 |
+
return True
|
| 31 |
+
except KeyError:
|
| 32 |
+
return False
|
| 33 |
+
|
| 34 |
+
def keys(self):
|
| 35 |
+
keys = list(super().keys())
|
| 36 |
+
return [key for key in keys if key in self]
|
| 37 |
+
|
| 38 |
+
def items(self):
|
| 39 |
+
return [(key, self[key]) for key in self.keys()]
|
| 40 |
+
|
| 41 |
+
def __iter__(self):
|
| 42 |
+
return self.keys().__iter__()
|
common/log.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def _reset_logger(log):
|
| 6 |
+
for handler in log.handlers:
|
| 7 |
+
handler.close()
|
| 8 |
+
log.removeHandler(handler)
|
| 9 |
+
del handler
|
| 10 |
+
log.handlers.clear()
|
| 11 |
+
log.propagate = False
|
| 12 |
+
console_handle = logging.StreamHandler(sys.stdout)
|
| 13 |
+
console_handle.setFormatter(
|
| 14 |
+
logging.Formatter(
|
| 15 |
+
"[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s",
|
| 16 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 17 |
+
)
|
| 18 |
+
)
|
| 19 |
+
file_handle = logging.FileHandler("run.log", encoding="utf-8")
|
| 20 |
+
file_handle.setFormatter(
|
| 21 |
+
logging.Formatter(
|
| 22 |
+
"[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s",
|
| 23 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 24 |
+
)
|
| 25 |
+
)
|
| 26 |
+
log.addHandler(file_handle)
|
| 27 |
+
log.addHandler(console_handle)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _get_logger():
|
| 31 |
+
log = logging.getLogger("log")
|
| 32 |
+
_reset_logger(log)
|
| 33 |
+
log.setLevel(logging.INFO)
|
| 34 |
+
return log
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# 日志句柄
|
| 38 |
+
logger = _get_logger()
|
common/package_manager.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
import pip
|
| 4 |
+
from pip._internal import main as pipmain
|
| 5 |
+
|
| 6 |
+
from common.log import _reset_logger, logger
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def install(package):
|
| 10 |
+
pipmain(["install", package])
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def install_requirements(file):
|
| 14 |
+
pipmain(["install", "-r", file, "--upgrade"])
|
| 15 |
+
_reset_logger(logger)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def check_dulwich():
|
| 19 |
+
needwait = False
|
| 20 |
+
for i in range(2):
|
| 21 |
+
if needwait:
|
| 22 |
+
time.sleep(3)
|
| 23 |
+
needwait = False
|
| 24 |
+
try:
|
| 25 |
+
import dulwich
|
| 26 |
+
|
| 27 |
+
return
|
| 28 |
+
except ImportError:
|
| 29 |
+
try:
|
| 30 |
+
install("dulwich")
|
| 31 |
+
except:
|
| 32 |
+
needwait = True
|
| 33 |
+
try:
|
| 34 |
+
import dulwich
|
| 35 |
+
except ImportError:
|
| 36 |
+
raise ImportError("Unable to import dulwich")
|
common/singleton.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def singleton(cls):
|
| 2 |
+
instances = {}
|
| 3 |
+
|
| 4 |
+
def get_instance(*args, **kwargs):
|
| 5 |
+
if cls not in instances:
|
| 6 |
+
instances[cls] = cls(*args, **kwargs)
|
| 7 |
+
return instances[cls]
|
| 8 |
+
|
| 9 |
+
return get_instance
|
common/sorted_dict.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import heapq
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class SortedDict(dict):
|
| 5 |
+
def __init__(self, sort_func=lambda k, v: k, init_dict=None, reverse=False):
|
| 6 |
+
if init_dict is None:
|
| 7 |
+
init_dict = []
|
| 8 |
+
if isinstance(init_dict, dict):
|
| 9 |
+
init_dict = init_dict.items()
|
| 10 |
+
self.sort_func = sort_func
|
| 11 |
+
self.sorted_keys = None
|
| 12 |
+
self.reverse = reverse
|
| 13 |
+
self.heap = []
|
| 14 |
+
for k, v in init_dict:
|
| 15 |
+
self[k] = v
|
| 16 |
+
|
| 17 |
+
def __setitem__(self, key, value):
|
| 18 |
+
if key in self:
|
| 19 |
+
super().__setitem__(key, value)
|
| 20 |
+
for i, (priority, k) in enumerate(self.heap):
|
| 21 |
+
if k == key:
|
| 22 |
+
self.heap[i] = (self.sort_func(key, value), key)
|
| 23 |
+
heapq.heapify(self.heap)
|
| 24 |
+
break
|
| 25 |
+
self.sorted_keys = None
|
| 26 |
+
else:
|
| 27 |
+
super().__setitem__(key, value)
|
| 28 |
+
heapq.heappush(self.heap, (self.sort_func(key, value), key))
|
| 29 |
+
self.sorted_keys = None
|
| 30 |
+
|
| 31 |
+
def __delitem__(self, key):
|
| 32 |
+
super().__delitem__(key)
|
| 33 |
+
for i, (priority, k) in enumerate(self.heap):
|
| 34 |
+
if k == key:
|
| 35 |
+
del self.heap[i]
|
| 36 |
+
heapq.heapify(self.heap)
|
| 37 |
+
break
|
| 38 |
+
self.sorted_keys = None
|
| 39 |
+
|
| 40 |
+
def keys(self):
|
| 41 |
+
if self.sorted_keys is None:
|
| 42 |
+
self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)]
|
| 43 |
+
return self.sorted_keys
|
| 44 |
+
|
| 45 |
+
def items(self):
|
| 46 |
+
if self.sorted_keys is None:
|
| 47 |
+
self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)]
|
| 48 |
+
sorted_items = [(k, self[k]) for k in self.sorted_keys]
|
| 49 |
+
return sorted_items
|
| 50 |
+
|
| 51 |
+
def _update_heap(self, key):
|
| 52 |
+
for i, (priority, k) in enumerate(self.heap):
|
| 53 |
+
if k == key:
|
| 54 |
+
new_priority = self.sort_func(key, self[key])
|
| 55 |
+
if new_priority != priority:
|
| 56 |
+
self.heap[i] = (new_priority, key)
|
| 57 |
+
heapq.heapify(self.heap)
|
| 58 |
+
self.sorted_keys = None
|
| 59 |
+
break
|
| 60 |
+
|
| 61 |
+
def __iter__(self):
|
| 62 |
+
return iter(self.keys())
|
| 63 |
+
|
| 64 |
+
def __repr__(self):
|
| 65 |
+
return f"{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})"
|
common/time_check.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import re
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import config
|
| 6 |
+
from common.log import logger
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def time_checker(f):
|
| 10 |
+
def _time_checker(self, *args, **kwargs):
|
| 11 |
+
_config = config.conf()
|
| 12 |
+
chat_time_module = _config.get("chat_time_module", False)
|
| 13 |
+
if chat_time_module:
|
| 14 |
+
chat_start_time = _config.get("chat_start_time", "00:00")
|
| 15 |
+
chat_stopt_time = _config.get("chat_stop_time", "24:00")
|
| 16 |
+
time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$") # 时间匹配,包含24:00
|
| 17 |
+
|
| 18 |
+
starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式
|
| 19 |
+
stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式
|
| 20 |
+
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间
|
| 21 |
+
|
| 22 |
+
# 时间格式检查
|
| 23 |
+
if not (starttime_format_check and stoptime_format_check and chat_time_check):
|
| 24 |
+
logger.warn("时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(starttime_format_check, stoptime_format_check))
|
| 25 |
+
if chat_start_time > "23:59":
|
| 26 |
+
logger.error("启动时间可能存在问题,请修改!")
|
| 27 |
+
|
| 28 |
+
# 服务时间检查
|
| 29 |
+
now_time = time.strftime("%H:%M", time.localtime())
|
| 30 |
+
if chat_start_time <= now_time <= chat_stopt_time: # 服务时间内,正常返回回答
|
| 31 |
+
f(self, *args, **kwargs)
|
| 32 |
+
return None
|
| 33 |
+
else:
|
| 34 |
+
if args[0]["Content"] == "#更新配置": # 不在服务时间内也可以更新配置
|
| 35 |
+
f(self, *args, **kwargs)
|
| 36 |
+
else:
|
| 37 |
+
logger.info("非服务时间内,不接受访问")
|
| 38 |
+
return None
|
| 39 |
+
else:
|
| 40 |
+
f(self, *args, **kwargs) # 未开启时间模块则直接回答
|
| 41 |
+
|
| 42 |
+
return _time_checker
|
common/tmp_dir.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pathlib
|
| 3 |
+
|
| 4 |
+
from config import conf
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TmpDir(object):
|
| 8 |
+
"""A temporary directory that is deleted when the object is destroyed."""
|
| 9 |
+
|
| 10 |
+
tmpFilePath = pathlib.Path("./tmp/")
|
| 11 |
+
|
| 12 |
+
def __init__(self):
|
| 13 |
+
pathExists = os.path.exists(self.tmpFilePath)
|
| 14 |
+
if not pathExists:
|
| 15 |
+
os.makedirs(self.tmpFilePath)
|
| 16 |
+
|
| 17 |
+
def path(self):
|
| 18 |
+
return str(self.tmpFilePath) + "/"
|
common/token_bucket.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TokenBucket:
|
| 6 |
+
def __init__(self, tpm, timeout=None):
|
| 7 |
+
self.capacity = int(tpm) # 令牌桶容量
|
| 8 |
+
self.tokens = 0 # 初始令牌数为0
|
| 9 |
+
self.rate = int(tpm) / 60 # 令牌每秒生成速率
|
| 10 |
+
self.timeout = timeout # 等待令牌超时时间
|
| 11 |
+
self.cond = threading.Condition() # 条件变量
|
| 12 |
+
self.is_running = True
|
| 13 |
+
# 开启令牌生成线程
|
| 14 |
+
threading.Thread(target=self._generate_tokens).start()
|
| 15 |
+
|
| 16 |
+
def _generate_tokens(self):
|
| 17 |
+
"""生成令牌"""
|
| 18 |
+
while self.is_running:
|
| 19 |
+
with self.cond:
|
| 20 |
+
if self.tokens < self.capacity:
|
| 21 |
+
self.tokens += 1
|
| 22 |
+
self.cond.notify() # 通知获取令牌的线程
|
| 23 |
+
time.sleep(1 / self.rate)
|
| 24 |
+
|
| 25 |
+
def get_token(self):
|
| 26 |
+
"""获取令牌"""
|
| 27 |
+
with self.cond:
|
| 28 |
+
while self.tokens <= 0:
|
| 29 |
+
flag = self.cond.wait(self.timeout)
|
| 30 |
+
if not flag: # 超时
|
| 31 |
+
return False
|
| 32 |
+
self.tokens -= 1
|
| 33 |
+
return True
|
| 34 |
+
|
| 35 |
+
def close(self):
|
| 36 |
+
self.is_running = False
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
token_bucket = TokenBucket(20, None) # 创建一个每分钟生产20个tokens的令牌桶
|
| 41 |
+
# token_bucket = TokenBucket(20, 0.1)
|
| 42 |
+
for i in range(3):
|
| 43 |
+
if token_bucket.get_token():
|
| 44 |
+
print(f"第{i+1}次请求成功")
|
| 45 |
+
token_bucket.close()
|
common/utils.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def fsize(file):
|
| 8 |
+
if isinstance(file, io.BytesIO):
|
| 9 |
+
return file.getbuffer().nbytes
|
| 10 |
+
elif isinstance(file, str):
|
| 11 |
+
return os.path.getsize(file)
|
| 12 |
+
elif hasattr(file, "seek") and hasattr(file, "tell"):
|
| 13 |
+
pos = file.tell()
|
| 14 |
+
file.seek(0, os.SEEK_END)
|
| 15 |
+
size = file.tell()
|
| 16 |
+
file.seek(pos)
|
| 17 |
+
return size
|
| 18 |
+
else:
|
| 19 |
+
raise TypeError("Unsupported type")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def compress_imgfile(file, max_size):
|
| 23 |
+
if fsize(file) <= max_size:
|
| 24 |
+
return file
|
| 25 |
+
file.seek(0)
|
| 26 |
+
img = Image.open(file)
|
| 27 |
+
rgb_image = img.convert("RGB")
|
| 28 |
+
quality = 95
|
| 29 |
+
while True:
|
| 30 |
+
out_buf = io.BytesIO()
|
| 31 |
+
rgb_image.save(out_buf, "JPEG", quality=quality)
|
| 32 |
+
if fsize(out_buf) <= max_size:
|
| 33 |
+
return out_buf
|
| 34 |
+
quality -= 5
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def split_string_by_utf8_length(string, max_length, max_split=0):
|
| 38 |
+
encoded = string.encode("utf-8")
|
| 39 |
+
start, end = 0, 0
|
| 40 |
+
result = []
|
| 41 |
+
while end < len(encoded):
|
| 42 |
+
if max_split > 0 and len(result) >= max_split:
|
| 43 |
+
result.append(encoded[start:].decode("utf-8"))
|
| 44 |
+
break
|
| 45 |
+
end = min(start + max_length, len(encoded))
|
| 46 |
+
# 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止
|
| 47 |
+
while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000:
|
| 48 |
+
end -= 1
|
| 49 |
+
result.append(encoded[start:end].decode("utf-8"))
|
| 50 |
+
start = end
|
| 51 |
+
return result
|