kamkol commited on
Commit
db2c124
·
1 Parent(s): 3f051e5

Fix metadata association to display correct page numbers in sources

Browse files
Files changed (1) hide show
  1. app.py +26 -11
app.py CHANGED
@@ -28,18 +28,20 @@ Question:
28
  user_role_prompt = UserRolePrompt(user_prompt_template)
29
 
30
  class RetrievalAugmentedQAPipeline:
31
- def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase, metadata: List[Dict[str, Any]] = None) -> None:
32
  self.llm = llm
33
  self.vector_db_retriever = vector_db_retriever
34
  self.metadata = metadata or []
35
  self.text_to_metadata = {}
36
 
37
- # Create lookup for text to metadata
38
- if metadata:
39
- texts = [key for key in self.vector_db_retriever.vectors.keys()]
40
  for i, text in enumerate(texts):
41
- if i < len(metadata):
42
- self.text_to_metadata[text] = metadata[i]
 
 
43
 
44
  async def arun_pipeline(self, user_query: str):
45
  context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
@@ -55,7 +57,18 @@ class RetrievalAugmentedQAPipeline:
55
  if text in self.text_to_metadata:
56
  sources.append(self.text_to_metadata[text])
57
  else:
58
- sources.append({"filename": "unknown", "page": "unknown"})
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  formatted_system_prompt = system_role_prompt.create_message()
61
 
@@ -85,10 +98,11 @@ def load_preprocessed_data():
85
  for key, vector in data['vectors'].items():
86
  vector_db.insert(key, vector)
87
 
88
- # Get metadata if available
89
  metadata = data.get('metadata', [])
 
90
 
91
- return vector_db, metadata
92
 
93
  @cl.on_chat_start
94
  async def on_chat_start():
@@ -121,7 +135,7 @@ The application requires preprocessing of PDF documents to build a knowledge bas
121
 
122
  # Load pre-processed data
123
  start_time = time.time()
124
- vector_db, metadata = load_preprocessed_data()
125
  load_time = time.time() - start_time
126
  print(f"Loaded vector database in {load_time:.2f} seconds")
127
 
@@ -131,7 +145,8 @@ The application requires preprocessing of PDF documents to build a knowledge bas
131
  retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
132
  vector_db_retriever=vector_db,
133
  llm=chat_openai,
134
- metadata=metadata
 
135
  )
136
 
137
  # Let the user know that the system is ready
 
28
  user_role_prompt = UserRolePrompt(user_prompt_template)
29
 
30
  class RetrievalAugmentedQAPipeline:
31
+ def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase, metadata: List[Dict[str, Any]] = None, texts: List[str] = None) -> None:
32
  self.llm = llm
33
  self.vector_db_retriever = vector_db_retriever
34
  self.metadata = metadata or []
35
  self.text_to_metadata = {}
36
 
37
+ # Ensure we have the original texts that match the metadata
38
+ if metadata and texts and len(texts) == len(metadata):
39
+ # Create a direct mapping from text to its metadata using the original texts
40
  for i, text in enumerate(texts):
41
+ self.text_to_metadata[text] = metadata[i]
42
+ print(f"Successfully mapped {len(self.text_to_metadata)} text chunks to metadata")
43
+ else:
44
+ print(f"Warning: Metadata mapping not created. Metadata: {len(metadata) if metadata else 0}, Texts: {len(texts) if texts else 0}")
45
 
46
  async def arun_pipeline(self, user_query: str):
47
  context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
 
57
  if text in self.text_to_metadata:
58
  sources.append(self.text_to_metadata[text])
59
  else:
60
+ # If exact text not found, try finding most similar text
61
+ # This is a fallback mechanism
62
+ found = False
63
+ for orig_text, meta in self.text_to_metadata.items():
64
+ # Simple overlap check - if 80% of the text matches
65
+ if len(set(text.split()).intersection(set(orig_text.split()))) / max(len(set(text.split())), 1) > 0.8:
66
+ sources.append(meta)
67
+ found = True
68
+ break
69
+
70
+ if not found:
71
+ sources.append({"filename": "unknown", "page": "unknown"})
72
 
73
  formatted_system_prompt = system_role_prompt.create_message()
74
 
 
98
  for key, vector in data['vectors'].items():
99
  vector_db.insert(key, vector)
100
 
101
+ # Get metadata and original texts if available
102
  metadata = data.get('metadata', [])
103
+ texts = data.get('texts', [])
104
 
105
+ return vector_db, metadata, texts
106
 
107
  @cl.on_chat_start
108
  async def on_chat_start():
 
135
 
136
  # Load pre-processed data
137
  start_time = time.time()
138
+ vector_db, metadata, texts = load_preprocessed_data()
139
  load_time = time.time() - start_time
140
  print(f"Loaded vector database in {load_time:.2f} seconds")
141
 
 
145
  retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
146
  vector_db_retriever=vector_db,
147
  llm=chat_openai,
148
+ metadata=metadata,
149
+ texts=texts
150
  )
151
 
152
  # Let the user know that the system is ready