Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,6 +15,10 @@ from PIL import Image
|
|
| 15 |
from torchvision import transforms
|
| 16 |
from torchvision.models import resnet50, ResNet50_Weights
|
| 17 |
from torchvision import transforms, models
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
class GeometryImageClassifier:
|
| 20 |
def __init__(self):
|
|
@@ -155,6 +159,23 @@ def create_db(splits):
|
|
| 155 |
vectordb = FAISS.from_documents(splits, embeddings)
|
| 156 |
return vectordb
|
| 157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
def retrieve_documents(query, retriever, embeddings):
|
| 159 |
print("\n=== Document Retrieval Process ===")
|
| 160 |
print(f"Query: {query}")
|
|
@@ -208,7 +229,6 @@ def validate_query_semantically(query, retrieved_docs):
|
|
| 208 |
|
| 209 |
return similarity_score >= 0.3
|
| 210 |
|
| 211 |
-
|
| 212 |
def handle_query(query, history, retriever, qa_chain, embeddings):
|
| 213 |
""" ✅ Handles user queries & prevents hallucination. """
|
| 214 |
retrieved_docs = retrieve_documents(query, retriever, embeddings)
|
|
|
|
| 15 |
from torchvision import transforms
|
| 16 |
from torchvision.models import resnet50, ResNet50_Weights
|
| 17 |
from torchvision import transforms, models
|
| 18 |
+
from sentence_transformers import CrossEncoder
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
| 22 |
|
| 23 |
class GeometryImageClassifier:
|
| 24 |
def __init__(self):
|
|
|
|
| 159 |
vectordb = FAISS.from_documents(splits, embeddings)
|
| 160 |
return vectordb
|
| 161 |
|
| 162 |
+
def rerank_documents(query, docs, top_k=3):
|
| 163 |
+
pairs = [[query, doc.page_content] for doc in docs]
|
| 164 |
+
scores = reranker.predict(pairs)
|
| 165 |
+
doc_score_pairs = list(zip(docs, scores))
|
| 166 |
+
ranked_docs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
|
| 167 |
+
return [doc for doc, score in ranked_docs[:top_k]]
|
| 168 |
+
|
| 169 |
+
def filter_relevant_chunks(query, chunks, embeddings, threshold=0.5):
|
| 170 |
+
query_embedding = embeddings.embed_query(query)
|
| 171 |
+
filtered_chunks = []
|
| 172 |
+
for chunk in chunks:
|
| 173 |
+
chunk_embedding = embeddings.embed_query(chunk.page_content)
|
| 174 |
+
similarity = cosine_similarity([query_embedding], [chunk_embedding])[0][0]
|
| 175 |
+
if similarity > threshold:
|
| 176 |
+
filtered_chunks.append(chunk)
|
| 177 |
+
return filtered_chunks
|
| 178 |
+
|
| 179 |
def retrieve_documents(query, retriever, embeddings):
|
| 180 |
print("\n=== Document Retrieval Process ===")
|
| 181 |
print(f"Query: {query}")
|
|
|
|
| 229 |
|
| 230 |
return similarity_score >= 0.3
|
| 231 |
|
|
|
|
| 232 |
def handle_query(query, history, retriever, qa_chain, embeddings):
|
| 233 |
""" ✅ Handles user queries & prevents hallucination. """
|
| 234 |
retrieved_docs = retrieve_documents(query, retriever, embeddings)
|