dnzblgn commited on
Commit
a87cdfe
Β·
verified Β·
1 Parent(s): a1e454e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -82
app.py CHANGED
@@ -16,9 +16,6 @@ from torchvision import transforms
16
  from torchvision.models import resnet50, ResNet50_Weights
17
  from torchvision import transforms, models
18
 
19
-
20
-
21
-
22
  class GeometryImageClassifier:
23
  def __init__(self):
24
  # Load ResNet50 but only use it for feature extraction
@@ -101,9 +98,6 @@ class GeometryImageClassifier:
101
  # βœ… Use a strong sentence embedding model
102
  semantic_model = SentenceTransformer("all-MiniLM-L6-v2")
103
 
104
-
105
-
106
-
107
  def extract_text_from_docx(file_path):
108
  """ βœ… Extracts normal text & tables from a .docx file for better retrieval. """
109
  doc = docx.Document(file_path)
@@ -125,20 +119,14 @@ def extract_text_from_docx(file_path):
125
 
126
  return "\n".join(extracted_text)
127
 
128
-
129
-
130
-
131
  def load_documents():
132
  """ βœ… Loads & processes documents, ensuring table data is properly extracted. """
133
  file_paths = {
134
  "Fastener_Types_Manual": "Fastener_Types_Manual.docx",
135
  "Manufacturing_Expert_Manual": "Manufacturing Expert Manual.docx"
136
  }
137
-
138
-
139
  all_splits = []
140
 
141
-
142
  for doc_name, file_path in file_paths.items():
143
  if not os.path.exists(file_path):
144
  raise FileNotFoundError(f"Document not found: {file_path}")
@@ -161,118 +149,90 @@ def load_documents():
161
 
162
  return all_splits
163
 
164
-
165
-
166
-
167
  def create_db(splits):
168
  """ βœ… Creates a FAISS vector database from document splits. """
169
  embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
170
  vectordb = FAISS.from_documents(splits, embeddings)
171
  return vectordb
172
 
173
-
174
-
175
-
176
  def retrieve_documents(query, retriever, embeddings):
177
- """ βœ… Retrieves the most relevant documents & filters out low-relevance ones. """
178
- query_embedding = np.array(embeddings.embed_query(query)).reshape(1, -1)
 
179
  results = retriever.invoke(query)
180
-
181
-
182
  if not results:
 
183
  return []
184
-
185
-
186
- doc_embeddings = np.array([embeddings.embed_query(doc.page_content) for doc in results])
187
- similarity_scores = cosine_similarity(query_embedding, doc_embeddings)[0] # βœ… Proper cosine similarity
188
-
189
-
190
- MIN_SIMILARITY = 0.5 # πŸ”₯ Increased threshold to improve relevance
191
- filtered_results = [(doc, sim) for doc, sim in zip(results, similarity_scores) if sim >= MIN_SIMILARITY]
192
-
193
-
194
- # βœ… Debugging log
195
- print(f"πŸ” Query: {query}")
196
- print(f"πŸ“„ Retrieved Docs (before filtering): {[(doc.metadata.get('source', 'Unknown'), sim) for doc, sim in zip(results, similarity_scores)]}")
197
- print(f"βœ… Filtered Docs (after threshold {MIN_SIMILARITY}): {[(doc.metadata.get('source', 'Unknown'), sim) for doc, sim in filtered_results]}")
198
-
199
-
 
 
 
 
 
 
 
 
200
  return [doc for doc, _ in filtered_results] if filtered_results else []
201
 
202
-
203
-
204
-
205
  def validate_query_semantically(query, retrieved_docs):
206
- """ βœ… Ensures the query meaning is covered in the retrieved documents. """
207
  if not retrieved_docs:
 
208
  return False
209
 
210
-
211
  combined_text = " ".join([doc.page_content for doc in retrieved_docs])
212
  query_embedding = semantic_model.encode(query, normalize_embeddings=True)
213
  doc_embedding = semantic_model.encode(combined_text, normalize_embeddings=True)
214
-
215
-
216
- similarity_score = np.dot(query_embedding, doc_embedding) # βœ… Cosine similarity already normalized
217
-
218
-
219
- print(f"πŸ” Semantic Similarity Score: {similarity_score}")
220
-
221
-
222
- return similarity_score >= 0.3 # πŸ”₯ Stricter threshold to ensure correctness
223
-
224
-
225
 
226
 
227
  def handle_query(query, history, retriever, qa_chain, embeddings):
228
  """ βœ… Handles user queries & prevents hallucination. """
