Spaces:
Paused
Paused
Update utils/database.py
Browse files- utils/database.py +27 -27
utils/database.py
CHANGED
|
@@ -11,6 +11,7 @@ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, Base
|
|
| 11 |
from langchain.chains import ConversationalRetrievalChain
|
| 12 |
from langchain.chat_models import ChatOpenAI
|
| 13 |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate
|
|
|
|
| 14 |
|
| 15 |
def create_connection(db_file):
|
| 16 |
try:
|
|
@@ -90,14 +91,6 @@ def insert_document(conn, doc_name, doc_content):
|
|
| 90 |
st.error(f"Error inserting document: {e}")
|
| 91 |
return False
|
| 92 |
|
| 93 |
-
# utils/database.py
|
| 94 |
-
from langchain.memory import ConversationBufferWindowMemory
|
| 95 |
-
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
|
| 96 |
-
from langchain.chains import ConversationalRetrievalChain
|
| 97 |
-
from langchain.chat_models import ChatOpenAI
|
| 98 |
-
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate
|
| 99 |
-
import os
|
| 100 |
-
import streamlit as st
|
| 101 |
|
| 102 |
def format_chat_history(messages: list[BaseMessage]) -> list[dict]:
|
| 103 |
"""Convert chat history to the format expected by langchain"""
|
|
@@ -111,47 +104,54 @@ def format_chat_history(messages: list[BaseMessage]) -> list[dict]:
|
|
| 111 |
formatted.append({"role": "system", "content": msg.content})
|
| 112 |
return formatted
|
| 113 |
|
|
|
|
|
|
|
| 114 |
def initialize_qa_system(vector_store):
|
| 115 |
"""Initialize QA system with proper chat handling"""
|
| 116 |
try:
|
| 117 |
llm = ChatOpenAI(
|
| 118 |
-
temperature=0,
|
| 119 |
model_name="gpt-4",
|
| 120 |
api_key=os.environ.get("OPENAI_API_KEY"),
|
| 121 |
)
|
| 122 |
|
| 123 |
-
# Create
|
| 124 |
-
prompt = ChatPromptTemplate.from_messages([
|
| 125 |
-
SystemMessage(content=(
|
| 126 |
-
"You are an AI assistant analyzing RFP documents. "
|
| 127 |
-
"Provide clear and concise answers based on the document content."
|
| 128 |
-
)),
|
| 129 |
-
MessagesPlaceholder(variable_name="chat_history"),
|
| 130 |
-
HumanMessagePromptTemplate.from_template("{input}"),
|
| 131 |
-
])
|
| 132 |
-
|
| 133 |
-
# Initialize memory with proper configuration
|
| 134 |
memory = ConversationBufferWindowMemory(
|
| 135 |
memory_key="chat_history",
|
| 136 |
return_messages=True,
|
| 137 |
-
k=5
|
| 138 |
)
|
| 139 |
|
| 140 |
-
|
|
|
|
| 141 |
llm=llm,
|
| 142 |
retriever=vector_store.as_retriever(search_kwargs={"k": 2}),
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
verbose=True,
|
| 145 |
-
|
| 146 |
-
|
|
|
|
| 147 |
)
|
| 148 |
|
| 149 |
-
return
|
| 150 |
|
| 151 |
except Exception as e:
|
| 152 |
st.error(f"Error initializing QA system: {e}")
|
| 153 |
return None
|
| 154 |
-
|
| 155 |
def initialize_faiss(embeddings, documents, document_names):
|
| 156 |
"""Initialize FAISS vector store"""
|
| 157 |
try:
|
|
|
|
| 11 |
from langchain.chains import ConversationalRetrievalChain
|
| 12 |
from langchain.chat_models import ChatOpenAI
|
| 13 |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate
|
| 14 |
+
from langchain.agents import initialize_agent
|
| 15 |
|
| 16 |
def create_connection(db_file):
|
| 17 |
try:
|
|
|
|
| 91 |
st.error(f"Error inserting document: {e}")
|
| 92 |
return False
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
def format_chat_history(messages: list[BaseMessage]) -> list[dict]:
|
| 96 |
"""Convert chat history to the format expected by langchain"""
|
|
|
|
| 104 |
formatted.append({"role": "system", "content": msg.content})
|
| 105 |
return formatted
|
| 106 |
|
| 107 |
+
|
| 108 |
+
|
| 109 |
def initialize_qa_system(vector_store):
|
| 110 |
"""Initialize QA system with proper chat handling"""
|
| 111 |
try:
|
| 112 |
llm = ChatOpenAI(
|
| 113 |
+
temperature=0.5,
|
| 114 |
model_name="gpt-4",
|
| 115 |
api_key=os.environ.get("OPENAI_API_KEY"),
|
| 116 |
)
|
| 117 |
|
| 118 |
+
# Create chat memory
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
memory = ConversationBufferWindowMemory(
|
| 120 |
memory_key="chat_history",
|
| 121 |
return_messages=True,
|
| 122 |
+
k=5
|
| 123 |
)
|
| 124 |
|
| 125 |
+
# Create retrieval QA chain
|
| 126 |
+
qa = ConversationalRetrievalChain.from_llm(
|
| 127 |
llm=llm,
|
| 128 |
retriever=vector_store.as_retriever(search_kwargs={"k": 2}),
|
| 129 |
+
chain_type="stuff",
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Create tool for the agent
|
| 133 |
+
qa_tool = Tool(
|
| 134 |
+
name='Knowledge Base',
|
| 135 |
+
func=qa.run,
|
| 136 |
+
description='use this tool when answering questions about the RFP documents'
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Initialize agent
|
| 140 |
+
agent = initialize_agent(
|
| 141 |
+
agent='chat-conversational-react-description',
|
| 142 |
+
tools=[qa_tool],
|
| 143 |
+
llm=llm,
|
| 144 |
verbose=True,
|
| 145 |
+
max_iterations=3,
|
| 146 |
+
early_stopping_method='generate',
|
| 147 |
+
memory=memory,
|
| 148 |
)
|
| 149 |
|
| 150 |
+
return agent
|
| 151 |
|
| 152 |
except Exception as e:
|
| 153 |
st.error(f"Error initializing QA system: {e}")
|
| 154 |
return None
|
|
|
|
| 155 |
def initialize_faiss(embeddings, documents, document_names):
|
| 156 |
"""Initialize FAISS vector store"""
|
| 157 |
try:
|