RAG_AIEXP_01 / index_retriever.py
MrSimple07's picture
tree summarizer + top k = 30
a5d5837
raw
history blame
2.7 kB
from llama_index.core import VectorStoreIndex, Settings
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.response_synthesizers import get_response_synthesizer, ResponseMode
from llama_index.core.prompts import PromptTemplate
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from my_logging import log_message
from config import CUSTOM_PROMPT, PROMPT_SIMPLE_POISK
def create_vector_index(documents):
log_message("Строю векторный индекс")
return VectorStoreIndex.from_documents(documents)
def create_query_engine(vector_index):
try:
bm25_retriever = BM25Retriever.from_defaults(
docstore=vector_index.docstore,
similarity_top_k=15
)
vector_retriever = VectorIndexRetriever(
index=vector_index,
similarity_top_k=30,
similarity_cutoff=0.7
)
hybrid_retriever = QueryFusionRetriever(
[vector_retriever, bm25_retriever],
similarity_top_k=30,
num_queries=1
)
custom_prompt_template = PromptTemplate(PROMPT_SIMPLE_POISK)
response_synthesizer = get_response_synthesizer(
response_mode=ResponseMode.TREE_SUMMARIZE,
text_qa_template=custom_prompt_template
)
query_engine = RetrieverQueryEngine(
retriever=hybrid_retriever,
response_synthesizer=response_synthesizer
)
log_message("Query engine успешно создан")
return query_engine
except Exception as e:
log_message(f"Ошибка создания query engine: {str(e)}")
raise
def rerank_nodes(query, nodes, reranker, top_k=10):
if not nodes or not reranker:
return nodes[:top_k]
try:
log_message(f"Переранжирую {len(nodes)} узлов")
pairs = []
for node in nodes:
pairs.append([query, node.text])
scores = reranker.predict(pairs)
scored_nodes = list(zip(nodes, scores))
scored_nodes.sort(key=lambda x: x[1], reverse=True)
reranked_nodes = [node for node, score in scored_nodes[:top_k]]
log_message(f"Возвращаю топ-{len(reranked_nodes)} переранжированных узлов")
return reranked_nodes
except Exception as e:
log_message(f"Ошибка переранжировки: {str(e)}")
return nodes[:top_k]