| import os |
| import time |
| import gradio as gr |
| from tqdm import tqdm |
| from loguru import logger |
| from transformers import AutoTokenizer,AutoModel |
| from duckduckgo_search import ddg_suggestions |
| from duckduckgo_search import ddg_translate, ddg, ddg_news |
|
|
| from langchain.document_loaders import UnstructuredFileLoader |
| from langchain.text_splitter import CharacterTextSplitter,RecursiveCharacterTextSplitter |
| from langchain.llms import OpenAI |
| from langchain.schema import Document |
| from langchain.embeddings import OpenAIEmbeddings |
| from langchain.embeddings.huggingface import HuggingFaceEmbeddings |
| from langchain.vectorstores import FAISS |
| from langchain.chains import ConversationalRetrievalChain,RetrievalQA,LLMChain |
| from langchain.prompts import PromptTemplate |
| from langchain.prompts.prompt import PromptTemplate |
| from langchain.chat_models import ChatOpenAI |
| from langchain import OpenAI,VectorDBQA |
|
|
| def load_model(): |
| tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) |
| |
| model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).quantize(bits=4, compile_parallel_kernel=True, parallel_num=2).float() |
| model = model.eval() |
| return tokenizer,model |
|
|
|
|
| def chat_glm(input, history=None): |
| if history is None: |
| history = [] |
|
|
| tokenizer,model = load_model() |
| response, history = model.chat(tokenizer, input, history) |
| logger.debug("chatglm:", input,response) |
| return history, history |
|
|
| def search_web(query): |
| logger.debug("searchweb:", query) |
| results = ddg(query) |
| web_content = '' |
| if results: |
| for result in results: |
| web_content += result['body'] |
| return web_content |
|
|
| def search_vec(query): |
| logger.debug("searchvec:", query) |
| embedding_model_name = 'GanymedeNil/text2vec-large-chinese' |
| vec_path = 'cache' |
| embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name) |
| vector_store = FAISS.load_local(vec_path,embeddings) |
| |
| qa = VectorDBQA.from_chain_type(llm=OpenAI(), chain_type="stuff", vectorstore=vector_store,return_source_documents=True) |
| result = qa({"query": query}) |
| return result['result'] |
|
|
| def chat_gpt(input, use_web, use_vec, history=None): |
| if history is None: |
| history = [] |
| |
| |
| context = "无" |
| if use_vec: |
| context = search_vec(input) |
| prompt_template = f"""基于以下已知信息,请简洁并专业地回答用户的问题。 |
| 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息"。若答案中存在编造成分,请在该部分开头添加“据我推测”。另外,答案请使用中文。 |
| 已知内容: |
| {context}"""+""" |
| 问题: |
| {question}""" |
| |
| prompt = PromptTemplate(template=prompt_template,input_variables=["question"]) |
| |
| llm = OpenAI(temperature = 0.2) |
| chain = LLMChain(llm=llm, prompt=prompt) |
| result = chain.run(input) |
| return result |
|
|
| def predict(input, |
| large_language_model, |
| use_web, |
| use_vec, |
| openai_key, |
| history=None): |
| logger.debug("predict..",large_language_model,use_web) |
| if openai_key is not None: |
| os.environ['OPENAI_API_KEY'] = openai_key |
| else: |
| return '',"You forgot OpenAI API key","You forgot OpenAI API key" |
| if history == None: |
| history = [] |
|
|
| if large_language_model == "GPT-3.5-turbo": |
| resp = chat_gpt(input, use_web, use_vec, history) |
| elif large_language_model == "ChatGLM-6B-int4": |
| _,resp = chat_glm(input, history) |
| resp = resp[-1][1] |
| elif large_language_model == "Search Web": |
| resp = search_web(input) |
| elif large_language_model == "Search VectorStore": |
| resp = search_vec(input) |
|
|
| history.append((input, resp)) |
| return '', history, history |
|
|
| def clear_session(): |
| return '', None |
|
|
| block = gr.Blocks() |
| with block as demo: |
| gr.Markdown("""<h1><center>MedKBQA(demo)</center></h1> |
| <center><font size=3> |
| 本项目基于LangChain、ChatGLM以及Open AI接口, 提供基于本地医药知识的自动问答应用. <br> |
| </center></font> |
| """) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| model_choose = gr.Accordion("模型选择") |
| with model_choose: |
| large_language_model = gr.Dropdown( |
| ["ChatGLM-6B-int4","GPT-3.5-turbo","Search Web","Search VectorStore"], |
| label="large language model", |
| value="ChatGLM-6B-int4") |
| use_web = gr.Radio(["True", "False"], |
| label="Web Search", |
| value="False") |
| use_vec = gr.Radio(["True", "False"], |
| label="VectorStore Search", |
| value="False") |
| openai_key = gr.Textbox(label="请输入OpenAI API key", type="password") |
| with gr.Column(scale=4): |
| chatbot = gr.Chatbot(label='ChatLLM').style(height=600) |
| message = gr.Textbox(label='请输入问题') |
| state = gr.State() |
|
|
| with gr.Row(): |
| clear_history = gr.Button("🧹 清除历史对话") |
| send = gr.Button("🚀 发送") |
|
|
| send.click(predict, |
| inputs=[ |
| message, large_language_model, use_web, use_vec, openai_key, state |
| ], |
| outputs=[message, chatbot, state]) |
| clear_history.click(fn=clear_session, |
| inputs=[], |
| outputs=[chatbot, state], |
| queue=False) |
|
|
| message.submit(predict, |
| inputs=[ |
| message, large_language_model, use_web, use_vec, openai_key, state |
| ], |
| outputs=[message, chatbot, state]) |
| gr.Markdown("""提醒:<br> |
| 1. 使用时请先选择使用chatglm或者chatgpt进行问答. <br> |
| 2. 使用chatgpt时需要输入您的api key. |
| """) |
| demo.queue().launch(server_name='0.0.0.0', share=False) |