Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| from threading import Lock | |
| from typing import Any, Dict, Optional, Tuple | |
| import gradio as gr | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.prompts.chat import (ChatPromptTemplate, | |
| HumanMessagePromptTemplate, | |
| SystemMessagePromptTemplate) | |
| from src.core.chunking import chunk_file | |
| from src.core.embedding import embed_files | |
| from src.core.parsing import read_file | |
| VECTOR_STORE = "faiss" | |
| MODEL = "openai" | |
| EMBEDDING = "openai" | |
| MODEL = "gpt-3.5-turbo-16k" | |
| K = 5 | |
| USE_VERBOSE = True | |
| API_KEY = os.environ["OPENAI_API_KEY"] | |
| system_template = """ | |
| The context below contains excerpts from 'How to Win Friends & Influence People,' by Dail Carnegie. You must only use the information in the context below to formulate your response. If there is not enough information to formulate a response, you must respond with | |
| "I'm sorry, but I can't find the answer to your question in, the book How to Win Friends & Influence People.". However, if there is enough information to formulate a response, you must start your response with "Dale says: ". | |
| Begin context: | |
| {context} | |
| End context. | |
| {chat_history} | |
| """ | |
| # Create the chat prompt templates | |
| messages = [ | |
| SystemMessagePromptTemplate.from_template(system_template), | |
| HumanMessagePromptTemplate.from_template("{question}") | |
| ] | |
| qa_prompt = ChatPromptTemplate.from_messages(messages) | |
| class AnswerConversationBufferMemory(ConversationBufferMemory): | |
| def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: | |
| return super(AnswerConversationBufferMemory, self).save_context(inputs,{'response': outputs['answer']}) | |
| def getretriever(): | |
| with open("./resources/How_To_Win_Friends_And_Influence_People_-_Dale_Carnegie.pdf", 'rb') as uploaded_file: | |
| try: | |
| file = read_file(uploaded_file) | |
| except Exception as e: | |
| print(e) | |
| chunked_file = chunk_file(file, chunk_size=512, chunk_overlap=0) | |
| folder_index = embed_files( | |
| files=[chunked_file], | |
| embedding=EMBEDDING, | |
| vector_store=VECTOR_STORE, | |
| openai_api_key=API_KEY, | |
| ) | |
| return folder_index.index.as_retriever(verbose=True, search_type="similarity", search_kwargs={"k": K}) | |
| retriever = getretriever() | |
| def predict(message): | |
| print(message) | |
| msgJson = json.loads(message) | |
| print(msgJson) | |
| messages = [ | |
| SystemMessagePromptTemplate.from_template(system_template), | |
| HumanMessagePromptTemplate.from_template("{question}") | |
| ] | |
| qa_prompt = ChatPromptTemplate.from_messages(messages) | |
| llm = ChatOpenAI( | |
| openai_api_key=API_KEY, | |
| model_name=MODEL, | |
| verbose=True) | |
| memory = AnswerConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
| for msg in msgJson["history"]: | |
| memory.save_context({"input": msg[0]}, {"answer": msg[1]}) | |
| chain = ConversationalRetrievalChain.from_llm( | |
| llm, | |
| retriever=retriever, | |
| return_source_documents=USE_VERBOSE, | |
| memory=memory, | |
| verbose=USE_VERBOSE, | |
| combine_docs_chain_kwargs={"prompt": qa_prompt}) | |
| chain.rephrase_question = False | |
| lock = Lock() | |
| lock.acquire() | |
| try: | |
| output = chain({"question": msgJson["question"]}) | |
| output = output["answer"] | |
| except Exception as e: | |
| print(e) | |
| raise e | |
| finally: | |
| lock.release() | |
| return output | |
| def getanswer(chain, question, history): | |
| if hasattr(chain, "value"): | |
| chain = chain.value | |
| if hasattr(history, "value"): | |
| history = history.value | |
| if hasattr(question, "value"): | |
| question = question.value | |
| history = history or [] | |
| lock = Lock() | |
| lock.acquire() | |
| try: | |
| output = chain({"question": question}) | |
| output = output["answer"] | |
| history.append((question, output)) | |
| except Exception as e: | |
| raise e | |
| finally: | |
| lock.release() | |
| return history, history, gr.update(value="") | |
| def load_chain(inputs = None): | |
| llm = ChatOpenAI( | |
| openai_api_key=API_KEY, | |
| model_name=MODEL, | |
| verbose=True) | |
| chain = ConversationalRetrievalChain.from_llm( | |
| llm, | |
| retriever=retriever, | |
| return_source_documents=USE_VERBOSE, | |
| memory=AnswerConversationBufferMemory(memory_key="chat_history", return_messages=True), | |
| verbose=USE_VERBOSE, | |
| combine_docs_chain_kwargs={"prompt": qa_prompt}) | |
| return chain | |
| with gr.Blocks() as block: | |
| with gr.Row(): | |
| with gr.Column(scale=0.75): | |
| with gr.Row(): | |
| gr.Markdown("<h1>How to Win Friends & Influence People</h1>") | |
| with gr.Row(): | |
| gr.Markdown("by Dale Carnegie") | |
| chatbot = gr.Chatbot(elem_id="chatbot").style(height=600) | |
| with gr.Row(): | |
| message = gr.Textbox( | |
| label="", | |
| placeholder="How to Win Friends...", | |
| lines=1, | |
| ) | |
| with gr.Row(): | |
| submit = gr.Button(value="Send", variant="primary", scale=1) | |
| state = gr.State() | |
| chain_state = gr.State(load_chain) | |
| submit.click(getanswer, inputs=[chain_state, message, state], outputs=[chatbot, state, message]) | |
| message.submit(getanswer, inputs=[chain_state, message, state], outputs=[chatbot, state, message]) | |
| with gr.Column(scale=0.25): | |
| with gr.Row(): | |
| gr.Markdown("<h1><center>Suggestions</center></h1>") | |
| ex1 = gr.Button(value="How do I know if I'm talking about myself too much?", variant="primary") | |
| ex1.click(getanswer, inputs=[chain_state, ex1, state], outputs=[chatbot, state, message]) | |
| ex2 = gr.Button(value="What do people enjoy talking about the most?", variant="primary") | |
| ex2.click(getanswer, inputs=[chain_state, ex2, state], outputs=[chatbot, state, message]) | |
| ex4 = gr.Button(value="Why should I try to get along with people better?", variant="primary") | |
| ex4.click(getanswer, inputs=[chain_state, ex4, state], outputs=[chatbot, state, message]) | |
| ex5 = gr.Button(value="How do I cite a Reddit thread?", variant="primary") | |
| ex5.click(getanswer, inputs=[chain_state, ex5, state], outputs=[chatbot, state, message]) | |
| predictBtn = gr.Button(value="Predict", visible=False) | |
| predictBtn.click(predict, inputs=[message], outputs=[message]) | |
| block.launch(debug=True) |