Spaces:
Build error
Build error
Update utils/database.py
Browse files- utils/database.py +31 -8
utils/database.py
CHANGED
|
@@ -7,10 +7,10 @@ from datetime import datetime
|
|
| 7 |
from langchain.chat_models import ChatOpenAI
|
| 8 |
import os
|
| 9 |
from langchain.memory import ConversationBufferWindowMemory
|
| 10 |
-
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
| 11 |
from langchain.chains import ConversationalRetrievalChain
|
| 12 |
from langchain.chat_models import ChatOpenAI
|
| 13 |
-
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 14 |
import os
|
| 15 |
|
| 16 |
def create_connection(db_file):
|
|
@@ -91,6 +91,27 @@ def insert_document(conn, doc_name, doc_content):
|
|
| 91 |
st.error(f"Error inserting document: {e}")
|
| 92 |
return False
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
def initialize_qa_system(vector_store):
|
| 95 |
"""Initialize QA system with proper chat handling"""
|
| 96 |
try:
|
|
@@ -100,12 +121,14 @@ def initialize_qa_system(vector_store):
|
|
| 100 |
api_key=os.environ.get("OPENAI_API_KEY"),
|
| 101 |
)
|
| 102 |
|
| 103 |
-
# Create
|
| 104 |
prompt = ChatPromptTemplate.from_messages([
|
| 105 |
-
(
|
|
|
|
|
|
|
|
|
|
| 106 |
MessagesPlaceholder(variable_name="chat_history"),
|
| 107 |
-
("
|
| 108 |
-
MessagesPlaceholder(variable_name="agent_scratchpad")
|
| 109 |
])
|
| 110 |
|
| 111 |
# Initialize memory with proper configuration
|
|
@@ -119,9 +142,9 @@ def initialize_qa_system(vector_store):
|
|
| 119 |
llm=llm,
|
| 120 |
retriever=vector_store.as_retriever(search_kwargs={"k": 2}),
|
| 121 |
memory=memory,
|
| 122 |
-
combine_docs_chain_kwargs={"prompt": prompt},
|
| 123 |
verbose=True,
|
| 124 |
-
return_source_documents=True
|
|
|
|
| 125 |
)
|
| 126 |
|
| 127 |
return qa_chain
|
|
|
|
| 7 |
from langchain.chat_models import ChatOpenAI
|
| 8 |
import os
|
| 9 |
from langchain.memory import ConversationBufferWindowMemory
|
| 10 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
|
| 11 |
from langchain.chains import ConversationalRetrievalChain
|
| 12 |
from langchain.chat_models import ChatOpenAI
|
| 13 |
+
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate
|
| 14 |
import os
|
| 15 |
|
| 16 |
def create_connection(db_file):
|
|
|
|
| 91 |
st.error(f"Error inserting document: {e}")
|
| 92 |
return False
|
| 93 |
|
| 94 |
+
# utils/database.py
|
| 95 |
+
from langchain.memory import ConversationBufferWindowMemory
|
| 96 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
|
| 97 |
+
from langchain.chains import ConversationalRetrievalChain
|
| 98 |
+
from langchain.chat_models import ChatOpenAI
|
| 99 |
+
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate
|
| 100 |
+
import os
|
| 101 |
+
import streamlit as st
|
| 102 |
+
|
| 103 |
+
def format_chat_history(messages: list[BaseMessage]) -> list[dict]:
|
| 104 |
+
"""Convert chat history to the format expected by langchain"""
|
| 105 |
+
formatted = []
|
| 106 |
+
for msg in messages:
|
| 107 |
+
if isinstance(msg, HumanMessage):
|
| 108 |
+
formatted.append({"role": "user", "content": msg.content})
|
| 109 |
+
elif isinstance(msg, AIMessage):
|
| 110 |
+
formatted.append({"role": "assistant", "content": msg.content})
|
| 111 |
+
elif isinstance(msg, SystemMessage):
|
| 112 |
+
formatted.append({"role": "system", "content": msg.content})
|
| 113 |
+
return formatted
|
| 114 |
+
|
| 115 |
def initialize_qa_system(vector_store):
|
| 116 |
"""Initialize QA system with proper chat handling"""
|
| 117 |
try:
|
|
|
|
| 121 |
api_key=os.environ.get("OPENAI_API_KEY"),
|
| 122 |
)
|
| 123 |
|
| 124 |
+
# Create a custom prompt template
|
| 125 |
prompt = ChatPromptTemplate.from_messages([
|
| 126 |
+
SystemMessage(content=(
|
| 127 |
+
"You are an AI assistant analyzing RFP documents. "
|
| 128 |
+
"Provide clear and concise answers based on the document content."
|
| 129 |
+
)),
|
| 130 |
MessagesPlaceholder(variable_name="chat_history"),
|
| 131 |
+
HumanMessagePromptTemplate.from_template("{input}"),
|
|
|
|
| 132 |
])
|
| 133 |
|
| 134 |
# Initialize memory with proper configuration
|
|
|
|
| 142 |
llm=llm,
|
| 143 |
retriever=vector_store.as_retriever(search_kwargs={"k": 2}),
|
| 144 |
memory=memory,
|
|
|
|
| 145 |
verbose=True,
|
| 146 |
+
return_source_documents=True,
|
| 147 |
+
combine_docs_chain_kwargs={"prompt": prompt}
|
| 148 |
)
|
| 149 |
|
| 150 |
return qa_chain
|