from typing import List import numpy as np import faiss from rank_bm25 import BM25Okapi from langchain.text_splitter import RecursiveCharacterTextSplitter class RAGPipeline: """A pipeline for Retrieval-Augmented Generation.""" def __init__(self, embedding_model, reranker): self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) self.embedding_model = embedding_model self.reranker = reranker self.chunks_with_meta = [] self.faiss_index = None self.bm25_index = None self.all_chunks = [] def index_research(self, research_items: List[dict]): """Create an index of research material for fast retrieval.""" self.chunks_with_meta = [] self.all_chunks = [] for item in research_items: chunks = self.text_splitter.split_text(item['content']) for chunk in chunks: self.chunks_with_meta.append({'content': chunk, 'source': item['source']}) self.all_chunks.append(chunk) if not self.all_chunks: print("Warning: No chunks to index.") return print(f"--> Embedding {len(self.all_chunks)} chunks...") embeddings = self.embedding_model.encode(self.all_chunks, convert_to_tensor=False) self.faiss_index = faiss.IndexFlatL2(embeddings.shape[1]) self.faiss_index.add(np.array(embeddings, dtype=np.float32)) tokenized_corpus = [doc.split(" ") for doc in self.all_chunks] self.bm25_index = BM25Okapi(tokenized_corpus) def retrieve_and_rerank(self, query: str, top_k: int = 10): """Retrieve relevant chunks and rerank them for the final context.""" if not self.chunks_with_meta or self.faiss_index is None or self.bm25_index is None: return [] print(f"--> Retrieving and re-ranking for query: '{query[:50]}...'") query_embedding = self.embedding_model.encode([query], convert_to_tensor=False) distances, faiss_indices = self.faiss_index.search(np.array(query_embedding, dtype=np.float32), k=min(top_k * 2, len(self.all_chunks))) tokenized_query = query.split(" ") bm25_scores = self.bm25_index.get_scores(tokenized_query) bm25_indices = np.argsort(bm25_scores)[::-1][:min(top_k * 2, len(self.all_chunks))] combined_indices = set(faiss_indices[0]).union(set(bm25_indices)) rerank_pairs = [[query, self.chunks_with_meta[idx]['content']] for idx in combined_indices] if not rerank_pairs: return [] scores = self.reranker.predict(rerank_pairs) scored_items = sorted(zip(scores, combined_indices), key=lambda x: x[0], reverse=True) final_results = [self.chunks_with_meta[idx] for score, idx in scored_items[:top_k]] return final_results