Spaces:
Sleeping
Sleeping
Update pages/2_🧠_context_aware_chatbot.py
Browse files
pages/2_🧠_context_aware_chatbot.py
CHANGED
|
@@ -6,6 +6,8 @@ from langchain_openai import ChatOpenAI
|
|
| 6 |
from langchain.chains.conversation.base import ConversationChain
|
| 7 |
from langchain.memory.buffer import ConversationBufferMemory
|
| 8 |
from langchain.memory import ConversationSummaryMemory
|
|
|
|
|
|
|
| 9 |
|
| 10 |
st.set_page_config(page_title="Context aware chatbot", page_icon="🧠")
|
| 11 |
st.header('Context aware chatbot')
|
|
@@ -18,11 +20,33 @@ class ContextChatbot:
|
|
| 18 |
|
| 19 |
@st.cache_resource
|
| 20 |
def setup_chain(_self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
summarizer = ChatOpenAI(model_name= "gpt-4o", temperature=0, streaming=True)
|
| 22 |
-
memory = ConversationSummaryMemory(llm = summarizer)
|
| 23 |
llm = ChatOpenAI(model_name=_self.openai_model, temperature=0, streaming=True)
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
@utils.enable_chat_history
|
| 28 |
def main(self):
|
|
|
|
| 6 |
from langchain.chains.conversation.base import ConversationChain
|
| 7 |
from langchain.memory.buffer import ConversationBufferMemory
|
| 8 |
from langchain.memory import ConversationSummaryMemory
|
| 9 |
+
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
|
| 10 |
+
from langchain.prompts import PromptTemplate
|
| 11 |
|
| 12 |
st.set_page_config(page_title="Context aware chatbot", page_icon="🧠")
|
| 13 |
st.header('Context aware chatbot')
|
|
|
|
| 20 |
|
| 21 |
@st.cache_resource
|
| 22 |
def setup_chain(_self):
|
| 23 |
+
prompt_template = """
|
| 24 |
+
Use the following pieces of context to answer the user's question about sickle cell.
|
| 25 |
+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
| 26 |
+
----------------
|
| 27 |
+
{context}
|
| 28 |
+
|
| 29 |
+
Question: {question}
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
PROMPT = PromptTemplate(
|
| 33 |
+
template=prompt_template, input_variables=["context", "question"]
|
| 34 |
+
)
|
| 35 |
+
chain_type_kwargs = {"prompt": PROMPT}
|
| 36 |
summarizer = ChatOpenAI(model_name= "gpt-4o", temperature=0, streaming=True)
|
|
|
|
| 37 |
llm = ChatOpenAI(model_name=_self.openai_model, temperature=0, streaming=True)
|
| 38 |
+
|
| 39 |
+
chat_history = []
|
| 40 |
+
|
| 41 |
+
qa = ConversationalRetrievalChain.from_llm(
|
| 42 |
+
llm = llm,
|
| 43 |
+
chain_type = "stuff",
|
| 44 |
+
memory = ConversationSummaryMemory(llm = summarizer, memory_key='chat_history', input_key='question', output_key= 'answer', return_messages=True),
|
| 45 |
+
retriever = vectorstore.as_retriever(k = 3, search_type="mmr"),
|
| 46 |
+
return_source_documents=True,
|
| 47 |
+
combine_docs_chain_kwargs=chain_type_kwargs
|
| 48 |
+
)
|
| 49 |
+
return qa
|
| 50 |
|
| 51 |
@utils.enable_chat_history
|
| 52 |
def main(self):
|