Spaces:
Sleeping
Sleeping
| from typing import Callable, Optional | |
| import gradio as gr | |
| from langchain.embeddings.openai import OpenAIEmbeddings | |
| from langchain.vectorstores import Zilliz | |
| from langchain.document_loaders import TextLoader | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain.chains import RetrievalQAWithSourcesChain | |
| import uuid | |
| from project.llm.zhipuai_llm import ZhipuAILLM | |
| from project.prompt.answer_by_private_prompt import ( | |
| COMBINE_PROMPT, | |
| EXAMPLE_PROMPT, | |
| QUESTION_PROMPT, | |
| DEFAULT_TEXT_QA_PROMPT, | |
| DEFAULT_REFINE_PROMPT | |
| ) | |
| from langchain.chains.combine_documents.refine import RefineDocumentsChain | |
| from langchain.chains.llm import LLMChain | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain.chains import StuffDocumentsChain | |
| from langchain_core.prompts import PromptTemplate | |
| import hashlib | |
| from project.embeddings.zhipuai_embedding import ZhipuAIEmbeddings | |
| import os | |
| chain: Optional[Callable] = None | |
| db_host = os.getenv("DB_HOST") | |
| db_user = os.getenv("DB_USER") | |
| db_password = os.getenv("DB_PASSWORD") | |
| zhipuai_api_key = os.getenv("ZHIPU_AI_KEY") | |
| def generate_article_id(content): | |
| # 使用SHA-256哈希算法 | |
| sha256 = hashlib.sha256() | |
| # 将文章内容编码为字节流并更新哈希对象 | |
| sha256.update(content.encode('utf-8')) | |
| # 获取哈希值的十六进制表示 | |
| article_id = sha256.hexdigest() | |
| return article_id | |
| def web_loader(file): | |
| if not file: | |
| return "please upload file" | |
| loader = TextLoader(file) | |
| docs = loader.load() | |
| text_splitter = CharacterTextSplitter(chunk_size=512, chunk_overlap=0) | |
| docs = text_splitter.split_documents(docs) | |
| #embeddings = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=openai_key) | |
| embeddings = ZhipuAIEmbeddings(zhipuai_api_key=zhipuai_api_key) | |
| if not embeddings: | |
| return "embeddings not" | |
| texts = [d.page_content for d in docs] | |
| article_ids = [] | |
| # 遍历texts列表 | |
| for text in texts: | |
| # 使用generate_article_id函数生成文章ID,并将其添加到article_ids列表中 | |
| article_id = generate_article_id(text) | |
| article_ids.append(article_id) | |
| docsearch = Zilliz.from_documents( | |
| docs, | |
| embedding=embeddings, | |
| ids=article_ids, | |
| connection_args={ | |
| "uri": db_host, | |
| "user": db_user, | |
| "password": db_password, | |
| "secure": True, | |
| }, | |
| collection_name="LangChainCollectionYin" | |
| ) | |
| if not docsearch: | |
| return "docsearch not" | |
| global chain | |
| #chain = RetrievalQAWithSourcesChain.from_chain_type( | |
| # ZhipuAILLM(model="glm-3-turbo", temperature=0.1, zhipuai_api_key=puzhiai_key), | |
| # chain_type="refine", | |
| # retriever=docsearch.as_retriever(), | |
| #) | |
| #chain = RetrievalQAWithSourcesChain.from_llm( | |
| # ZhipuAILLM(model="glm-3-turbo", temperature=0.1, zhipuai_api_key=puzhiai_key), | |
| # EXAMPLE_PROMPT, | |
| # QUESTION_PROMPT, | |
| # COMBINE_PROMPT, | |
| # retriever=docsearch.as_retriever(), | |
| #) | |
| llm = ZhipuAILLM(model="glm-3-turbo", temperature=0.1, zhipuai_api_key=zhipuai_api_key) | |
| #initial_chain = LLMChain(llm=llm, prompt=DEFAULT_TEXT_QA_PROMPT) | |
| #refine_chain = LLMChain(llm=llm, prompt=DEFAULT_REFINE_PROMPT) | |
| #combine_documents_chain = RefineDocumentsChain( | |
| # initial_llm_chain=initial_chain, | |
| # refine_llm_chain=refine_chain, | |
| # document_variable_name="context_str", | |
| # initial_response_name="existing_answer", | |
| # document_prompt=EXAMPLE_PROMPT, | |
| #) | |
| document_prompt = PromptTemplate( | |
| input_variables=["page_content"], | |
| template="{page_content}" | |
| ) | |
| document_variable_name = "context" | |
| # The prompt here should take as an input variable the | |
| # `document_variable_name` | |
| prompt = PromptTemplate.from_template( | |
| """你是资深的技术支持工程师,请使用提供给你的文档内容去恢复客户问题,不需要编造或者虚构答案,也不需要回答文档之外的内容。 | |
| 请用中文回答。 | |
| 下边是我给你提供的文档,其中文档格式都是一问一答,不允许组装多个答案回答一个问题,并且问题答案也完全来自所提供的回答: | |
| {context} | |
| 问题: {question} | |
| 答:""" | |
| ) | |
| llm_chain = LLMChain(llm=llm, prompt=prompt) | |
| combine_documents_chain = StuffDocumentsChain( | |
| llm_chain=llm_chain, | |
| document_prompt=document_prompt, | |
| document_variable_name=document_variable_name | |
| ) | |
| chain = RetrievalQAWithSourcesChain(combine_documents_chain=combine_documents_chain, | |
| retriever=docsearch.as_retriever(search_kwargs={'k': 3})) | |
| return "success to load data" | |
| def query(question): | |
| global chain | |
| # "What is milvus?" | |
| if not chain: | |
| return "please load the data first" | |
| return chain(inputs={"question": question}, return_only_outputs=True).get( | |
| "answer", "fail to get answer" | |
| ) | |
| if __name__ == "__main__": | |
| block = gr.Blocks() | |
| with block as demo: | |
| gr.Markdown( | |
| """ | |
| <h1><center>Langchain And Zilliz App</center></h1> | |
| v.2.28.15.3 | |
| """ | |
| ) | |
| # url_list_text = gr.Textbox( | |
| # label="url list", | |
| # lines=3, | |
| # placeholder="https://milvus.io/docs/overview.md", | |
| # ) | |
| file = gr.File(label='请上传知识库文件\n可以处理 .txt, .md, .docx, .pdf 结尾的文件', | |
| file_types=['.txt', '.md', '.docx', '.pdf']) | |
| #openai_key_text = gr.Textbox(label="openai api key", type="password", placeholder="sk-******") | |
| #puzhiai_key_text = gr.Textbox(label="puzhi api key", type="password", placeholder="******") | |
| loader_output = gr.Textbox(label="load status") | |
| loader_btn = gr.Button("Load Data") | |
| loader_btn.click( | |
| fn=web_loader, | |
| inputs=[ | |
| file, | |
| ], | |
| outputs=loader_output, | |
| api_name="web_load", | |
| ) | |
| question_text = gr.Textbox( | |
| label="question", | |
| lines=3, | |
| placeholder="What is milvus?", | |
| ) | |
| query_output = gr.Textbox(label="question answer", lines=3) | |
| query_btn = gr.Button("Generate") | |
| query_btn.click( | |
| fn=query, | |
| inputs=[question_text], | |
| outputs=query_output, | |
| api_name="generate_answer", | |
| ) | |
| demo.queue().launch(server_name="0.0.0.0", share=False) |