Spaces:
Build error
Build error
Update utils/database.py
Browse files- utils/database.py +8 -7
utils/database.py
CHANGED
|
@@ -1646,18 +1646,19 @@ def initialize_qa_system(vector_store):
|
|
| 1646 |
dict: QA system chain or None if initialization fails.
|
| 1647 |
"""
|
| 1648 |
try:
|
| 1649 |
-
|
| 1650 |
temperature=0.5,
|
| 1651 |
model_name="gpt-4",
|
| 1652 |
-
max_tokens=4000,
|
| 1653 |
api_key=os.environ.get("OPENAI_API_KEY")
|
| 1654 |
)
|
| 1655 |
|
| 1656 |
-
# Optimize retriever settings
|
| 1657 |
retriever = vector_store.as_retriever(
|
| 1658 |
search_kwargs={
|
| 1659 |
"k": 3, # Retrieve fewer, more relevant chunks
|
| 1660 |
-
"fetch_k": 5 # Consider more candidates before selecting top k
|
|
|
|
| 1661 |
}
|
| 1662 |
)
|
| 1663 |
|
|
@@ -1691,7 +1692,7 @@ Accuracy: Double-check all information for accuracy and completeness before prov
|
|
| 1691 |
MessagesPlaceholder(variable_name="chat_history"),
|
| 1692 |
("human", "{input}\n\nContext: {context}")
|
| 1693 |
])
|
| 1694 |
-
|
| 1695 |
def get_chat_history(inputs):
|
| 1696 |
chat_history = inputs.get("chat_history", [])
|
| 1697 |
if not isinstance(chat_history, list):
|
|
@@ -1708,8 +1709,8 @@ Accuracy: Double-check all information for accuracy and completeness before prov
|
|
| 1708 |
|
| 1709 |
chain = (
|
| 1710 |
{
|
| 1711 |
-
"context":
|
| 1712 |
-
"chat_history":
|
| 1713 |
"input": lambda x: x["input"]
|
| 1714 |
}
|
| 1715 |
| prompt
|
|
|
|
| 1646 |
dict: QA system chain or None if initialization fails.
|
| 1647 |
"""
|
| 1648 |
try:
|
| 1649 |
+
llm = ChatOpenAI(
|
| 1650 |
temperature=0.5,
|
| 1651 |
model_name="gpt-4",
|
| 1652 |
+
max_tokens=4000,
|
| 1653 |
api_key=os.environ.get("OPENAI_API_KEY")
|
| 1654 |
)
|
| 1655 |
|
| 1656 |
+
# Optimize retriever settings and add source tracking
|
| 1657 |
retriever = vector_store.as_retriever(
|
| 1658 |
search_kwargs={
|
| 1659 |
"k": 3, # Retrieve fewer, more relevant chunks
|
| 1660 |
+
"fetch_k": 5, # Consider more candidates before selecting top k
|
| 1661 |
+
"include_metadata": True # Enable source tracking
|
| 1662 |
}
|
| 1663 |
)
|
| 1664 |
|
|
|
|
| 1692 |
MessagesPlaceholder(variable_name="chat_history"),
|
| 1693 |
("human", "{input}\n\nContext: {context}")
|
| 1694 |
])
|
| 1695 |
+
|
| 1696 |
def get_chat_history(inputs):
|
| 1697 |
chat_history = inputs.get("chat_history", [])
|
| 1698 |
if not isinstance(chat_history, list):
|
|
|
|
| 1709 |
|
| 1710 |
chain = (
|
| 1711 |
{
|
| 1712 |
+
"context": lambda x: get_context_with_sources(retriever, x["input"]),
|
| 1713 |
+
"chat_history": lambda x: format_chat_history(x["chat_history"]),
|
| 1714 |
"input": lambda x: x["input"]
|
| 1715 |
}
|
| 1716 |
| prompt
|