cryogenic22 commited on
Commit
cbeca91
·
verified ·
1 Parent(s): d39a1aa

Update utils/database.py

Browse files
Files changed (1) hide show
  1. utils/database.py +21 -13
utils/database.py CHANGED
@@ -337,15 +337,17 @@ def display_vector_store_info():
337
  st.error(traceback.format_exc())
338
 
339
  def initialize_qa_system(vector_store):
340
- """Initialize QA system with clean response formatting."""
341
  try:
342
  llm = ChatOpenAI(
343
  temperature=0.5,
344
  model_name="gpt-4",
345
  api_key=os.environ.get("OPENAI_API_KEY")
346
  )
347
- # Create retriever function
 
348
  retriever = vector_store.as_retriever(search_kwargs={"k": 2})
 
349
  # Create a template that enforces clean formatting
350
  prompt = ChatPromptTemplate.from_messages([
351
  ("system", """You are a helpful assistant analyzing RFP documents.
@@ -362,6 +364,22 @@ def initialize_qa_system(vector_store):
362
  ("human", "{input}\n\nContext: {context}")
363
  ])
364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  def format_response(response_text):
366
  """Clean up the response formatting."""
367
  # Remove technical metadata
@@ -375,20 +393,10 @@ def initialize_qa_system(vector_store):
375
  response_text = response_text.replace('\\n', '\n')
376
 
377
  return response_text
378
- def get_context(inputs):
379
- docs = retriever.get_relevant_documents(inputs["input"])
380
- context_parts = []
381
- for doc in docs:
382
- source = doc.metadata.get('source', 'Unknown source')
383
- context_parts.append(f"\nFrom {source}:\n{doc.page_content}")
384
- return "\n".join(context_parts)
385
 
386
  chain = (
387
  {
388
- "context": lambda x: "\n".join(
389
- f"\nFrom {doc.metadata.get('source', 'Unknown')}:\n{doc.page_content}"
390
- for doc in retriever.get_relevant_documents(x["input"])
391
- ),
392
  "chat_history": get_chat_history,
393
  "input": lambda x: x["input"]
394
  }
 
337
  st.error(traceback.format_exc())
338
 
339
  def initialize_qa_system(vector_store):
340
+ """Initialize QA system with proper chat handling."""
341
  try:
342
  llm = ChatOpenAI(
343
  temperature=0.5,
344
  model_name="gpt-4",
345
  api_key=os.environ.get("OPENAI_API_KEY")
346
  )
347
+
348
+ # Create retriever function
349
  retriever = vector_store.as_retriever(search_kwargs={"k": 2})
350
+
351
  # Create a template that enforces clean formatting
352
  prompt = ChatPromptTemplate.from_messages([
353
  ("system", """You are a helpful assistant analyzing RFP documents.
 
364
  ("human", "{input}\n\nContext: {context}")
365
  ])
366
 
367
+ def get_chat_history(inputs):
368
+ """Get formatted chat history."""
369
+ chat_history = inputs.get("chat_history", [])
370
+ if not isinstance(chat_history, list):
371
+ return []
372
+ return [msg for msg in chat_history if isinstance(msg, BaseMessage)]
373
+
374
+ def get_context(inputs):
375
+ """Get formatted context from documents."""
376
+ docs = retriever.get_relevant_documents(inputs["input"])
377
+ context_parts = []
378
+ for doc in docs:
379
+ source = doc.metadata.get('source', 'Unknown source')
380
+ context_parts.append(f"\nFrom {source}:\n{doc.page_content}")
381
+ return "\n".join(context_parts)
382
+
383
  def format_response(response_text):
384
  """Clean up the response formatting."""
385
  # Remove technical metadata
 
393
  response_text = response_text.replace('\\n', '\n')
394
 
395
  return response_text
 
 
 
 
 
 
 
396
 
397
  chain = (
398
  {
399
+ "context": get_context,
 
 
 
400
  "chat_history": get_chat_history,
401
  "input": lambda x: x["input"]
402
  }