| import os
|
| os.environ["CHROMA_TELEMETRY_DISABLED"] = "true"
|
| from dotenv import load_dotenv, find_dotenv
|
| from zhipuLLM import ZhipuaiLLM
|
| from zhipuEmbedding import ZhipuAiEmbeddings
|
| from langchain_community.vectorstores import Chroma
|
| from langchain_core.runnables import RunnablePassthrough
|
| from langchain_core.output_parsers import StrOutputParser
|
| from langchain_core.prompts import ChatPromptTemplate
|
| from langchain_core.runnables import RunnableBranch
|
|
|
|
|
| import gradio as gr
|
|
|
|
|
| _ = load_dotenv(find_dotenv())
|
| api_key=os.environ["ZHIPUAI_API_KEY"]
|
|
|
| def combine_docs(docs):
|
| return "\n\n".join(doc.page_content for doc in docs["context"])
|
|
|
|
|
| def show_switch_status(switch_state):
|
| return switch_state
|
|
|
|
|
| def format_chat_history(chatbot):
|
| """将 Gradio Chatbot 格式转为 LangChain 支持的 chat_history 格式"""
|
| formatted_history = []
|
| for human_msg, ai_msg in chatbot:
|
| formatted_history.append(("human", human_msg))
|
| formatted_history.append(("ai", ai_msg))
|
| return formatted_history
|
|
|
| def chatbot_response(input, chatbot, isUseRAG):
|
| """根据开关状态返回提示信息"""
|
|
|
| llm = ZhipuaiLLM(model_name="glm-4-plus", temperature=0.1, api_key=api_key)
|
| if isUseRAG:
|
|
|
|
|
| system_prompt = (
|
| "你是一个问答任务的助手。 "
|
| "请使用检索到的上下文片段回答这个问题。 "
|
| "如果你不知道答案就说不知道。 "
|
| "请使用简洁的话语回答用户。"
|
| "\n\n"
|
| "{context}"
|
| )
|
|
|
| qa_prompt = ChatPromptTemplate(
|
| [
|
| ("system", system_prompt),
|
| ("placeholder", "{chat_history}"),
|
| ("human", "{input}"),
|
| ]
|
| )
|
|
|
| qa_chain = (
|
| RunnablePassthrough.assign(context=combine_docs)
|
| | qa_prompt
|
| | llm
|
| | StrOutputParser()
|
| )
|
|
|
|
|
|
|
| vectordb = Chroma(
|
| persist_directory='data_base/vector_db/chroma',
|
| embedding_function=ZhipuAiEmbeddings()
|
| )
|
|
|
| retriever = vectordb.as_retriever(search_kwargs={"k": 1})
|
|
|
|
|
|
|
| condense_question_system_template = (
|
| "请根据聊天记录完善用户最新的问题,"
|
| "如果用户最新的问题不需要完善则返回用户的问题。"
|
| )
|
|
|
| condense_question_prompt = ChatPromptTemplate([
|
| ("system", condense_question_system_template),
|
| ("placeholder", "{chat_history}"),
|
| ("human", "{input}"),
|
| ])
|
|
|
| retrieve_docs = RunnableBranch(
|
|
|
| (lambda x: not x.get("chat_history", False), (lambda x: x["input"]) | retriever, ),
|
|
|
| condense_question_prompt | llm | StrOutputParser() | retriever,
|
| )
|
|
|
|
|
| qa_history_chain = RunnablePassthrough.assign(
|
| context = (lambda x: x) | retrieve_docs
|
| ).assign(answer=qa_chain)
|
|
|
| result = qa_history_chain.invoke({
|
| "input": input,
|
| "chat_history": format_chat_history(chatbot)
|
| })
|
|
|
| print(result)
|
| chatbot.append((input,result["answer"]))
|
| return [chatbot,input]
|
| else:
|
| result = llm.invoke(input)
|
| print(chatbot)
|
| chatbot.append((input,result.content))
|
| return [chatbot,input]
|
|
|
|
|
| with gr.Blocks() as demo:
|
| gr.Markdown("""
|
| # 🤖 RAG 智能聊天机器人
|
| 支持直接调用大模型或结合本地知识库(RAG)回答问题
|
| """)
|
|
|
| chatbot = gr.Chatbot(
|
| label="对话历史",
|
| height=500,
|
| avatar_images=(None, "https://gradio.s3-us-west-2.amazonaws.com/guides/robot.png")
|
| )
|
|
|
| with gr.Row():
|
| chebox = gr.Checkbox(
|
| label="RAG",
|
| value=False
|
| )
|
|
|
| with gr.Row():
|
| input = gr.Textbox(
|
| label="输入你的问题",
|
| placeholder="例如:",
|
| lines=2,
|
| container=False
|
| )
|
| submit_btn = gr.Button("发送", variant="primary", icon="📤")
|
|
|
| submit_btn.click(
|
| fn=chatbot_response,
|
| inputs=[input, chatbot, chebox],
|
| outputs=[chatbot, input]
|
| )
|
|
|
|
|
| if __name__ == "__main__":
|
| demo.launch(
|
| share=False,
|
| show_error=True,
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|