Spaces:
Runtime error
Runtime error
| """This is an example of how to use async langchain with fastapi and return a streaming response.""" | |
| import uvicorn, threading, logging, time, json, re, os | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from langchain import LLMChain | |
| from starlette.types import Send | |
| from langchain.chat_models import ChatOpenAI | |
| from fastapi.responses import StreamingResponse | |
| from logging.handlers import RotatingFileHandler | |
| from langchain.memory import ConversationBufferMemory | |
| from typing import Any, Optional, Awaitable, Callable, Iterator, Union, List | |
| from langchain.callbacks.base import AsyncCallbackManager, AsyncCallbackHandler | |
| import openai | |
| os.environ["OPENAI_API_KEY"] = "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" | |
| openai.proxy = {"http": "http://127.0.0.1:7890", "https": "http://127.0.0.1:7890"} | |
| from utils import ( | |
| prompt_memory, prompt_chat_term, prompt_basic, prompt_reco, | |
| prompt_memory_character, prompt_chat_character, prompt_basic_character, prompt_reco_character, | |
| memory_chat, memory_basic, memory_basic_character, memory_reco, memory_reco_character, | |
| product, content_term, get_detailInfo, process_Info, | |
| ) | |
| from config import log_path, log_file, pre_key_words | |
| key_words = {} | |
| # [完毕, 招呼, 配置, 点赞, 广告, 讲解, 多人脸, 未成功] | |
| keywords_to_extract = ['wan_bi', 'zhao_hu', 'pei_zhi', 'dian_zan', 'guang_gao', 'jiang_jie', 'multi_face', 'un_success'] | |
| for keyword in keywords_to_extract: | |
| key_words[keyword] = pre_key_words[keyword] | |
| logger = logging.getLogger('my_logger') | |
| logger.setLevel(logging.DEBUG) | |
| log_path = os.path.join(os.path.dirname(__file__), log_path) | |
| if not os.path.exists(log_path): | |
| os.makedirs(log_path) | |
| log_file = "{}/{}".format(log_path, log_file) # 创建大小滚动的日志处理器,最大文件大小为20MB,保留10个历史日志文件 | |
| file_handler = RotatingFileHandler(log_file, maxBytes=20 * 1024 * 1024, backupCount=10) | |
| formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
| file_handler.setFormatter(formatter) | |
| logger.addHandler(file_handler) | |
| app = FastAPI() | |
| Sender = Callable[[Union[str, bytes]], Awaitable[None]] | |
| class EmptyIterator(Iterator[Union[str, bytes]]): | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| raise StopIteration | |
| class AsyncStreamCallbackHandler(AsyncCallbackHandler): | |
| """Callback handler for streaming, inheritance from AsyncCallbackHandler.""" | |
| def __init__(self, send: Sender, time: int, start_time: float): | |
| super().__init__() | |
| self.send = send | |
| self.tokens = [] # 用于存储生成的token | |
| self.time = time | |
| self.start_time = start_time # 用户发送数据开始时间 | |
| self.previous_time = start_time # 上一条句子返回的时间 | |
| async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
| async def send_response(sentence): | |
| response = generate_response(self.time, self.start_time, self.previous_time, content=sentence, broadcast=True, end=False) | |
| self.previous_time = time.time() | |
| time.sleep(0.2) | |
| logger.info(f"--return body--:{json.dumps(response, ensure_ascii=False)}") | |
| await self.send(f"{json.dumps(response, ensure_ascii=False)}\n") | |
| self.tokens.append(token.strip()) | |
| if any(punctuation in token for punctuation in [",", ",", "。", "!", "?", ":", ":", ";", ";"]): | |
| if len(token) == 1: | |
| sentence = "".join(self.tokens) | |
| self.tokens = [] | |
| await send_response(sentence) | |
| else: # openai 生成的token有时候标点符号连着另一个token:",请" / ",把", 语音播报会生硬,所以单独处理 | |
| sentence = "".join(self.tokens[:-1]) | |
| self.tokens = [self.tokens[-1]] | |
| await send_response(sentence) | |
| def extract_product(text): | |
| """根据模型生成内容正则提取出推荐的险种名称,在预先定义好的阳光6类保险(product)中去根据名称拿出对应的保险信息""" | |
| match = re.search(r"推荐:(.+?)[,。]", text) | |
| insurance_name = match.group(1) if match else "未匹配到保险名称" | |
| product_name = product.get(insurance_name, insurance_name) | |
| return {"product": product_name} | |
| def final_info(res): | |
| """根据UI把家庭信息放入数组,可优化,当前版本是必须按照[用户,父亲,母亲,妻子,儿子]的顺序,缺失的地方留出slot""" | |
| family = { | |
| "父亲": ["父亲", "爸爸"], | |
| "母亲": ["母亲", "妈妈"], | |
| "配偶": ["老公", "丈夫", "太太", "配偶", "妻子", "夫人", "老婆"], | |
| "孩子": ["孩子", "儿子", "女儿", "小孩儿", "小孩"] | |
| } | |
| result = [res[0]] | |
| for _, keywords in family.items(): | |
| family_member = next((item for item in res[1:] if item['name'] in keywords), {'name': '', 'age': '', 'career': '', 'health': '', 'live': ''}) | |
| result.append(family_member) | |
| return result | |
| def generate_response( | |
| stamp, | |
| start_time, | |
| previous_time, | |
| content: str = "", | |
| broadcast: bool = True, | |
| exitInsureProductInfo: bool = False, | |
| existFamilyInfo: bool = False, | |
| insureProductInfos: dict = [], | |
| familyInfos: dict = [], | |
| end: bool = False | |
| ) -> dict: | |
| """返回给客户端的接口字段定义""" | |
| response = { | |
| "time": stamp, | |
| "totalElapsedTime": round(time.time() - start_time, 4), | |
| "elapsedTimeSinceLast": round(time.time() - previous_time, 4), | |
| "content": content, | |
| "broadcast": broadcast, | |
| "exitInsureProductInfo": exitInsureProductInfo, | |
| "existFamilyInfo": existFamilyInfo, | |
| "insureProductInfos": insureProductInfos, | |
| "familyInfos": familyInfos, | |
| "end": end | |
| } | |
| return response | |
| class ChatOpenAIStreamingResponse(StreamingResponse): | |
| """Streaming response for openai chat model, inheritance from StreamingResponse.""" | |
| def __init__( | |
| self, | |
| generate: Callable[[Sender], Awaitable[None]], | |
| message: str, | |
| time: int, | |
| action: int, | |
| start_time: float, | |
| status_code: int = 200, | |
| media_type: Optional[str] = None, | |
| ) -> None: | |
| super().__init__(content=EmptyIterator(), status_code=status_code, media_type=media_type) | |
| self.generate = generate | |
| self.message = message | |
| self.time = time | |
| self.action = action | |
| self.start_time = start_time | |
| self.response_data = '' # 新增的属性,用于存储生成的数据 | |
| async def stream_response(self, send: Send) -> None: | |
| """Rewrite stream_response to send response to client.""" | |
| await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) | |
| async def send_chunk(chunk: Union[str, bytes]): | |
| if not isinstance(chunk, bytes): | |
| chunk = chunk.encode(self.charset) | |
| await send({"type": "http.response.body", "body": chunk, "more_body": True}) | |
| dict_data = json.loads(chunk.decode(self.charset).strip()) | |
| self.response_data += dict_data['content'] | |
| async def send_word_response(self, response): | |
| logger.info(f"--return body--:{json.dumps(response, ensure_ascii=False)}") | |
| await send({"type": "http.response.body", "body": json.dumps(response, ensure_ascii=False).encode("utf-8"), "more_body": True}) | |
| async def send_end_response(self, previous_time): | |
| response = generate_response(self.time, self.start_time, previous_time, content = "", broadcast = False, end=True) | |
| logger.info(f"--return body--:{json.dumps(response, ensure_ascii=False)}\n{'*' * 300}") | |
| await send({"type": "http.response.body", "body": json.dumps(response, ensure_ascii=False).encode("utf-8"), "more_body": False}) | |
| async def process_response(self, prefix): | |
| """处理固定回复,切割成小句返回""" | |
| delimiter_pattern = r"[,。?!\.\?!]" | |
| sentences = re.split(delimiter_pattern, prefix) | |
| surfix = [i.strip() for i in sentences if len(i)] | |
| previous_time = time.time() | |
| for item in surfix: | |
| response = generate_response(self.time, self.start_time, previous_time, content = item, broadcast = True, end=False) | |
| await send_word_response(self, response) | |
| previous_time = time.time() | |
| time.sleep(0.5) | |
| await send_end_response(self, previous_time) | |
| print("self.message:", self.message) | |
| global flag1, flag2, all_Info | |
| previous_time = time.time() | |
| if self.action == 2 and len(self.message) == 0: | |
| await send_end_response(self, previous_time) | |
| if self.action == 4: | |
| await self.generate(send_chunk) if self.message else None | |
| await process_response(self, prefix = key_words.get('multi_face', '')) | |
| elif "赞" in self.message: | |
| await process_response(self, prefix = key_words.get('dian_zan', '')) | |
| elif "招呼" in self.message: | |
| await process_response(self, prefix = key_words.get('zhao_hu', '')) | |
| elif "讲解" in self.message: | |
| await self.generate(send_chunk) | |
| await process_response(self, prefix = key_words.get('jiang_jie', '')) | |
| elif "广告" in self.message: | |
| await process_response(self, prefix = key_words.get('guang_gao', '')) | |
| elif "配置" in self.message: | |
| all_Info = [] | |
| await self.generate(send_chunk) | |
| await process_response(self, prefix = key_words.get('pei_zhi', '')) | |
| elif "首要" in self.message: | |
| if len(all_Info) == 0: | |
| await process_response(self, prefix = "抱歉,未能正确记录您的家庭信息,让我们重新开始搜集您的家庭信息之后,再为您推荐保险吧。") | |
| else: | |
| await self.generate(send_chunk) | |
| previous_time = time.time() | |
| print("response_data:", self.response_data) | |
| result = extract_product(self.response_data) | |
| all_Info[0]["name"] = "客户" | |
| doll = {'name': '', 'age': '', 'career': '', 'health': '', 'live': ''} | |
| user_info = [all_Info[0], doll, doll, doll, doll] | |
| response = generate_response(self.time, self.start_time, previous_time, content = "请查看以下表格:", broadcast = True, exitInsureProductInfo=True, existFamilyInfo=True, insureProductInfos=result['product'], familyInfos=user_info, end=False) | |
| await send_word_response(self, response) | |
| time.sleep(0.5) | |
| previous_time = time.time() | |
| await process_response(self, prefix = "您可以向我咨询具体的保险条款,或直接向我获取代理人联系方式,他们能提供更加详细专业的回答。〞如已阅读完毕,请按方向键上键,我将继续回答您关于保险产品的问题。") | |
| else: | |
| await self.generate(send_chunk) | |
| previous_time = time.time() | |
| print("response_data:", self.response_data) | |
| if "完毕" in self.response_data: | |
| thread1 = threading.Thread(target = thread_function) | |
| thread1.start() | |
| thread1.join() # 这一行会让主线程等待thread1线程执行完毕,调用join()方法会阻塞主线程,直到thread1线程执行完成后才会继续执行主线程的后续代码 | |
| if all_Info: | |
| all_Info[0]["name"] = "客户" | |
| all_Info = final_info(all_Info) | |
| response = generate_response(self.time, self.start_time, previous_time, content = "以下表格是根据您提供的家庭结构和成员信息,", broadcast = True, existFamilyInfo=True, familyInfos=all_Info, end=False) | |
| await send_word_response(self, response) | |
| time.sleep(1) | |
| previous_time = time.time() | |
| await process_response(self, prefix = key_words['wan_bi']) | |
| else: # 用户没有提供或者没有搜集到家庭信息:all_Info = [];跳转到闲聊模式 | |
| restart_function() | |
| await process_response(self, prefix = key_words['un_success']) | |
| else: | |
| await send_end_response(self, previous_time) | |
| if flag1 == 1 and flag2 == 0: | |
| thread1 = threading.Thread(target = thread_function) | |
| thread1.start() | |
| def thread_function(): | |
| global family_info, all_Info | |
| if family_info: | |
| detail_Info = get_detailInfo(family_info) | |
| print("detail_Info_1:\n", detail_Info) | |
| if "人物" in detail_Info: | |
| detail_Info = [i.strip() for i in detail_Info.strip().split('\n') if i.strip()] | |
| count = sum(1 for item in detail_Info if item.count('未知') >= 4) | |
| print("detail_Info_2:\n", len(detail_Info), detail_Info) | |
| print("超过4个信息是未知的家人数(不计入最终统计)count:", count) | |
| if len(detail_Info) != count and "*" in detail_Info[0]: | |
| final = process_Info(detail_Info) | |
| if final: | |
| all_Info.extend(final) | |
| print("all_Info: ", all_Info) | |
| def restart_function(): | |
| global flag1, flag2,family_info, all_Info, memory_chat, memory_reco, memory_reco_character, memory_basic, memory_basic_character | |
| flag1, flag2, all_Info, family_info = 0, 0, [], "" | |
| memory_chat = ConversationBufferMemory(memory_key="chat_history", ai_prefix="") | |
| memory_reco = ConversationBufferMemory(memory_key="chat_history", input_key = "human_input") | |
| memory_reco_character = ConversationBufferMemory(memory_key="chat_history", input_key = "human_input") | |
| memory_basic = ConversationBufferMemory(memory_key="chat_history", ai_prefix="") | |
| memory_basic_character = ConversationBufferMemory(memory_key="chat_history", ai_prefix="") | |
| return | |
| async def openai_function(send, time, start_time, prompt, message, model_name, memory=None, context: str = ""): | |
| # model = AzureChatOpenAI(request_timeout = 8*60, deployment_name="gpt-35-turbo",openai_api_version="2023-03-15-preview", streaming=True, callback_manager=AsyncCallbackManager([AsyncStreamCallbackHandler(send, time, start_time)]), verbose=True, temperature=0.7) | |
| model = ChatOpenAI(request_timeout = 8*60, model_name=model_name, streaming=True, callback_manager=AsyncCallbackManager([AsyncStreamCallbackHandler(send, time, start_time)]), verbose=True, temperature=0.7) | |
| chain_args = { | |
| 'llm': model, | |
| 'prompt': prompt, | |
| 'verbose': True, | |
| 'memory': memory if memory else None | |
| } | |
| chain = LLMChain(**chain_args) | |
| if context: | |
| await chain.apredict(human_input=message, context=context) | |
| else: | |
| await chain.apredict(human_input=message) | |
| switch, flag1, flag2, all_Info, family_info = 0, 0, 0, [], "" | |
| def send_message(message: str, time: int, action: int, dialogue_memory:str, start_time: float, switch: int) -> Callable[[Sender], Awaitable[None]]: | |
| async def generate_memory(send: Sender): | |
| await openai_function(model_name = "gpt-4", send = send, time = time, start_time = start_time, prompt = prompt_memory, message = message, context = dialogue_memory) | |
| async def generate_memory_character(send: Sender): | |
| await openai_function(model = "gpt-4", send = send, start_time = start_time, prompt = prompt_memory_character, message = message, context = dialogue_memory) | |
| async def generate_hello(send: Sender): | |
| await openai_function(model_name = "gpt-3.5-turbo-16k", send = send, time = time, start_time = start_time, prompt = prompt_chat_term, message = message, context = content_term) | |
| async def generate_hello_character(send: Sender): | |
| await openai_function(model_name = "gpt-4", send = send, time = time, start_time = start_time, memory = memory_chat, prompt = prompt_chat_character, message = message) | |
| async def generate_basic(send: Sender): | |
| global all_Info, family_info | |
| family_info = message | |
| await openai_function(model_name = "gpt-4", send = send, time = time, start_time = start_time, memory = memory_basic, prompt = prompt_basic, message = message) | |
| async def generate_basic_character(send: Sender): | |
| global all_Info, family_info | |
| family_info = message | |
| await openai_function(model_name = "gpt-4", send = send, time = time, start_time = start_time, memory = memory_basic_character, prompt = prompt_basic_character, message = message) | |
| async def generate_recommend(send: Sender): | |
| global all_Info | |
| all_Info[0]["name"] = "客户" if all_Info else logger.info("--error--:len(all_Info)==0, 家庭信息未正确搜集!") | |
| model = ChatOpenAI(request_timeout = 8*60, model_name="gpt-4", streaming=True, callback_manager=AsyncCallbackManager([AsyncStreamCallbackHandler(send, time, start_time)]), verbose=True, temperature=0) | |
| llm_chain = LLMChain(llm = model, prompt = prompt_reco, verbose = True, memory = memory_reco) | |
| await llm_chain.apredict(human_input = message, context = all_Info[0] if all_Info else "", product=product) | |
| async def generate_recommend_character(send: Sender): | |
| global all_Info | |
| all_Info[0]["name"] = "客户" if all_Info else logger.info("--error--:len(all_Info)==0, 家庭信息未正确搜集!") | |
| model = ChatOpenAI(request_timeout = 8*60, model_name="gpt-4", streaming=True, callback_manager=AsyncCallbackManager([AsyncStreamCallbackHandler(send, time, start_time)]), verbose=True, temperature=0) | |
| llm_chain = LLMChain(llm=model, prompt = prompt_reco_character, verbose=True, memory = memory_reco_character) | |
| await llm_chain.apredict(human_input = message, context = all_Info[0] if all_Info else "", product = product) | |
| global flag1, flag2 | |
| if action == 5: | |
| return generate_memory_character if switch else generate_memory | |
| if action == 2: | |
| restart_function() | |
| if "配置" in message: | |
| flag1 = 1 | |
| return generate_basic_character if switch else generate_basic | |
| if "首要" in message: | |
| flag2 = 1 | |
| return generate_recommend_character if switch else generate_recommend | |
| if flag1 == 0 and flag2 == 0: | |
| return generate_hello_character if switch else generate_hello | |
| if flag1 == 1 and flag2 == 0: | |
| return generate_basic_character if switch else generate_basic | |
| if flag1 == 1 and flag2 == 1: | |
| return generate_recommend_character if switch else generate_recommend | |
| if flag1 == 0 and flag2 == 1: | |
| return generate_recommend_character if switch else generate_recommend | |
| return generate_recommend_character if switch else generate_recommend | |
| class StreamRequest(BaseModel): | |
| """Request body for streaming.""" | |
| message: str | |
| time: int | |
| action: int | |
| historyMessage: list[str] | |
| def stream(body: StreamRequest): | |
| logger.info(f"--request.body--:{body}") | |
| start_time = time.time() | |
| if body.action == 1: | |
| body.message = "根据我之前提供的家庭信息,首要推荐我考虑的是什么险种?" | |
| elif body.action == 3: | |
| response = {"time":body.time, "totalElapsedTime": round(time.time() - start_time, 4), "elapsedTimeSinceLast": round(time.time() - start_time, 4), "content": "您对以上推荐的保险方案是否有疑惑需要我进行解答呢?", "broadcast": True, "exitInsureProductInfo": False, "existFamilyInfo": False, "insureProductInfos": [], "familyInfos":[], "end": True} | |
| logger.info(f"--return body--:{response}\n{'*' * 300}") | |
| return response | |
| if "讲解" in body.message: | |
| body.message = body.message + "回答不超过3句话" | |
| logger.info(f"--request.body--:{body}") | |
| if any(keyword in body.message for keyword in ["研发", "开发"]): | |
| response = {"time":body.time, "totalElapsedTime": round(time.time() - start_time, 4), "elapsedTimeSinceLast": round(time.time() - start_time, 4), "content": "我是由杭州华鲤智能科技有限公司研发的AI私人保险助理,具体细节请咨询我们团队技术人员。", "broadcast": True, "exitInsureProductInfo": False, "existFamilyInfo": False, "insureProductInfos": [], "familyInfos":[], "end": True} | |
| logger.info(f"--return body--:{response}\n{'*' * 300}") | |
| return response | |
| global switch, flag1, flag2, all_Info, family_info | |
| if "切换" in body.message: | |
| if switch == 0: | |
| switch = 1 | |
| restart_function() | |
| response = {"time":body.time, "totalElapsedTime": round(time.time() - start_time, 4), "elapsedTimeSinceLast": round(time.time() - start_time, 4), "content": "已为您切换到精神小伙人设,老铁,请问您有什么问题要咨询的吗?", "broadcast": True, "exitInsureProductInfo": False, "existFamilyInfo": False, "insureProductInfos": [], "familyInfos":[], "end": True} | |
| logger.info(f"--return body--:{response}\n{'*' * 300}") | |
| return response | |
| else: | |
| switch = 0 | |
| restart_function() | |
| response = {"time":body.time, "totalElapsedTime": round(time.time() - start_time, 4), "elapsedTimeSinceLast": round(time.time() - start_time, 4), "content": "已为您切换到默认人设,请问您有什么问题要咨询的吗?", "broadcast": True, "exitInsureProductInfo": False, "existFamilyInfo": False, "insureProductInfos": [], "familyInfos":[], "end": True} | |
| logger.info(f"--return body--:{response}\n{'*' * 300}") | |
| return response | |
| else: | |
| dialogue_memory = '\n'.join(body.historyMessage) if body.historyMessage else "" | |
| return ChatOpenAIStreamingResponse(send_message(body.message, body.time, body.action, dialogue_memory, start_time, switch=switch), media_type="text/event-stream", message = body.message, time = body.time, action = body.action, start_time = start_time) | |
| if __name__ == "__main__": | |
| uvicorn.run(host="0.0.0.0", port=8086, app=app) | |
| ''' | |
| action = 1 : 自动切换保险推荐,算法发送“根据我之前提供的家庭信息,首要推荐我考虑的是什么险种?” | |
| action = 2 : 会话全部重置 | |
| action = 3 : 默认回复“您对以上推荐的保险方案是否有疑惑需要我进行解答呢?” | |
| action = 4 : 出现多张脸的情况,在一个嘈杂环境中 | |
| action = 5 : 用户切换 | |
| action = 0 or -1 : 默认 | |
| ''' | |