Spaces:
Sleeping
Sleeping
| from typing import List | |
| import numpy as np | |
| import faiss | |
| from rank_bm25 import BM25Okapi | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| class RAGPipeline: | |
| """A pipeline for Retrieval-Augmented Generation.""" | |
| def __init__(self, embedding_model, reranker): | |
| self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) | |
| self.embedding_model = embedding_model | |
| self.reranker = reranker | |
| self.chunks_with_meta = [] | |
| self.faiss_index = None | |
| self.bm25_index = None | |
| self.all_chunks = [] | |
| def index_research(self, research_items: List[dict]): | |
| """Create an index of research material for fast retrieval.""" | |
| self.chunks_with_meta = [] | |
| self.all_chunks = [] | |
| for item in research_items: | |
| chunks = self.text_splitter.split_text(item['content']) | |
| for chunk in chunks: | |
| self.chunks_with_meta.append({'content': chunk, 'source': item['source']}) | |
| self.all_chunks.append(chunk) | |
| if not self.all_chunks: | |
| print("Warning: No chunks to index.") | |
| return | |
| print(f"--> Embedding {len(self.all_chunks)} chunks...") | |
| embeddings = self.embedding_model.encode(self.all_chunks, convert_to_tensor=False) | |
| self.faiss_index = faiss.IndexFlatL2(embeddings.shape[1]) | |
| self.faiss_index.add(np.array(embeddings, dtype=np.float32)) | |
| tokenized_corpus = [doc.split(" ") for doc in self.all_chunks] | |
| self.bm25_index = BM25Okapi(tokenized_corpus) | |
| def retrieve_and_rerank(self, query: str, top_k: int = 10): | |
| """Retrieve relevant chunks and rerank them for the final context.""" | |
| if not self.chunks_with_meta or self.faiss_index is None or self.bm25_index is None: | |
| return [] | |
| print(f"--> Retrieving and re-ranking for query: '{query[:50]}...'") | |
| query_embedding = self.embedding_model.encode([query], convert_to_tensor=False) | |
| distances, faiss_indices = self.faiss_index.search(np.array(query_embedding, dtype=np.float32), k=min(top_k * 2, len(self.all_chunks))) | |
| tokenized_query = query.split(" ") | |
| bm25_scores = self.bm25_index.get_scores(tokenized_query) | |
| bm25_indices = np.argsort(bm25_scores)[::-1][:min(top_k * 2, len(self.all_chunks))] | |
| combined_indices = set(faiss_indices[0]).union(set(bm25_indices)) | |
| rerank_pairs = [[query, self.chunks_with_meta[idx]['content']] for idx in combined_indices] | |
| if not rerank_pairs: | |
| return [] | |
| scores = self.reranker.predict(rerank_pairs) | |
| scored_items = sorted(zip(scores, combined_indices), key=lambda x: x[0], reverse=True) | |
| final_results = [self.chunks_with_meta[idx] for score, idx in scored_items[:top_k]] | |
| return final_results | |