Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,7 +6,7 @@ from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTem
|
|
| 6 |
from langchain_community.vectorstores import Chroma
|
| 7 |
from langchain_community.embeddings import OpenAIEmbeddings
|
| 8 |
from langchain.chat_models import ChatOpenAI
|
| 9 |
-
from langchain.schema import SystemMessage, HumanMessage
|
| 10 |
from PyPDF2 import PdfReader
|
| 11 |
import aiohttp
|
| 12 |
from io import BytesIO
|
|
@@ -15,7 +15,7 @@ from io import BytesIO
|
|
| 15 |
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
|
| 16 |
|
| 17 |
# Set up prompts
|
| 18 |
-
system_template = "
|
| 19 |
system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
|
| 20 |
|
| 21 |
human_template = "Context:\n{context}\n\nQuestion:\n{question}"
|
|
@@ -29,7 +29,7 @@ class RetrievalAugmentedQAPipeline:
|
|
| 29 |
self.llm = llm
|
| 30 |
self.vector_db = vector_db
|
| 31 |
|
| 32 |
-
async def arun_pipeline(self, user_query: str
|
| 33 |
context_docs = self.vector_db.similarity_search(user_query, k=2)
|
| 34 |
context_list = [doc.page_content for doc in context_docs]
|
| 35 |
context_prompt = "\n".join(context_list)
|
|
@@ -38,9 +38,7 @@ class RetrievalAugmentedQAPipeline:
|
|
| 38 |
if len(context_prompt) > max_context_length:
|
| 39 |
context_prompt = context_prompt[:max_context_length]
|
| 40 |
|
| 41 |
-
messages =
|
| 42 |
-
messages.extend(chat_history)
|
| 43 |
-
messages.append(HumanMessage(content=human_template.format(context=context_prompt, question=user_query)))
|
| 44 |
|
| 45 |
response = await self.llm.agenerate([messages])
|
| 46 |
return {"response": response.generations[0][0].text}
|
|
@@ -88,36 +86,13 @@ async def main():
|
|
| 88 |
# Streamlit UI
|
| 89 |
st.title("Ask About AI!")
|
| 90 |
|
| 91 |
-
# Initialize session state for chat history
|
| 92 |
-
if "chat_history" not in st.session_state:
|
| 93 |
-
st.session_state.chat_history = []
|
| 94 |
-
|
| 95 |
pipeline = initialize_pipeline()
|
| 96 |
|
| 97 |
-
# Display chat history
|
| 98 |
-
for message in st.session_state.chat_history:
|
| 99 |
-
if isinstance(message, HumanMessage):
|
| 100 |
-
st.write("You:", message.content)
|
| 101 |
-
elif isinstance(message, AIMessage):
|
| 102 |
-
st.write("AI:", message.content)
|
| 103 |
-
|
| 104 |
user_query = st.text_input("Enter your question about AI:")
|
| 105 |
|
| 106 |
if user_query:
|
| 107 |
-
# Add user message to chat history
|
| 108 |
-
st.session_state.chat_history.append(HumanMessage(content=user_query))
|
| 109 |
-
|
| 110 |
with st.spinner("Generating response..."):
|
| 111 |
-
result = asyncio.run(pipeline.arun_pipeline(user_query
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
st.session_state.chat_history.append(ai_message)
|
| 116 |
-
|
| 117 |
-
# Display the latest response
|
| 118 |
-
st.write("AI:", result["response"])
|
| 119 |
-
|
| 120 |
-
# Add a button to clear chat history
|
| 121 |
-
if st.button("Clear Chat History"):
|
| 122 |
-
st.session_state.chat_history = []
|
| 123 |
-
st.experimental_rerun()
|
|
|
|
| 6 |
from langchain_community.vectorstores import Chroma
|
| 7 |
from langchain_community.embeddings import OpenAIEmbeddings
|
| 8 |
from langchain.chat_models import ChatOpenAI
|
| 9 |
+
from langchain.schema import SystemMessage, HumanMessage
|
| 10 |
from PyPDF2 import PdfReader
|
| 11 |
import aiohttp
|
| 12 |
from io import BytesIO
|
|
|
|
| 15 |
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
|
| 16 |
|
| 17 |
# Set up prompts
|
| 18 |
+
system_template = "Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer."
|
| 19 |
system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
|
| 20 |
|
| 21 |
human_template = "Context:\n{context}\n\nQuestion:\n{question}"
|
|
|
|
| 29 |
self.llm = llm
|
| 30 |
self.vector_db = vector_db
|
| 31 |
|
| 32 |
+
async def arun_pipeline(self, user_query: str):
|
| 33 |
context_docs = self.vector_db.similarity_search(user_query, k=2)
|
| 34 |
context_list = [doc.page_content for doc in context_docs]
|
| 35 |
context_prompt = "\n".join(context_list)
|
|
|
|
| 38 |
if len(context_prompt) > max_context_length:
|
| 39 |
context_prompt = context_prompt[:max_context_length]
|
| 40 |
|
| 41 |
+
messages = chat_prompt.format_prompt(context=context_prompt, question=user_query).to_messages()
|
|
|
|
|
|
|
| 42 |
|
| 43 |
response = await self.llm.agenerate([messages])
|
| 44 |
return {"response": response.generations[0][0].text}
|
|
|
|
| 86 |
# Streamlit UI
|
| 87 |
st.title("Ask About AI!")
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
pipeline = initialize_pipeline()
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
user_query = st.text_input("Enter your question about AI:")
|
| 92 |
|
| 93 |
if user_query:
|
|
|
|
|
|
|
|
|
|
| 94 |
with st.spinner("Generating response..."):
|
| 95 |
+
result = asyncio.run(pipeline.arun_pipeline(user_query))
|
| 96 |
|
| 97 |
+
st.write("Response:")
|
| 98 |
+
st.write(result["response"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|