cryogenic22 commited on
Commit
64b3386
·
verified ·
1 Parent(s): 5b139a9

Update utils/database.py

Browse files
Files changed (1) hide show
  1. utils/database.py +38 -32
utils/database.py CHANGED
@@ -1,18 +1,15 @@
1
  # utils/database.py
2
- import streamlit as st
3
- import sqlite3
4
- from sqlite3 import Error
5
- from datetime import datetime
6
- #from langchain.memory import ConversationBufferMemory
7
- from langchain.chat_models import ChatOpenAI
8
- import os
9
  from langchain.memory import ConversationBufferWindowMemory
10
- from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
11
  from langchain.chains import ConversationalRetrievalChain
12
  from langchain.chat_models import ChatOpenAI
13
- from langchain.agents import initialize_agent
14
- from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder # utils/database.py
 
 
15
  from langchain_core.runnables import RunnablePassthrough
 
 
16
 
17
  def create_connection(db_file):
18
  try:
@@ -105,6 +102,19 @@ def format_chat_history(messages: list[BaseMessage]) -> list[dict]:
105
  formatted.append({"role": "system", "content": msg.content})
106
  return formatted
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def initialize_qa_system(vector_store):
109
  """Initialize QA system with proper chat handling"""
110
  try:
@@ -121,48 +131,44 @@ def initialize_qa_system(vector_store):
121
  k=5
122
  )
123
 
124
- # Create the prompt template
125
  prompt = ChatPromptTemplate.from_messages([
126
  ("system", "You are a helpful assistant analyzing RFP documents."),
127
  MessagesPlaceholder(variable_name="chat_history"),
128
- ("human", "{question}"),
129
  MessagesPlaceholder(variable_name="agent_scratchpad")
130
  ])
131
 
132
- # Create the RAG chain with lambda function
133
- rag_chain = (
 
 
 
134
  {
135
- "context": lambda x: vector_store.as_retriever().get_relevant_documents(x["question"]),
136
- "question": RunnablePassthrough(),
137
- "chat_history": lambda x: memory.chat_memory.messages
 
138
  }
139
- | prompt
140
- | llm
 
141
  )
142
 
143
  # Create the agent executor
144
  agent_executor = AgentExecutor(
145
- agent=create_openai_tools_agent(
146
- llm=llm,
147
- tools=[
148
- Tool(
149
- name="RFP_Knowledge_Base",
150
- func=rag_chain.invoke,
151
- description="Use this tool to analyze RFP documents and answer questions about their content."
152
- )
153
- ],
154
- prompt=prompt
155
- ),
156
  tools=[
157
  Tool(
158
  name="RFP_Knowledge_Base",
159
- func=rag_chain.invoke,
160
  description="Use this tool to analyze RFP documents and answer questions about their content."
161
  )
162
  ],
163
  memory=memory,
164
  verbose=True,
165
- handle_parsing_errors=True
 
166
  )
167
 
168
  return agent_executor
 
1
  # utils/database.py
 
 
 
 
 
 
 
2
  from langchain.memory import ConversationBufferWindowMemory
3
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
4
  from langchain.chains import ConversationalRetrievalChain
5
  from langchain.chat_models import ChatOpenAI
6
+ from langchain.agents import AgentExecutor, Tool, create_openai_tools_agent
7
+ from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
8
+ from langchain.agents.format_scratchpad.tools import format_to_tool_messages
9
+ from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
10
  from langchain_core.runnables import RunnablePassthrough
11
+ import os
12
+ import streamlit as st
13
 
14
  def create_connection(db_file):
15
  try:
 
102
  formatted.append({"role": "system", "content": msg.content})
103
  return formatted
104
 
105
+ # utils/database.py
106
+ from langchain.memory import ConversationBufferWindowMemory
107
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
108
+ from langchain.chains import ConversationalRetrievalChain
109
+ from langchain.chat_models import ChatOpenAI
110
+ from langchain.agents import AgentExecutor, Tool, create_openai_tools_agent
111
+ from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
112
+ from langchain.agents.format_scratchpad.tools import format_to_tool_messages
113
+ from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
114
+ from langchain_core.runnables import RunnablePassthrough
115
+ import os
116
+ import streamlit as st
117
+
118
  def initialize_qa_system(vector_store):
119
  """Initialize QA system with proper chat handling"""
120
  try:
 
131
  k=5
132
  )
133
 
134
+ # Create the prompt template with the correct variable names
135
  prompt = ChatPromptTemplate.from_messages([
136
  ("system", "You are a helpful assistant analyzing RFP documents."),
137
  MessagesPlaceholder(variable_name="chat_history"),
138
+ ("user", "{input}\nContext: {context}"),
139
  MessagesPlaceholder(variable_name="agent_scratchpad")
140
  ])
141
 
142
+ # Create retriever function
143
+ retriever = vector_store.as_retriever(search_kwargs={"k": 2})
144
+
145
+ # Create the RAG pipeline
146
+ rag_pipe = (
147
  {
148
+ "context": lambda x: retriever.get_relevant_documents(x["input"]),
149
+ "input": lambda x: x["input"],
150
+ "chat_history": lambda x: memory.chat_memory.messages,
151
+ "agent_scratchpad": lambda x: format_to_tool_messages(x["intermediate_steps"])
152
  }
153
+ | prompt
154
+ | llm.bind(stop=["\nHuman:"])
155
+ | OpenAIToolsAgentOutputParser()
156
  )
157
 
158
  # Create the agent executor
159
  agent_executor = AgentExecutor(
160
+ agent=rag_pipe,
 
 
 
 
 
 
 
 
 
 
161
  tools=[
162
  Tool(
163
  name="RFP_Knowledge_Base",
164
+ func=lambda x: retriever.get_relevant_documents(x),
165
  description="Use this tool to analyze RFP documents and answer questions about their content."
166
  )
167
  ],
168
  memory=memory,
169
  verbose=True,
170
+ handle_parsing_errors=True,
171
+ return_intermediate_steps=True
172
  )
173
 
174
  return agent_executor