MusaR's picture
Upload 9 files
20d3dd7 verified
from typing import List
import numpy as np
import faiss
from rank_bm25 import BM25Okapi
from langchain.text_splitter import RecursiveCharacterTextSplitter
class RAGPipeline:
"""A pipeline for Retrieval-Augmented Generation."""
def __init__(self, embedding_model, reranker):
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
self.embedding_model = embedding_model
self.reranker = reranker
self.chunks_with_meta = []
self.faiss_index = None
self.bm25_index = None
self.all_chunks = []
def index_research(self, research_items: List[dict]):
"""Create an index of research material for fast retrieval."""
self.chunks_with_meta = []
self.all_chunks = []
for item in research_items:
chunks = self.text_splitter.split_text(item['content'])
for chunk in chunks:
self.chunks_with_meta.append({'content': chunk, 'source': item['source']})
self.all_chunks.append(chunk)
if not self.all_chunks:
print("Warning: No chunks to index.")
return
print(f"--> Embedding {len(self.all_chunks)} chunks...")
embeddings = self.embedding_model.encode(self.all_chunks, convert_to_tensor=False)
self.faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
self.faiss_index.add(np.array(embeddings, dtype=np.float32))
tokenized_corpus = [doc.split(" ") for doc in self.all_chunks]
self.bm25_index = BM25Okapi(tokenized_corpus)
def retrieve_and_rerank(self, query: str, top_k: int = 10):
"""Retrieve relevant chunks and rerank them for the final context."""
if not self.chunks_with_meta or self.faiss_index is None or self.bm25_index is None:
return []
print(f"--> Retrieving and re-ranking for query: '{query[:50]}...'")
query_embedding = self.embedding_model.encode([query], convert_to_tensor=False)
distances, faiss_indices = self.faiss_index.search(np.array(query_embedding, dtype=np.float32), k=min(top_k * 2, len(self.all_chunks)))
tokenized_query = query.split(" ")
bm25_scores = self.bm25_index.get_scores(tokenized_query)
bm25_indices = np.argsort(bm25_scores)[::-1][:min(top_k * 2, len(self.all_chunks))]
combined_indices = set(faiss_indices[0]).union(set(bm25_indices))
rerank_pairs = [[query, self.chunks_with_meta[idx]['content']] for idx in combined_indices]
if not rerank_pairs:
return []
scores = self.reranker.predict(rerank_pairs)
scored_items = sorted(zip(scores, combined_indices), key=lambda x: x[0], reverse=True)
final_results = [self.chunks_with_meta[idx] for score, idx in scored_items[:top_k]]
return final_results