Spaces:
Build error
Build error
Upload 2 files
Browse files- app.py +12 -10
- retrieval.py +14 -15
app.py
CHANGED
|
@@ -78,14 +78,16 @@ if "time_taken_for_response" not in st.session_state:
|
|
| 78 |
st.session_state.time_taken_for_response = "N/A"
|
| 79 |
if "metrics" not in st.session_state:
|
| 80 |
st.session_state.metrics = {}
|
| 81 |
-
if "query_dataset" not in
|
| 82 |
st.session_state.query_dataset = ''
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
recent_questions
|
| 85 |
-
print(recent_questions)
|
| 86 |
|
| 87 |
-
if recent_questions
|
| 88 |
-
recent_qns = list(reversed(recent_questions["questions"]))
|
| 89 |
|
| 90 |
print(recent_qns)
|
| 91 |
|
|
@@ -98,7 +100,7 @@ if recent_questions and "questions" in recent_questions and recent_questions["qu
|
|
| 98 |
st.sidebar.title("Analytics")
|
| 99 |
|
| 100 |
# Extract response times and labels
|
| 101 |
-
response_time = [q["response_time"] for q in recent_qns]
|
| 102 |
labels = [f"Q{i+1}" for i in range(len(response_time))]
|
| 103 |
|
| 104 |
# Plot graph
|
|
@@ -130,10 +132,10 @@ if st.button("Submit"):
|
|
| 130 |
st.session_state.time_taken_for_response = end_time - start_time
|
| 131 |
|
| 132 |
# Store in session state
|
| 133 |
-
st.session_state.recent_questions.append({
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
})
|
| 137 |
|
| 138 |
# Display stored response
|
| 139 |
st.subheader("Response")
|
|
|
|
| 78 |
st.session_state.time_taken_for_response = "N/A"
|
| 79 |
if "metrics" not in st.session_state:
|
| 80 |
st.session_state.metrics = {}
|
| 81 |
+
if "query_dataset" not in
|
| 82 |
st.session_state.query_dataset = ''
|
| 83 |
+
if "recent_questions" not in st.session_state:
|
| 84 |
+
st.session_state.recent_questions = {}
|
| 85 |
|
| 86 |
+
st.session_state.recent_questions = load_recent_questions()
|
| 87 |
+
print(st.session_state.recent_questions )
|
| 88 |
|
| 89 |
+
if st.session_state.recent_questions and "questions" in st.session_state.recent_questions and st.session_state.recent_questions ["questions"]:
|
| 90 |
+
recent_qns = list(reversed(st.session_state.recent_questions ["questions"]))
|
| 91 |
|
| 92 |
print(recent_qns)
|
| 93 |
|
|
|
|
| 100 |
st.sidebar.title("Analytics")
|
| 101 |
|
| 102 |
# Extract response times and labels
|
| 103 |
+
response_time = [q['metrics']["response_time"] for q in recent_qns]
|
| 104 |
labels = [f"Q{i+1}" for i in range(len(response_time))]
|
| 105 |
|
| 106 |
# Plot graph
|
|
|
|
| 132 |
st.session_state.time_taken_for_response = end_time - start_time
|
| 133 |
|
| 134 |
# Store in session state
|
| 135 |
+
# st.session_state.recent_questions.append({
|
| 136 |
+
# "question": question,
|
| 137 |
+
# "response_time": st.session_state.time_taken_for_response
|
| 138 |
+
# })
|
| 139 |
|
| 140 |
# Display stored response
|
| 141 |
st.subheader("Response")
|
retrieval.py
CHANGED
|
@@ -5,12 +5,11 @@ import faiss
|
|
| 5 |
from rank_bm25 import BM25Okapi
|
| 6 |
from data_processing import embedding_model
|
| 7 |
from sentence_transformers import CrossEncoder
|
| 8 |
-
import string
|
| 9 |
-
import nltk
|
| 10 |
|
| 11 |
-
import
|
| 12 |
-
nltk
|
| 13 |
-
nltk.download('
|
|
|
|
| 14 |
|
| 15 |
from nltk.tokenize import word_tokenize
|
| 16 |
|
|
@@ -19,8 +18,8 @@ reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
|
|
| 19 |
retrieved_docs = None
|
| 20 |
|
| 21 |
# Tokenize the documents and remove punctuation
|
| 22 |
-
def preprocess(doc):
|
| 23 |
-
|
| 24 |
|
| 25 |
def retrieve_documents_hybrid(query, q_dataset, top_k=5):
|
| 26 |
with open( f"data_local/{q_dataset}_chunked_docs.json", "r") as f:
|
|
@@ -30,31 +29,31 @@ def retrieve_documents_hybrid(query, q_dataset, top_k=5):
|
|
| 30 |
index = faiss.read_index(faiss_index_path)
|
| 31 |
|
| 32 |
# Tokenize documents for BM25
|
| 33 |
-
tokenized_docs = [
|
| 34 |
bm25 = BM25Okapi(tokenized_docs)
|
| 35 |
|
| 36 |
query_embedding = np.array(embedding_model.embed_documents([query]), dtype=np.float32)
|
| 37 |
query_embedding = query_embedding.reshape(1, -1)
|
| 38 |
|
| 39 |
# FAISS Search
|
| 40 |
-
|
| 41 |
faiss_docs = [chunked_documents[i] for i in faiss_indices[0]]
|
| 42 |
|
| 43 |
# BM25 Search
|
| 44 |
-
tokenized_query = preprocess(query)
|
| 45 |
bm25_scores = bm25.get_scores(tokenized_query)
|
| 46 |
bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
|
| 47 |
bm25_docs = [chunked_documents[i] for i in bm25_top_indices]
|
| 48 |
|
| 49 |
# Combine FAISS + BM25 scores and retrieve docs
|
| 50 |
-
combined_results = set(bm25_top_indices).union(set(faiss_indices[0]))
|
| 51 |
|
| 52 |
-
combined_scores = rerank_docs_bm25faiss_scores(combined_results,bm25_scores, faiss_distances,faiss_indices)
|
| 53 |
-
reranked_docs = [chunked_documents[result[0]] for result in combined_scores[:top_k]]
|
| 54 |
|
| 55 |
# Merge FAISS + BM25 Results and re-rank
|
| 56 |
-
|
| 57 |
-
|
| 58 |
|
| 59 |
return reranked_docs
|
| 60 |
|
|
|
|
| 5 |
from rank_bm25 import BM25Okapi
|
| 6 |
from data_processing import embedding_model
|
| 7 |
from sentence_transformers import CrossEncoder
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
#import string
|
| 10 |
+
# import nltk
|
| 11 |
+
# nltk.download('punkt')
|
| 12 |
+
# nltk.download('punkt_tab')
|
| 13 |
|
| 14 |
from nltk.tokenize import word_tokenize
|
| 15 |
|
|
|
|
| 18 |
retrieved_docs = None
|
| 19 |
|
| 20 |
# Tokenize the documents and remove punctuation
|
| 21 |
+
# def preprocess(doc):
|
| 22 |
+
# return [word.lower() for word in word_tokenize(doc) if word not in string.punctuation]
|
| 23 |
|
| 24 |
def retrieve_documents_hybrid(query, q_dataset, top_k=5):
|
| 25 |
with open( f"data_local/{q_dataset}_chunked_docs.json", "r") as f:
|
|
|
|
| 29 |
index = faiss.read_index(faiss_index_path)
|
| 30 |
|
| 31 |
# Tokenize documents for BM25
|
| 32 |
+
tokenized_docs = [doc.split() for doc in chunked_documents]
|
| 33 |
bm25 = BM25Okapi(tokenized_docs)
|
| 34 |
|
| 35 |
query_embedding = np.array(embedding_model.embed_documents([query]), dtype=np.float32)
|
| 36 |
query_embedding = query_embedding.reshape(1, -1)
|
| 37 |
|
| 38 |
# FAISS Search
|
| 39 |
+
_, faiss_indices = index.search(query_embedding, top_k)
|
| 40 |
faiss_docs = [chunked_documents[i] for i in faiss_indices[0]]
|
| 41 |
|
| 42 |
# BM25 Search
|
| 43 |
+
tokenized_query = query.split() #preprocess(query)
|
| 44 |
bm25_scores = bm25.get_scores(tokenized_query)
|
| 45 |
bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
|
| 46 |
bm25_docs = [chunked_documents[i] for i in bm25_top_indices]
|
| 47 |
|
| 48 |
# Combine FAISS + BM25 scores and retrieve docs
|
| 49 |
+
# combined_results = set(bm25_top_indices).union(set(faiss_indices[0]))
|
| 50 |
|
| 51 |
+
# combined_scores = rerank_docs_bm25faiss_scores(combined_results,bm25_scores, faiss_distances,faiss_indices)
|
| 52 |
+
# reranked_docs = [chunked_documents[result[0]] for result in combined_scores[:top_k]]
|
| 53 |
|
| 54 |
# Merge FAISS + BM25 Results and re-rank
|
| 55 |
+
retrieved_docs = list(set(faiss_docs + bm25_docs))[:top_k]
|
| 56 |
+
reranked_docs = rerank_documents(query, retrieved_docs)
|
| 57 |
|
| 58 |
return reranked_docs
|
| 59 |
|