cryogenic22 commited on
Commit
2329f67
·
verified ·
1 Parent(s): 87da96e

Update utils/database.py

Browse files
Files changed (1) hide show
  1. utils/database.py +30 -14
utils/database.py CHANGED
@@ -1,18 +1,20 @@
1
  # utils/database.py
2
- from langchain.memory import ConversationBufferWindowMemory
 
3
  from langchain_core.messages import (
4
  HumanMessage,
5
  AIMessage,
6
  SystemMessage,
7
- BaseMessage # Added this import
8
  )
 
 
9
  from langchain.chains import ConversationalRetrievalChain
10
  from langchain.chat_models import ChatOpenAI
11
  from langchain.agents import AgentExecutor, Tool, create_openai_tools_agent
12
- from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
13
  from langchain.agents.format_scratchpad.tools import format_to_tool_messages
14
  from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
15
- from langchain_core.runnables import RunnablePassthrough
16
  import os
17
  import streamlit as st
18
  import sqlite3
@@ -130,21 +132,35 @@ def initialize_qa_system(vector_store):
130
  api_key=os.environ.get("OPENAI_API_KEY")
131
  )
132
 
133
- # Create the prompt template
 
 
 
134
  prompt = ChatPromptTemplate.from_messages([
135
  ("system", "You are a helpful assistant analyzing RFP documents."),
136
  MessagesPlaceholder(variable_name="chat_history"),
137
- ("human", "{input}\nContext: {context}")
138
  ])
139
 
140
- # Create retriever function
141
- retriever = vector_store.as_retriever(search_kwargs={"k": 2})
142
-
143
- # Create the chain with proper chat history handling
144
- chain = RunnablePassthrough.assign(
145
- context=lambda x: "\n".join(doc.page_content for doc in retriever.get_relevant_documents(x["input"])),
146
- chat_history=lambda x: x.get("chat_history", [])
147
- ) | prompt | llm
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  return chain
150
 
 
1
  # utils/database.py
2
+ # Update the imports first
3
+ from langchain_community.chat_models import ChatOpenAI
4
  from langchain_core.messages import (
5
  HumanMessage,
6
  AIMessage,
7
  SystemMessage,
8
+ BaseMessage
9
  )
10
+ from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
11
+ from langchain_core.runnables import RunnablePassthrough
12
  from langchain.chains import ConversationalRetrievalChain
13
  from langchain.chat_models import ChatOpenAI
14
  from langchain.agents import AgentExecutor, Tool, create_openai_tools_agent
 
15
  from langchain.agents.format_scratchpad.tools import format_to_tool_messages
16
  from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
17
+
18
  import os
19
  import streamlit as st
20
  import sqlite3
 
132
  api_key=os.environ.get("OPENAI_API_KEY")
133
  )
134
 
135
+ # Create retriever function
136
+ retriever = vector_store.as_retriever(search_kwargs={"k": 2})
137
+
138
+ # Create a simpler prompt template
139
  prompt = ChatPromptTemplate.from_messages([
140
  ("system", "You are a helpful assistant analyzing RFP documents."),
141
  MessagesPlaceholder(variable_name="chat_history"),
142
+ ("human", "{question}")
143
  ])
144
 
145
+ def get_chat_history(inputs):
146
+ chat_history = inputs.get("chat_history", [])
147
+ if not isinstance(chat_history, list):
148
+ return []
149
+ return [msg for msg in chat_history if isinstance(msg, BaseMessage)]
150
+
151
+ def get_context(inputs):
152
+ docs = retriever.get_relevant_documents(inputs["question"])
153
+ return "\n".join(doc.page_content for doc in docs)
154
+
155
+ chain = (
156
+ {
157
+ "question": lambda x: x["input"],
158
+ "chat_history": get_chat_history,
159
+ "context": get_context
160
+ }
161
+ | prompt
162
+ | llm
163
+ )
164
 
165
  return chain
166