Spaces:
Runtime error
Runtime error
| import os | |
| os.environ["OPENAI_API_KEY"] = "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" | |
| from langchain.embeddings.openai import OpenAIEmbeddings | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain.document_loaders import PyPDFLoader | |
| from langchain.vectorstores import Chroma | |
| from pypinyin import lazy_pinyin | |
| import gradio as gr | |
| import openai | |
| import random | |
| # import logging | |
| # logging.basicConfig( | |
| # filename='log/log.log', | |
| # level=logging.INFO, | |
| # format='%(asctime)s - %(levelname)s - %(message)s', | |
| # datefmt='%m/%d/%Y %H:%M:%S' | |
| # ) | |
| embedding = OpenAIEmbeddings() | |
| target_files = set() | |
| topics = ["农业", "宗教与文化", "建筑业与制造业", "医疗卫生保健", "国家治理", "法律法规", "财政税收", "教育", "金融", "贸易", "宏观经济", "社会发展", "科学技术", "能源环保", "国际关系", "国防安全","不限主题"] | |
| def get_path(target_string): | |
| folder_path = "./vector_data" | |
| all_vectors = os.listdir(folder_path) | |
| matching_files = [file for file in all_vectors if file.startswith(target_string)] | |
| for file in matching_files: | |
| file_path = os.path.join(folder_path, file) | |
| return file_path | |
| return "" | |
| def extract_partial_message(res_message, response): | |
| for chunk in response: | |
| if len(chunk["choices"][0]["delta"]) != 0: | |
| res_message = res_message + chunk["choices"][0]["delta"]["content"] | |
| yield res_message | |
| def format_messages(sys_prompt, history, message): | |
| history_openai_format = [{"role": "system", "content": sys_prompt}] | |
| for human, assistant in history: | |
| history_openai_format.append({"role": "user", "content": human}) | |
| history_openai_format.append({"role": "assistant", "content": assistant}) | |
| history_openai_format.append({"role": "user", "content": message}) | |
| return history_openai_format | |
| def get_domain(history, message): | |
| sys_prompt = """ | |
| 帮我根据用户的问题划分到以下几个类别,输出最匹配的一个类别:[宗教与文化, 农业, 建筑业与制造业, 医疗卫生保健, 国家治理, 法律法规, 财政税收, 教育, 金融, 贸易, 宏观经济, 社会发展, 科学技术, 能源环保, 国际关系, 国防安全] | |
| """ | |
| history_openai_format = format_messages(sys_prompt, history, message) | |
| print("history_openai_format:", history_openai_format) | |
| # logging.info(f"history_openai_format: {history_openai_format}") | |
| response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=False) | |
| domain = response['choices'][0]['message']['content'] | |
| print("匹配领域:", domain) | |
| # logging.info(f"匹配领域: {domain}") | |
| return domain | |
| def echo(message, history, flag1, flag2): | |
| global target_files, topics | |
| print("flag1:", flag1) | |
| print("flag2:", flag2) | |
| print("history:", history) | |
| print("message:", message) | |
| # logging.info(f"flag1: {flag1}") | |
| # logging.info(f"flag2: {flag2}") | |
| # logging.info(f"history: {history}") | |
| # logging.info(f"message: {message}") | |
| if len(flag1) == 0: # 不进行研报问答&研报生成 | |
| target_files.clear() | |
| history.clear() | |
| if flag2 not in [None, 16]: | |
| domain = topics[flag2] | |
| message = f"{domain}领域相关内容" | |
| elif flag2 in [None, 16]: | |
| message = f"{topics[random.randint(0, 16)]}领域相关内容" if message == "" else message | |
| domain = get_domain(history, message) | |
| persist_vector_path = get_path("".join(lazy_pinyin(domain))) | |
| db = Chroma(persist_directory=persist_vector_path, embedding_function=embedding) | |
| docs = db.similarity_search_with_score(query=message, k=5) | |
| contents = [doc[0] for doc in docs] | |
| relevance = " ".join(doc.page_content for doc in contents) | |
| source = [doc.metadata for doc in contents] | |
| for item in source: | |
| target_files.add(item['source'] ) | |
| print("研报搜索结果:", target_files) | |
| # logging.info(f"研报搜索结果: {target_files}") | |
| sys_prompt = """ | |
| 你是一个研报助手,根据这篇文章:{} | |
| 来回复用户的问题生成总结,你需要严格按照这种格式回复:以上文章总结了*,主要观点是*。你只能回复中文。 | |
| 如果这篇文章无法回答用户的问题,你必须根据你的知识面直接编造内容进行回答,不能回复不知道,不能说这段文字没有提供等话语。 | |
| """ | |
| sys_prompt = sys_prompt.format(relevance) | |
| history_openai_format = format_messages(sys_prompt, history, message) | |
| print("history_openai_format:", history_openai_format) | |
| # logging.info(f"history_openai_format: {history_openai_format}") | |
| response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=True) | |
| partial_message = "搜索结果来自以下研报:\n" + '\n'.join(i for i in target_files) + '\n\n' | |
| for result_message in extract_partial_message(partial_message, response): | |
| yield result_message | |
| elif flag1 == ['研报问答']: | |
| print("target_files:", target_files) | |
| # logging.info(f"target_files: {target_files}") | |
| QA_pages = [] | |
| if not target_files: | |
| yield "请取消选中研报问答,先进行研报检索,再进行问答。" | |
| else: | |
| for item in target_files: | |
| loader = PyPDFLoader(item) | |
| QA_pages.extend(loader.load_and_split()) | |
| text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
| documents = text_splitter.split_documents(QA_pages) | |
| db = Chroma.from_documents(documents, OpenAIEmbeddings()) | |
| docs = db.similarity_search_with_score(query=message, k=3) | |
| contents = [doc[0] for doc in docs] | |
| relevance = " ".join(doc.page_content for doc in contents) | |
| sys_prompt = """ | |
| 你是一个研报助手,根据这篇文章:{} | |
| 来回复用户的问题,如果这篇文章无法回答用户的问题,你必须根据你的知识面来编造进行专业的回答, | |
| 不能回复不知道,不能回复这篇文章不能回答的这种话语,你只能回复中文。 | |
| """ | |
| sys_prompt = sys_prompt.format(relevance) | |
| history_openai_format = format_messages(sys_prompt, history, message) | |
| print("history_openai_format:", history_openai_format) | |
| # logging.info(f"history_openai_format: {history_openai_format}") | |
| response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=True) | |
| for result_message in extract_partial_message("", response): | |
| yield result_message | |
| elif flag1 == ['研报生成']: | |
| target_files.clear() | |
| sys_prompt = """ | |
| 你是一个研报助手,请根据用户的要求回复问题。 | |
| """ | |
| history_openai_format = format_messages(sys_prompt, history, message) | |
| print("history_openai_format:", history_openai_format) | |
| # logging.info(f"history_openai_format: {history_openai_format}") | |
| response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=True) | |
| for result_message in extract_partial_message("", response): | |
| yield result_message | |
| elif len(flag1) == 2: | |
| yield "请选中一个选项,进行相关问答。" | |
| demo = gr.ChatInterface( | |
| echo, | |
| chatbot=gr.Chatbot(height=430, label="ChatReport"), | |
| textbox=gr.Textbox(placeholder="请输入问题", container=False, scale=7), | |
| title="研报助手", | |
| description="清芬院研报助手", | |
| theme="soft", | |
| additional_inputs=[ | |
| # gr.Radio(["研报问答", "研报生成"], type="index", label = "function"), | |
| # gr.Checkbox(label = "研报问答"), | |
| # gr.Checkbox(label = "研报生成"), | |
| gr.CheckboxGroup(["研报问答", "研报生成"], label="Function"), | |
| gr.Dropdown(topics, type="index"), | |
| # gr.Button(value="Run").click(echo, inputs=["", "", [], None], outputs=[""]) | |
| # btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3]) | |
| # gr.Blocks() | |
| ], | |
| # retry_btn="retry", | |
| undo_btn="清空输入框", | |
| clear_btn="清空聊天记录" | |
| ).queue() | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |
| ''' | |
| flag1: ['研报问答'] | |
| flag2: None | |
| history: [] | |
| message: gg | |
| target_files: set() | |
| ''' |