Spaces:
Sleeping
Sleeping
File size: 5,559 Bytes
ce8469e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
from llama_index.core.postprocessor.rankGPT_rerank import RankGPTRerank
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore
from llama_index.core import QueryBundle
from llama_index.core import get_response_synthesizer
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever, KeywordTableSimpleRetriever
from typing import List
from knowledgeBase.collection import CollectionManager
class HybridRetriever(BaseRetriever):
"""
A retriever that combines vector-based and keyword-based retrieval methods to
retrieve relevant nodes based on a given query.
Attributes:
query_engine_name (str): The name of the query engine.
query_engine_description (str): A description of the query engine.
model_llm: The language model used for keyword-based retrieval.
model_embd: The embedding model used for vector-based retrieval.
_vector_retriever (VectorIndexRetriever): The retriever for vector-based retrieval.
_keyword_retriever (KeywordTableSimpleRetriever): The retriever for keyword-based retrieval.
Methods:
__init__(model_llm, model_embd, query_engine_name, query_engine_description, k_semantic=16, k_keyword=6):
_retrieve(query_bundle: QueryBundle) -> List[NodeWithScore]:
"""
def __init__(self, model_llm, model_embd, query_engine_name, query_engine_description, k_semantic=16, k_keyword=6)-> None:
"""
Initializes the HybridRetriever with the given models, query engine details, and retrieval parameters.
"""
self.query_engine_name = query_engine_name
self.query_engine_description = query_engine_description
self.model_llm = model_llm
self.model_embd = model_embd
collection_manager = CollectionManager()
# Load the vector index and keyword index
vector_index = collection_manager.load_vector_index_from_file(query_engine_name=query_engine_name, model_embd=model_embd)
keyword_index = collection_manager.load_keyword_index_from_file(query_engine_name=query_engine_name, model_llm=model_llm)
self._vector_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=k_semantic)
self._keyword_retriever = KeywordTableSimpleRetriever(index=keyword_index, num_chunks_per_query=k_keyword)
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""
Retrieve nodes based on the given query bundle by combining results from
vector and keyword retrievers.
Args:
query_bundle (QueryBundle): The query bundle containing the query information.
Returns:
List[NodeWithScore]: A list of nodes with scores that match the query,
combining results from both vector and keyword retrieval.
"""
vector_nodes = self._vector_retriever.retrieve(query_bundle)
keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
resulting_nodes = []
node_ids_added = set()
# Process all nodes from both lists
for vector_node in vector_nodes:
if vector_node.node.node_id not in node_ids_added:
resulting_nodes.append(vector_node)
node_ids_added.add(vector_node.node.node_id)
for keyword_node in keyword_nodes:
if keyword_node.node.node_id not in node_ids_added:
resulting_nodes.append(keyword_node)
node_ids_added.add(keyword_node.node.node_id)
return resulting_nodes
def load_hybrid_query_engine(model_llm, model_embd, query_engine_name, query_engine_description, k_semantic=18, k_keyword=6):
"""
Load a hybrid query engine that combines vector-based and keyword-based retrieval methods.
Args:
model_llm (object): The language model to be used for semantic understanding and reranking.
model_embd (object): The embedding model to be used for vector-based retrieval.
query_engine_name (str): The name of the query engine.
query_engine_description (str): A description of the query engine.
k_semantic (int, optional): The number of top results to retrieve using semantic search. Defaults to 18.
k_keyword (int, optional): The number of top results to retrieve using keyword search. Defaults to 6.
Returns:
object: An instance of the hybrid query engine.
"""
# Hybrid retriever to combine vector and keyword-based retrieval
hybrid_retriever = HybridRetriever(
model_llm=model_llm,
model_embd=model_embd,
query_engine_name=query_engine_name,
query_engine_description=query_engine_description,
k_semantic=k_semantic,
k_keyword=k_keyword
)
# Reranker to sort retrieved results according to relevance to query by using the language model
k_total = k_semantic + k_keyword
num_keep_nodes = max(1, k_total//2)
rankGPT = RankGPTRerank(top_n=num_keep_nodes, llm=model_llm, verbose=True)
response_synthesizer = get_response_synthesizer(llm=model_llm)
hybrid_query_engine = RetrieverQueryEngine(
retriever=hybrid_retriever,
response_synthesizer=response_synthesizer,
node_postprocessors=[rankGPT]
)
return hybrid_query_engine |