229
  retrieved_docs = retrieve_documents(query, retriever, embeddings)
230
-
231
-
232
  if not retrieved_docs or not validate_query_semantically(query, retrieved_docs):
233
  return history + [(query, "I couldn't find any relevant information.")], ""
234
-
235
-
236
  response = qa_chain.invoke({"question": query, "chat_history": history})
237
  assistant_response = response['answer'].strip()
238
-
239
-
240
- # βœ… Final hallucination check
241
  if not validate_query_semantically(query, retrieved_docs):
242
  assistant_response = "I couldn't find any relevant information."
243
-
244
-
245
  assistant_response += f"\n\nπŸ“„ **Source:** {', '.join(set(doc.metadata.get('source', 'Unknown') for doc in retrieved_docs))}"
246
-
247
-
248
- # βœ… Debugging logs
249
  print(f"πŸ€– LLM Response: {assistant_response[:300]}") # βœ… Limit output for debugging
250
-
251
-
252
  history.append((query, assistant_response))
253
  return history, ""
254
 
255
-
256
-
257
-
258
  def initialize_chatbot(vector_db):
259
  """ βœ… Initializes chatbot with improved retrieval & processing. """
260
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
261
-
262
-
263
  embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
264
-
265
-
266
  retriever = vector_db.as_retriever(search_kwargs={"k": 5, "search_type": "similarity"})
267
-
268
-
269
  system_prompt = """You are an AI assistant that answers questions **ONLY based on the provided documents**.
270
  - **If no relevant documents are retrieved, respond with: "I couldn't find any relevant information."**
271
  - **If the meaning of the query does not match the retrieved documents, say "I couldn't find any relevant information."**
272
  - **Do NOT attempt to answer from general knowledge.**
273
  """
274
-
275
-
276
  llm = HuggingFaceEndpoint(
277
  repo_id="tiiuae/falcon-40b-instruct",
278
  huggingfacehub_api_token=os.environ.get("HUGGINGFACE_API_TOKEN"),
@@ -281,15 +241,12 @@ def initialize_chatbot(vector_db):
281
  task="text-generation",
282
  system_prompt=system_prompt)
283
 
284
-
285
  qa_chain = ConversationalRetrievalChain.from_llm(
286
  llm=llm,
287
  retriever=retriever,
288
  memory=memory,
289
  return_source_documents=True,
290
  verbose=False)
291
-
292
-
293
  return retriever, qa_chain, embeddings
294
 
295
 
 
16
  from torchvision.models import resnet50, ResNet50_Weights
17
  from torchvision import transforms, models
18
 
 
 
 
19
  class GeometryImageClassifier:
20
  def __init__(self):
21
  # Load ResNet50 but only use it for feature extraction
 
98
  # βœ… Use a strong sentence embedding model
99
  semantic_model = SentenceTransformer("all-MiniLM-L6-v2")
100
 
 
 
 
101
  def extract_text_from_docx(file_path):
102
  """ βœ… Extracts normal text & tables from a .docx file for better retrieval. """
103
  doc = docx.Document(file_path)
 
119
 
120
  return "\n".join(extracted_text)
121
 
 
 
 
122
  def load_documents():
123
  """ βœ… Loads & processes documents, ensuring table data is properly extracted. """
124
  file_paths = {
125
  "Fastener_Types_Manual": "Fastener_Types_Manual.docx",
126
  "Manufacturing_Expert_Manual": "Manufacturing Expert Manual.docx"
127
  }
 
 
128
  all_splits = []
129
 
 
130
  for doc_name, file_path in file_paths.items():
131
  if not os.path.exists(file_path):
132
  raise FileNotFoundError(f"Document not found: {file_path}")
 
149
 
150
  return all_splits
151
 
 
 
 
152
  def create_db(splits):
153
  """ βœ… Creates a FAISS vector database from document splits. """
154
  embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
155
  vectordb = FAISS.from_documents(splits, embeddings)
156
  return vectordb
157
 
 
 
 
158
  def retrieve_documents(query, retriever, embeddings):
159
+ print("\n=== Document Retrieval Process ===")
160
+ print(f"Query: {query}")
161
+
162
  results = retriever.invoke(query)
163
+ print(f"Initial results count: {len(results)}")
164
+
165
  if not results:
166
+ print("No initial results found")
167
  return []
