Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import re | |
| import requests | |
| import argparse | |
| import string | |
| from datetime import timedelta | |
| from flask import Flask, session, request, jsonify, render_template | |
| from transformers.models.bert.tokenization_bert import BertTokenizer | |
| from bot.chatbot import ChatBot | |
| from bot.config import special_token_list | |
| app = Flask(__name__) | |
| app.config["SECRET_KEY"] = os.urandom(74) | |
| app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(days=7) | |
| tokenizer:BertTokenizer = None | |
| history_matrix:dict = {} | |
| def move_history_from_session_to_global_memory() -> None: | |
| global history_matrix | |
| if session.get( "session_hash") and session["history"]: | |
| history_matrix[session["session_hash"]] = session["history"] | |
| def move_history_from_global_memory_to_session() -> None: | |
| global history_matrix | |
| if session.get( "session_hash"): | |
| session["history"] = history_matrix.get( session.get( "session_hash") ) | |
| def set_args() -> argparse.Namespace: | |
| parser:argparse.ArgumentParser = argparse.ArgumentParser() | |
| parser.add_argument("--vocab_path", default=None, type=str, required=False, help="选择词库") | |
| parser.add_argument("--model_path", default="lewiswu1209/Winnie", type=str, required=False, help="对话模型路径") | |
| return parser.parse_args() | |
| def get_history_list() -> str: | |
| global tokenizer | |
| move_history_from_global_memory_to_session() | |
| history_list:list = session.get("history") | |
| if history_list is None: | |
| history_list = [] | |
| history:list = [] | |
| for history_ids in history_list: | |
| tokens = tokenizer.convert_ids_to_tokens(history_ids) | |
| fixed_tokens = [] | |
| for token in tokens: | |
| if token.startswith("##"): | |
| token = token[2:] | |
| fixed_tokens.append(token) | |
| history.append( "".join( fixed_tokens ) ) | |
| return jsonify(history) | |
| def talk() -> str: | |
| global tokenizer | |
| global history_matrix | |
| if request.args.get("hash"): | |
| session["session_hash"] = request.args.get("hash") | |
| move_history_from_global_memory_to_session() | |
| if session.get("session_hash") is None: | |
| session["session_hash"] = "".join( random.sample(string.ascii_lowercase + string.digits, 11) ) | |
| if request.args.get("text"): | |
| input_text = request.args.get("text") | |
| history_list = session.get("history") | |
| if input_text.upper()=="HELP": | |
| help_info_list = ["输入任意文字,Winnie会和你对话", | |
| "输入ERASE MEMORY,Winnie会清空记忆", | |
| "输入\"<TAG>=<VALUE>\",Winnie会记录你的角色信息", | |
| "例如:<NAME>=Vicky,Winnie会修改自己的名字", | |
| "可以修改的角色信息有:", | |
| "<NAME>, <GENDER>, <YEAROFBIRTH>, <MONTHOFBIRTH>, <DAYOFBIRTH>, <ZODIAC>, <AGE>", | |
| "输入“上联:XXXXXXX”,Winnie会和你对对联", | |
| "输入“写诗:XXXXXXX”,Winnie会以XXXXXXX为开头写诗", | |
| "以\"请问\"开头并以问号结尾,Winnie会回答该问题" | |
| ] | |
| return jsonify(help_info_list) | |
| if history_list is None or len(history_list)==0 or input_text == "ERASE MEMORY": | |
| history_list = [] | |
| output_text = requests.post( | |
| url='https://hf.space/embed/lewiswu1209/Winnie/+/api/predict/', | |
| json={"data": ["ERASE MEMORY"], "session_hash": session["session_hash"]} | |
| ).json()["data"][0] | |
| if input_text != "ERASE MEMORY": | |
| if not re.match( r"^<.+>=.+$", input_text ): | |
| history_list.append( tokenizer.encode(input_text, add_special_tokens=False) ) | |
| output_text = requests.post( | |
| url='https://hf.space/embed/lewiswu1209/Winnie/+/api/predict/', | |
| json={"data": [input_text], "session_hash": session["session_hash"]} | |
| ).json()["data"][0] | |
| if not re.match( r"^<.+>=.+$", input_text ): | |
| history_list.append( tokenizer.encode(output_text, add_special_tokens=False) ) | |
| session["history"] = history_list | |
| history_matrix[session["session_hash"]] = history_list | |
| return jsonify([output_text]) | |
| else: | |
| return jsonify([""]) | |
| def index() -> str: | |
| return "Hello world!" | |
| def get_hash() -> str: | |
| global history_matrix | |
| if request.args.get("hash"): | |
| session["session_hash"] = request.args.get("hash") | |
| move_history_from_global_memory_to_session() | |
| hash = session.get("session_hash") | |
| if hash: | |
| return session.get("session_hash") | |
| else: | |
| return " " | |
| def chitchat() -> str: | |
| return render_template( "chat_template.html" ) | |
| def main() -> None: | |
| global tokenizer | |
| args = set_args() | |
| tokenizer = ChatBot.get_tokenizer( | |
| args.model_path, | |
| vocab_path=args.vocab_path, | |
| special_token_list = special_token_list | |
| ) | |
| app.run( host = "127.0.0.1", port = 8080 ) | |
| if __name__ == "__main__": | |
| main() | |