Spaces:
Sleeping
Sleeping
Ajaykanth Maddi commited on
Commit ·
784ce37
1
Parent(s): 2a1bbfd
Code Changes - Upload Results
Browse files- app.py +7 -3
- ragbench.py +48 -12
app.py
CHANGED
|
@@ -384,7 +384,7 @@ def _evaluate_using_groq(context_docs, question, generated_answer):
|
|
| 384 |
|
| 385 |
|
| 386 |
def run_rag_pipeline(subset, question, custom_question, chunking, embed_model, retriever, chunk_count, retriever_type,
|
| 387 |
-
reranking, evaluator):
|
| 388 |
final_question = custom_question if custom_question.strip() else question
|
| 389 |
print(f"The query is {final_question}")
|
| 390 |
|
|
@@ -395,13 +395,17 @@ def run_rag_pipeline(subset, question, custom_question, chunking, embed_model, r
|
|
| 395 |
logging.info("Starting RAG Pipeline using logging")
|
| 396 |
gr.Info("Starting RAG Pipeline using gradio") # Shows as a toast notification in UI
|
| 397 |
|
|
|
|
|
|
|
|
|
|
| 398 |
rag = RAGSystem(
|
| 399 |
subset=subset,
|
| 400 |
dataset_type="test",
|
| 401 |
strategy=chunking,
|
| 402 |
chunks=[], # Not needed for loading
|
| 403 |
generator_model_name=retriever,
|
| 404 |
-
retriever_model_name=embed_model
|
|
|
|
| 405 |
)
|
| 406 |
|
| 407 |
# 3. Load or use stored vector DB
|
|
@@ -621,7 +625,7 @@ with gr.Blocks(
|
|
| 621 |
subset_dropdown, question_dropdown, custom_question_input,
|
| 622 |
chunking_dropdown, embed_dropdown, generator_dropdown,
|
| 623 |
chunk_count, retriever_type,
|
| 624 |
-
reranking_checkbox, evaluator_dropdown
|
| 625 |
],
|
| 626 |
outputs=[gen_ans_display, y_pred_metrics_display, chunks_retrieved_display, evaluator_json_output, download_file]
|
| 627 |
)
|
|
|
|
| 384 |
|
| 385 |
|
| 386 |
def run_rag_pipeline(subset, question, custom_question, chunking, embed_model, retriever, chunk_count, retriever_type,
|
| 387 |
+
reranking, reranking_dropdown, evaluator):
|
| 388 |
final_question = custom_question if custom_question.strip() else question
|
| 389 |
print(f"The query is {final_question}")
|
| 390 |
|
|
|
|
| 395 |
logging.info("Starting RAG Pipeline using logging")
|
| 396 |
gr.Info("Starting RAG Pipeline using gradio") # Shows as a toast notification in UI
|
| 397 |
|
| 398 |
+
ranking_method = reranking_dropdown if reranking else None
|
| 399 |
+
print(f"Using reranking: {reranking}, method: {ranking_method}")
|
| 400 |
+
|
| 401 |
rag = RAGSystem(
|
| 402 |
subset=subset,
|
| 403 |
dataset_type="test",
|
| 404 |
strategy=chunking,
|
| 405 |
chunks=[], # Not needed for loading
|
| 406 |
generator_model_name=retriever,
|
| 407 |
+
retriever_model_name=embed_model,
|
| 408 |
+
reranker_model_name=ranking_method
|
| 409 |
)
|
| 410 |
|
| 411 |
# 3. Load or use stored vector DB
|
|
|
|
| 625 |
subset_dropdown, question_dropdown, custom_question_input,
|
| 626 |
chunking_dropdown, embed_dropdown, generator_dropdown,
|
| 627 |
chunk_count, retriever_type,
|
| 628 |
+
reranking_checkbox, reranking_dropdown, evaluator_dropdown
|
| 629 |
],
|
| 630 |
outputs=[gen_ans_display, y_pred_metrics_display, chunks_retrieved_display, evaluator_json_output, download_file]
|
| 631 |
)
|
ragbench.py
CHANGED
|
@@ -65,6 +65,7 @@ class RAGSystem:
|
|
| 65 |
chunk_overlap: int = 50,
|
| 66 |
generator_model_name: str = "mistralai/Mistral-7B-Instruct-v0.2",
|
| 67 |
retriever_model_name: str = "BAAI/bge-large-en-v1.5",
|
|
|
|
| 68 |
hf_api_token: str = None
|
| 69 |
):
|
| 70 |
self.subset = subset
|
|
@@ -74,6 +75,7 @@ class RAGSystem:
|
|
| 74 |
self.chunk_overlap = chunk_overlap
|
| 75 |
self.generator_model_name = generator_model_name
|
| 76 |
self.retriever_model_name = retriever_model_name
|
|
|
|
| 77 |
self.chunks = chunks
|
| 78 |
self.hf_api_token = hf_api_token or os.getenv("HF_API_TOKEN")
|
| 79 |
|
|
@@ -377,24 +379,58 @@ class RAGSystem:
|
|
| 377 |
except Exception as e:
|
| 378 |
print(f"Generation failed: {str(e)}")
|
| 379 |
return "I couldn't generate an answer."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
|
| 381 |
def retrieve(self, query: str, top_k: int = 10) -> List[Chunk]:
|
| 382 |
"""Retrieve relevant chunks using HYDE"""
|
| 383 |
pseudo_answer = self.generate_hypothetical_answer(query)
|
| 384 |
docs = self.hybrid_retriever.invoke(pseudo_answer)
|
| 385 |
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
|
|
|
|
|
|
| 398 |
|
| 399 |
def generate(self, question: str, context: List[str] = None) -> str:
|
| 400 |
"""Generate final answer with RAG context"""
|
|
|
|
| 65 |
chunk_overlap: int = 50,
|
| 66 |
generator_model_name: str = "mistralai/Mistral-7B-Instruct-v0.2",
|
| 67 |
retriever_model_name: str = "BAAI/bge-large-en-v1.5",
|
| 68 |
+
reranker_model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2",
|
| 69 |
hf_api_token: str = None
|
| 70 |
):
|
| 71 |
self.subset = subset
|
|
|
|
| 75 |
self.chunk_overlap = chunk_overlap
|
| 76 |
self.generator_model_name = generator_model_name
|
| 77 |
self.retriever_model_name = retriever_model_name
|
| 78 |
+
self.reranker_model_name = reranker_model_name
|
| 79 |
self.chunks = chunks
|
| 80 |
self.hf_api_token = hf_api_token or os.getenv("HF_API_TOKEN")
|
| 81 |
|
|
|
|
| 379 |
except Exception as e:
|
| 380 |
print(f"Generation failed: {str(e)}")
|
| 381 |
return "I couldn't generate an answer."
|
| 382 |
+
|
| 383 |
+
def _use_reranker(self, docs: List[LangchainDocument], query: str, top_k: int) -> List[LangchainDocument]:
|
| 384 |
+
"""Use the reranker model to re-rank retrieved documents"""
|
| 385 |
+
if not self.reranker_model_name:
|
| 386 |
+
return docs
|
| 387 |
+
|
| 388 |
+
sentence_chunks = []
|
| 389 |
+
for doc in docs:
|
| 390 |
+
for sentence in doc.page_content.strip().split("."):
|
| 391 |
+
sentence = sentence.strip()
|
| 392 |
+
if len(sentence) > 15:
|
| 393 |
+
sentence_chunks.append((sentence, doc.metadata))
|
| 394 |
+
|
| 395 |
+
pairs = [[query, sent] for sent, _ in sentence_chunks]
|
| 396 |
+
scores = self.reranker.predict(pairs)
|
| 397 |
+
|
| 398 |
+
top_pairs = sorted(zip(sentence_chunks, scores), key=lambda x: x[1], reverse=True)[:top_k]
|
| 399 |
+
|
| 400 |
+
top_chunks = []
|
| 401 |
+
for (sentence, meta), score in top_pairs:
|
| 402 |
+
top_chunks.append(Chunk(
|
| 403 |
+
chunk_id=meta.get("chunk_id", ""),
|
| 404 |
+
text=sentence,
|
| 405 |
+
doc_id=meta.get("doc_id", ""),
|
| 406 |
+
source=meta.get("source", ""),
|
| 407 |
+
chunk_num=meta.get("chunk_num", -1),
|
| 408 |
+
total_chunks=meta.get("total_chunks", -1),
|
| 409 |
+
metadata={**meta, "reranker_score": score}
|
| 410 |
+
))
|
| 411 |
+
|
| 412 |
+
print(f"Reranked {len(top_chunks)} chunks from {len(docs)} documents")
|
| 413 |
+
return top_chunks
|
| 414 |
|
| 415 |
def retrieve(self, query: str, top_k: int = 10) -> List[Chunk]:
|
| 416 |
"""Retrieve relevant chunks using HYDE"""
|
| 417 |
pseudo_answer = self.generate_hypothetical_answer(query)
|
| 418 |
docs = self.hybrid_retriever.invoke(pseudo_answer)
|
| 419 |
|
| 420 |
+
if self.reranker_model_name is not None:
|
| 421 |
+
return self._use_reranker(docs, query, top_k)
|
| 422 |
+
else:
|
| 423 |
+
return [
|
| 424 |
+
Chunk(
|
| 425 |
+
chunk_id=doc.metadata.get("chunk_id", ""),
|
| 426 |
+
text=doc.page_content,
|
| 427 |
+
doc_id=doc.metadata.get("doc_id", ""),
|
| 428 |
+
source=doc.metadata.get("source", ""),
|
| 429 |
+
chunk_num=doc.metadata.get("chunk_num", -1),
|
| 430 |
+
total_chunks=doc.metadata.get("total_chunks", -1),
|
| 431 |
+
metadata=doc.metadata
|
| 432 |
+
) for doc in docs[:top_k]
|
| 433 |
+
]
|
| 434 |
|
| 435 |
def generate(self, question: str, context: List[str] = None) -> str:
|
| 436 |
"""Generate final answer with RAG context"""
|