cryogenic22 commited on
Commit
0f519bc
·
verified ·
1 Parent(s): 80c5436

Update utils/database.py

Browse files
Files changed (1) hide show
  1. utils/database.py +8 -7
utils/database.py CHANGED
@@ -1646,18 +1646,19 @@ def initialize_qa_system(vector_store):
1646
  dict: QA system chain or None if initialization fails.
1647
  """
1648
  try:
1649
- llm = ChatOpenAI(
1650
  temperature=0.5,
1651
  model_name="gpt-4",
1652
- max_tokens=4000, # Explicitly set max tokens
1653
  api_key=os.environ.get("OPENAI_API_KEY")
1654
  )
1655
 
1656
- # Optimize retriever settings
1657
  retriever = vector_store.as_retriever(
1658
  search_kwargs={
1659
  "k": 3, # Retrieve fewer, more relevant chunks
1660
- "fetch_k": 5 # Consider more candidates before selecting top k
 
1661
  }
1662
  )
1663
 
@@ -1691,7 +1692,7 @@ Accuracy: Double-check all information for accuracy and completeness before prov
1691
  MessagesPlaceholder(variable_name="chat_history"),
1692
  ("human", "{input}\n\nContext: {context}")
1693
  ])
1694
-
1695
  def get_chat_history(inputs):
1696
  chat_history = inputs.get("chat_history", [])
1697
  if not isinstance(chat_history, list):
@@ -1708,8 +1709,8 @@ Accuracy: Double-check all information for accuracy and completeness before prov
1708
 
1709
  chain = (
1710
  {
1711
- "context": get_context,
1712
- "chat_history": get_chat_history,
1713
  "input": lambda x: x["input"]
1714
  }
1715
  | prompt
 
1646
  dict: QA system chain or None if initialization fails.
1647
  """
1648
  try:
1649
+ llm = ChatOpenAI(
1650
  temperature=0.5,
1651
  model_name="gpt-4",
1652
+ max_tokens=4000,
1653
  api_key=os.environ.get("OPENAI_API_KEY")
1654
  )
1655
 
1656
+ # Optimize retriever settings and add source tracking
1657
  retriever = vector_store.as_retriever(
1658
  search_kwargs={
1659
  "k": 3, # Retrieve fewer, more relevant chunks
1660
+ "fetch_k": 5, # Consider more candidates before selecting top k
1661
+ "include_metadata": True # Enable source tracking
1662
  }
1663
  )
1664
 
 
1692
  MessagesPlaceholder(variable_name="chat_history"),
1693
  ("human", "{input}\n\nContext: {context}")
1694
  ])
1695
+
1696
  def get_chat_history(inputs):
1697
  chat_history = inputs.get("chat_history", [])
1698
  if not isinstance(chat_history, list):
 
1709
 
1710
  chain = (
1711
  {
1712
+ "context": lambda x: get_context_with_sources(retriever, x["input"]),
1713
+ "chat_history": lambda x: format_chat_history(x["chat_history"]),
1714
  "input": lambda x: x["input"]
1715
  }
1716
  | prompt