168
+
169
+ reranked_results = rerank_documents(query, results, top_k=3)
170
+ print(f"Reranked results count: {len(reranked_results)}")
171
+
172
+ filtered_chunks = filter_relevant_chunks(query, reranked_results, embeddings, threshold=0.7)
173
+ print(f"Filtered chunks count: {len(filtered_chunks)}")
174
+
175
+ if not filtered_chunks:
176
+ print("No chunks passed filtering")
177
+ return []
178
+
179
+ doc_embeddings = np.array([embeddings.embed_query(doc.page_content) for doc in filtered_chunks])
180
+ query_embedding = np.array(embeddings.embed_query(query)).reshape(1, -1)
181
+ similarity_scores = cosine_similarity(query_embedding, doc_embeddings)[0]
182
+
183
+ print("\nSimilarity Scores:")
184
+ for doc, score in zip(filtered_chunks, similarity_scores):
185
+ print(f"Score: {score:.4f} | Source: {doc.metadata.get('source', 'Unknown')}")
186
+ print(f"Content Preview: {doc.page_content[:100]}...\n")
187
+
188
+ MIN_SIMILARITY = 0.5
189
+ filtered_results = [(doc, sim) for doc, sim in zip(filtered_chunks, similarity_scores) if sim >= MIN_SIMILARITY]
190
+ print(f"Final filtered results count: {len(filtered_results)}")
191
+
192
  return [doc for doc, _ in filtered_results] if filtered_results else []
193
 
 
 
 
194
  def validate_query_semantically(query, retrieved_docs):
195
+ print("\n=== Semantic Validation ===")
196
  if not retrieved_docs:
197
+ print("No documents to validate")
198
  return False
199
 
 
200
  combined_text = " ".join([doc.page_content for doc in retrieved_docs])
201
  query_embedding = semantic_model.encode(query, normalize_embeddings=True)
202
  doc_embedding = semantic_model.encode(combined_text, normalize_embeddings=True)
203
+ similarity_score = np.dot(query_embedding, doc_embedding)
204
+
205
+ print(f"Query: {query}")
206
+ print(f"Semantic similarity score: {similarity_score:.4f}")
207
+ print(f"Validation {'passed' if similarity_score >= 0.3 else 'failed'}")
208
+
209
+ return similarity_score >= 0.3
 
 
 
 
210
 
211
 
212
  def handle_query(query, history, retriever, qa_chain, embeddings):
213
  """ βœ… Handles user queries & prevents hallucination. """
214
  retrieved_docs = retrieve_documents(query, retriever, embeddings)
 
 
215
  if not retrieved_docs or not validate_query_semantically(query, retrieved_docs):
216
  return history + [(query, "I couldn't find any relevant information.")], ""
 
 
217
  response = qa_chain.invoke({"question": query, "chat_history": history})
218
  assistant_response = response['answer'].strip()
 
 
 
219
  if not validate_query_semantically(query, retrieved_docs):
220
  assistant_response = "I couldn't find any relevant information."
 
 
221
  assistant_response += f"\n\nπŸ“„ **Source:** {', '.join(set(doc.metadata.get('source', 'Unknown') for doc in retrieved_docs))}"
 
 
 
222
  print(f"πŸ€– LLM Response: {assistant_response[:300]}") # βœ… Limit output for debugging
 
 
223
  history.append((query, assistant_response))
224
  return history, ""
225
 
 
 
 
226
  def initialize_chatbot(vector_db):
227
  """ βœ… Initializes chatbot with improved retrieval & processing. """
228
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
 
 
229
  embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
 
 
230
  retriever = vector_db.as_retriever(search_kwargs={"k": 5, "search_type": "similarity"})
 
 
231
  system_prompt = """You are an AI assistant that answers questions **ONLY based on the provided documents**.
232
  - **If no relevant documents are retrieved, respond with: "I couldn't find any relevant information."**
233
  - **If the meaning of the query does not match the retrieved documents, say "I couldn't find any relevant information."**
234
  - **Do NOT attempt to answer from general knowledge.**
235
  """
 
 
236
  llm = HuggingFaceEndpoint(
237
  repo_id="tiiuae/falcon-40b-instruct",
238
  huggingfacehub_api_token=os.environ.get("HUGGINGFACE_API_TOKEN"),
 
241
  task="text-generation",
242
  system_prompt=system_prompt)
243
 
 
244
  qa_chain = ConversationalRetrievalChain.from_llm(
245
  llm=llm,
246
  retriever=retriever,
247
  memory=memory,
248
  return_source_documents=True,
249
  verbose=False)
 
 
250
  return retriever, qa_chain, embeddings
251
 
252