File size: 2,989 Bytes
20d3dd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from typing import List
import numpy as np
import faiss
from rank_bm25 import BM25Okapi
from langchain.text_splitter import RecursiveCharacterTextSplitter

class RAGPipeline:
    """A pipeline for Retrieval-Augmented Generation."""
    def __init__(self, embedding_model, reranker):
        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
        self.embedding_model = embedding_model
        self.reranker = reranker
        self.chunks_with_meta = []
        self.faiss_index = None
        self.bm25_index = None
        self.all_chunks = []

    def index_research(self, research_items: List[dict]):
        """Create an index of research material for fast retrieval."""
        self.chunks_with_meta = []
        self.all_chunks = []
        for item in research_items:
            chunks = self.text_splitter.split_text(item['content'])
            for chunk in chunks:
                self.chunks_with_meta.append({'content': chunk, 'source': item['source']})
                self.all_chunks.append(chunk)

        if not self.all_chunks:
            print("Warning: No chunks to index.")
            return

        print(f"--> Embedding {len(self.all_chunks)} chunks...")
        embeddings = self.embedding_model.encode(self.all_chunks, convert_to_tensor=False)
        self.faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
        self.faiss_index.add(np.array(embeddings, dtype=np.float32))

        tokenized_corpus = [doc.split(" ") for doc in self.all_chunks]
        self.bm25_index = BM25Okapi(tokenized_corpus)

    def retrieve_and_rerank(self, query: str, top_k: int = 10):
        """Retrieve relevant chunks and rerank them for the final context."""
        if not self.chunks_with_meta or self.faiss_index is None or self.bm25_index is None:
            return []

        print(f"--> Retrieving and re-ranking for query: '{query[:50]}...'")
        
        query_embedding = self.embedding_model.encode([query], convert_to_tensor=False)
        distances, faiss_indices = self.faiss_index.search(np.array(query_embedding, dtype=np.float32), k=min(top_k * 2, len(self.all_chunks)))
        
        tokenized_query = query.split(" ")
        bm25_scores = self.bm25_index.get_scores(tokenized_query)
        bm25_indices = np.argsort(bm25_scores)[::-1][:min(top_k * 2, len(self.all_chunks))]

        combined_indices = set(faiss_indices[0]).union(set(bm25_indices))
        
        rerank_pairs = [[query, self.chunks_with_meta[idx]['content']] for idx in combined_indices]
            
        if not rerank_pairs:
            return []
            
        scores = self.reranker.predict(rerank_pairs)
        
        scored_items = sorted(zip(scores, combined_indices), key=lambda x: x[0], reverse=True)
        
        final_results = [self.chunks_with_meta[idx] for score, idx in scored_items[:top_k]]
            
        return final_results