cryogenic22 commited on
Commit
028c9cb
·
verified ·
1 Parent(s): fa2e1ef

Update utils/database.py

Browse files
Files changed (1) hide show
  1. utils/database.py +33 -28
utils/database.py CHANGED
@@ -11,7 +11,8 @@ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, Base
11
  from langchain.chains import ConversationalRetrievalChain
12
  from langchain.chat_models import ChatOpenAI
13
  from langchain.agents import initialize_agent
14
- from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate
 
15
 
16
  def create_connection(db_file):
17
  try:
@@ -104,17 +105,13 @@ def format_chat_history(messages: list[BaseMessage]) -> list[dict]:
104
  formatted.append({"role": "system", "content": msg.content})
105
  return formatted
106
 
107
-
108
-
109
-
110
-
111
  def initialize_qa_system(vector_store):
112
  """Initialize QA system with proper chat handling"""
113
  try:
114
  llm = ChatOpenAI(
115
  temperature=0.5,
116
  model_name="gpt-4",
117
- api_key=os.environ.get("OPENAI_API_KEY"),
118
  )
119
 
120
  # Create chat memory
@@ -124,37 +121,45 @@ def initialize_qa_system(vector_store):
124
  k=5
125
  )
126
 
127
- # Create the base QA chain
128
- qa = ConversationalRetrievalChain.from_llm(
129
- llm=llm,
130
- retriever=vector_store.as_retriever(search_kwargs={"k": 2}),
131
- chain_type="stuff",
132
- )
133
-
134
- # Define the tools
135
- tools = [
136
- Tool(
137
- name="RFP_Knowledge_Base",
138
- func=qa.run,
139
- description="Use this tool to analyze RFP documents and answer questions about their content."
140
- )
141
- ]
142
-
143
  # Create the prompt template
144
  prompt = ChatPromptTemplate.from_messages([
145
  ("system", "You are a helpful assistant analyzing RFP documents."),
146
  MessagesPlaceholder(variable_name="chat_history"),
147
- ("human", "{input}"),
148
- MessagesPlaceholder(variable_name="agent_scratchpad"),
149
  ])
150
 
151
- # Create the agent
152
- agent = create_openai_tools_agent(llm, tools, prompt)
 
 
 
 
 
 
 
 
153
 
154
  # Create the agent executor
155
  agent_executor = AgentExecutor(
156
- agent=agent,
157
- tools=tools,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  memory=memory,
159
  verbose=True,
160
  handle_parsing_errors=True
 
11
  from langchain.chains import ConversationalRetrievalChain
12
  from langchain.chat_models import ChatOpenAI
13
  from langchain.agents import initialize_agent
14
+ from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, # utils/database.py
15
+ from langchain_core.runnables import RunnablePassthrough
16
 
17
  def create_connection(db_file):
18
  try:
 
105
  formatted.append({"role": "system", "content": msg.content})
106
  return formatted
107
 
 
 
 
 
108
  def initialize_qa_system(vector_store):
109
  """Initialize QA system with proper chat handling"""
110
  try:
111
  llm = ChatOpenAI(
112
  temperature=0.5,
113
  model_name="gpt-4",
114
+ api_key=os.environ.get("OPENAI_API_KEY")
115
  )
116
 
117
  # Create chat memory
 
121
  k=5
122
  )
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  # Create the prompt template
125
  prompt = ChatPromptTemplate.from_messages([
126
  ("system", "You are a helpful assistant analyzing RFP documents."),
127
  MessagesPlaceholder(variable_name="chat_history"),
128
+ ("human", "{question}"),
129
+ MessagesPlaceholder(variable_name="agent_scratchpad")
130
  ])
131
 
132
+ # Create the RAG chain with lambda function
133
+ rag_chain = (
134
+ {
135
+ "context": lambda x: vector_store.as_retriever().get_relevant_documents(x["question"]),
136
+ "question": RunnablePassthrough(),
137
+ "chat_history": lambda x: memory.chat_memory.messages
138
+ }
139
+ | prompt
140
+ | llm
141
+ )
142
 
143
  # Create the agent executor
144
  agent_executor = AgentExecutor(
145
+ agent=create_openai_tools_agent(
146
+ llm=llm,
147
+ tools=[
148
+ Tool(
149
+ name="RFP_Knowledge_Base",
150
+ func=rag_chain.invoke,
151
+ description="Use this tool to analyze RFP documents and answer questions about their content."
152
+ )
153
+ ],
154
+ prompt=prompt
155
+ ),
156
+ tools=[
157
+ Tool(
158
+ name="RFP_Knowledge_Base",
159
+ func=rag_chain.invoke,
160
+ description="Use this tool to analyze RFP documents and answer questions about their content."
161
+ )
162
+ ],
163
  memory=memory,
164
  verbose=True,
165
  handle_parsing_errors=True