| import gradio as gr | |
| import os | |
| from langchain import OpenAI, ConversationChain | |
| from langchain.prompts import PromptTemplate | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain.vectorstores import Chroma | |
| from langchain.docstore.document import Document | |
| from langchain.embeddings import HuggingFaceInstructEmbeddings | |
| from langchain.chains.conversation.memory import ConversationBufferMemory | |
| from langchain.chains import RetrievalQAWithSourcesChain | |
| from langchain.chains.conversation.memory import ConversationEntityMemory | |
| from langchain.chains.conversation.prompt import ENTITY_MEMORY_CONVERSATION_TEMPLATE | |
| from langchain import LLMChain | |
| persist_directory="db" | |
| llm=OpenAI(model_name = "text-davinci-003", temperature=0) | |
| model_name = "hkunlp/instructor-large" | |
| embed_instruction = "Represent the text from the BMW website for retrieval" | |
| query_instruction = "Query the most relevant text from the BMW website" | |
| embeddings = HuggingFaceInstructEmbeddings(model_name=model_name, embed_instruction=embed_instruction, query_instruction=query_instruction) | |
| vectordb = Chroma(persist_directory=persist_directory, embedding_function=embeddings) | |
| chain = RetrievalQAWithSourcesChain.from_chain_type(llm, chain_type="stuff", retriever=vectordb.as_retriever()) | |
| def chat(message, history): | |
| history = history or [] | |
| response = "" | |
| markdown = "" | |
| try: | |
| response = chain({"question": f"{message}"}, return_only_outputs=True) | |
| print('got response') | |
| markdown = generate_markdown(response) | |
| except Exception as e: | |
| print(f"Erorr: {e}") | |
| history.append((message, markdown)) | |
| return history, history | |
| def generate_markdown(obj): | |
| print('generating markdown') | |
| md_string = "" | |
| if 'answer' in obj: | |
| md_string += f"**Answer:**\n\n{obj['answer']}\n" | |
| if 'sources' in obj: | |
| sources_list = obj['sources'].strip().split('\n') | |
| md_string += "**Sources:**\n\n" | |
| for i, source in enumerate(sources_list): | |
| md_string += f"{i + 1}. {source}\n" | |
| return md_string | |
| with gr.Blocks() as demo: | |
| gr.Markdown("<h3><center>BMW Chat Bot</center></h3>") | |
| gr.Markdown("<p><center>Ask questions about BMW</center></p>") | |
| chatbot = gr.Chatbot() | |
| with gr.Row(): | |
| inp = gr.Textbox(placeholder="Question",label =None) | |
| btn = gr.Button("Run").style(full_width=False) | |
| state = gr.State() | |
| agent_state = gr.State() | |
| btn.click(chat, [inp, state],[chatbot, state]) | |
| gr.Examples( | |
| examples=[ | |
| "What is BMW doing about sustainability?", | |
| "What is the future of BMW?" | |
| ], | |
| inputs=inp, | |
| ) | |
| if __name__ == '__main__': | |
| demo.launch() | |