RAG_AIEXP_01 / index_retriever.py
MrSimple07's picture
top k in reranker = 30, bm = 20
f49e798
raw
history blame
4.75 kB
from llama_index.core import VectorStoreIndex, Settings
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.response_synthesizers import get_response_synthesizer, ResponseMode
from llama_index.core.prompts import PromptTemplate
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from my_logging import log_message
from config import CUSTOM_PROMPT, PROMPT_SIMPLE_POISK
def create_vector_index(documents):
log_message("Строю векторный индекс")
return VectorStoreIndex.from_documents(documents)
def create_query_engine(vector_index):
try:
bm25_retriever = BM25Retriever.from_defaults(
docstore=vector_index.docstore,
similarity_top_k=20
)
vector_retriever = VectorIndexRetriever(
index=vector_index,
similarity_top_k=50,
similarity_cutoff=0.7
)
hybrid_retriever = QueryFusionRetriever(
[vector_retriever, bm25_retriever],
similarity_top_k=70,
num_queries=1
)
custom_prompt_template = PromptTemplate(PROMPT_SIMPLE_POISK)
response_synthesizer = get_response_synthesizer(
response_mode=ResponseMode.TREE_SUMMARIZE,
text_qa_template=custom_prompt_template
)
query_engine = RetrieverQueryEngine(
retriever=hybrid_retriever,
response_synthesizer=response_synthesizer
)
log_message("Query engine успешно создан")
return query_engine
except Exception as e:
log_message(f"Ошибка создания query engine: {str(e)}")
raise
def rerank_nodes(query, nodes, reranker, top_k=25, min_score_threshold=0.5, diversity_penalty=0.3):
if not nodes or not reranker:
return nodes[:top_k]
try:
log_message(f"Переранжирую {len(nodes)} узлов")
pairs = [[query, node.text] for node in nodes]
scores = reranker.predict(pairs)
scored_nodes = list(zip(nodes, scores))
scored_nodes.sort(key=lambda x: x[1], reverse=True)
if min_score_threshold is not None:
scored_nodes = [(node, score) for node, score in scored_nodes
if score >= min_score_threshold]
log_message(f"После фильтрации по порогу {min_score_threshold}: {len(scored_nodes)} узлов")
if not scored_nodes:
log_message("Нет узлов после фильтрации, снижаю порог")
scored_nodes = list(zip(nodes, scores))
scored_nodes.sort(key=lambda x: x[1], reverse=True)
min_score_threshold = scored_nodes[0][1] * 0.6
scored_nodes = [(node, score) for node, score in scored_nodes
if score >= min_score_threshold]
selected_nodes = []
selected_docs = set()
selected_sections = set()
for node, score in scored_nodes:
if len(selected_nodes) >= top_k:
break
metadata = node.metadata if hasattr(node, 'metadata') else {}
doc_id = metadata.get('document_id', 'unknown')
section_key = f"{doc_id}_{metadata.get('section_path', metadata.get('section_id', ''))}"
# Apply diversity penalty
penalty = 0
if doc_id in selected_docs:
penalty += diversity_penalty * 0.5
if section_key in selected_sections:
penalty += diversity_penalty
adjusted_score = score * (1 - penalty)
# Add if still competitive
if not selected_nodes or adjusted_score >= selected_nodes[0][1] * 0.6:
selected_nodes.append((node, score))
selected_docs.add(doc_id)
selected_sections.add(section_key)
log_message(f"Выбрано {len(selected_nodes)} узлов с разнообразием")
log_message(f"Уникальных документов: {len(selected_docs)}, секций: {len(selected_sections)}")
if selected_nodes:
log_message(f"Score range: {selected_nodes[0][1]:.3f} to {selected_nodes[-1][1]:.3f}")
return [node for node, score in selected_nodes]
except Exception as e:
log_message(f"Ошибка переранжировки: {str(e)}")
return nodes[:top_k]