Spaces:
Paused
Paused
Update utils/database.py
Browse files- utils/database.py +30 -14
utils/database.py
CHANGED
|
@@ -1,18 +1,20 @@
|
|
| 1 |
# utils/database.py
|
| 2 |
-
|
|
|
|
| 3 |
from langchain_core.messages import (
|
| 4 |
HumanMessage,
|
| 5 |
AIMessage,
|
| 6 |
SystemMessage,
|
| 7 |
-
BaseMessage
|
| 8 |
)
|
|
|
|
|
|
|
| 9 |
from langchain.chains import ConversationalRetrievalChain
|
| 10 |
from langchain.chat_models import ChatOpenAI
|
| 11 |
from langchain.agents import AgentExecutor, Tool, create_openai_tools_agent
|
| 12 |
-
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 13 |
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
|
| 14 |
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
|
| 15 |
-
|
| 16 |
import os
|
| 17 |
import streamlit as st
|
| 18 |
import sqlite3
|
|
@@ -130,21 +132,35 @@ def initialize_qa_system(vector_store):
|
|
| 130 |
api_key=os.environ.get("OPENAI_API_KEY")
|
| 131 |
)
|
| 132 |
|
| 133 |
-
# Create
|
|
|
|
|
|
|
|
|
|
| 134 |
prompt = ChatPromptTemplate.from_messages([
|
| 135 |
("system", "You are a helpful assistant analyzing RFP documents."),
|
| 136 |
MessagesPlaceholder(variable_name="chat_history"),
|
| 137 |
-
("human", "{
|
| 138 |
])
|
| 139 |
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
return chain
|
| 150 |
|
|
|
|
| 1 |
# utils/database.py
|
| 2 |
+
# Update the imports first
|
| 3 |
+
from langchain_community.chat_models import ChatOpenAI
|
| 4 |
from langchain_core.messages import (
|
| 5 |
HumanMessage,
|
| 6 |
AIMessage,
|
| 7 |
SystemMessage,
|
| 8 |
+
BaseMessage
|
| 9 |
)
|
| 10 |
+
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 11 |
+
from langchain_core.runnables import RunnablePassthrough
|
| 12 |
from langchain.chains import ConversationalRetrievalChain
|
| 13 |
from langchain.chat_models import ChatOpenAI
|
| 14 |
from langchain.agents import AgentExecutor, Tool, create_openai_tools_agent
|
|
|
|
| 15 |
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
|
| 16 |
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
|
| 17 |
+
|
| 18 |
import os
|
| 19 |
import streamlit as st
|
| 20 |
import sqlite3
|
|
|
|
| 132 |
api_key=os.environ.get("OPENAI_API_KEY")
|
| 133 |
)
|
| 134 |
|
| 135 |
+
# Create retriever function
|
| 136 |
+
retriever = vector_store.as_retriever(search_kwargs={"k": 2})
|
| 137 |
+
|
| 138 |
+
# Create a simpler prompt template
|
| 139 |
prompt = ChatPromptTemplate.from_messages([
|
| 140 |
("system", "You are a helpful assistant analyzing RFP documents."),
|
| 141 |
MessagesPlaceholder(variable_name="chat_history"),
|
| 142 |
+
("human", "{question}")
|
| 143 |
])
|
| 144 |
|
| 145 |
+
def get_chat_history(inputs):
|
| 146 |
+
chat_history = inputs.get("chat_history", [])
|
| 147 |
+
if not isinstance(chat_history, list):
|
| 148 |
+
return []
|
| 149 |
+
return [msg for msg in chat_history if isinstance(msg, BaseMessage)]
|
| 150 |
+
|
| 151 |
+
def get_context(inputs):
|
| 152 |
+
docs = retriever.get_relevant_documents(inputs["question"])
|
| 153 |
+
return "\n".join(doc.page_content for doc in docs)
|
| 154 |
+
|
| 155 |
+
chain = (
|
| 156 |
+
{
|
| 157 |
+
"question": lambda x: x["input"],
|
| 158 |
+
"chat_history": get_chat_history,
|
| 159 |
+
"context": get_context
|
| 160 |
+
}
|
| 161 |
+
| prompt
|
| 162 |
+
| llm
|
| 163 |
+
)
|
| 164 |
|
| 165 |
return chain
|
| 166 |
|