dnzblgn commited on
Commit
7aa6142
·
verified ·
1 Parent(s): a87cdfe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -1
app.py CHANGED
@@ -15,6 +15,10 @@ from PIL import Image
15
  from torchvision import transforms
16
  from torchvision.models import resnet50, ResNet50_Weights
17
  from torchvision import transforms, models
 
 
 
 
18
 
19
  class GeometryImageClassifier:
20
  def __init__(self):
@@ -155,6 +159,23 @@ def create_db(splits):
155
  vectordb = FAISS.from_documents(splits, embeddings)
156
  return vectordb
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  def retrieve_documents(query, retriever, embeddings):
159
  print("\n=== Document Retrieval Process ===")
160
  print(f"Query: {query}")
@@ -208,7 +229,6 @@ def validate_query_semantically(query, retrieved_docs):
208
 
209
  return similarity_score >= 0.3
210
 
211
-
212
  def handle_query(query, history, retriever, qa_chain, embeddings):
213
  """ ✅ Handles user queries & prevents hallucination. """
214
  retrieved_docs = retrieve_documents(query, retriever, embeddings)
 
15
  from torchvision import transforms
16
  from torchvision.models import resnet50, ResNet50_Weights
17
  from torchvision import transforms, models
18
+ from sentence_transformers import CrossEncoder
19
+
20
+
21
+ reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
22
 
23
  class GeometryImageClassifier:
24
  def __init__(self):
 
159
  vectordb = FAISS.from_documents(splits, embeddings)
160
  return vectordb
161
 
162
+ def rerank_documents(query, docs, top_k=3):
163
+ pairs = [[query, doc.page_content] for doc in docs]
164
+ scores = reranker.predict(pairs)
165
+ doc_score_pairs = list(zip(docs, scores))
166
+ ranked_docs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
167
+ return [doc for doc, score in ranked_docs[:top_k]]
168
+
169
+ def filter_relevant_chunks(query, chunks, embeddings, threshold=0.5):
170
+ query_embedding = embeddings.embed_query(query)
171
+ filtered_chunks = []
172
+ for chunk in chunks:
173
+ chunk_embedding = embeddings.embed_query(chunk.page_content)
174
+ similarity = cosine_similarity([query_embedding], [chunk_embedding])[0][0]
175
+ if similarity > threshold:
176
+ filtered_chunks.append(chunk)
177
+ return filtered_chunks
178
+
179
  def retrieve_documents(query, retriever, embeddings):
180
  print("\n=== Document Retrieval Process ===")
181
  print(f"Query: {query}")
 
229
 
230
  return similarity_score >= 0.3
231
 
 
232
  def handle_query(query, history, retriever, qa_chain, embeddings):
233
  """ ✅ Handles user queries & prevents hallucination. """
234
  retrieved_docs = retrieve_documents(query, retriever, embeddings)