RAG_AIEXP / index_retriever.py
MrSimple01's picture
Upload 10 files
fa02ae1 verified
from llama_index.core import VectorStoreIndex, Settings
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.response_synthesizers import get_response_synthesizer, ResponseMode
from llama_index.core.prompts import PromptTemplate
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from my_logging import log_message
from config import CUSTOM_PROMPT, PROMPT_SIMPLE_POISK
def create_vector_index(documents):
log_message("Строю векторный индекс")
connection_type_sources = {}
table_count = 0
for doc in documents:
if doc.metadata.get('type') == 'table':
table_count += 1
conn_type = doc.metadata.get('connection_type', '')
if conn_type:
table_id = f"{doc.metadata.get('document_id', 'unknown')} Table {doc.metadata.get('table_number', 'N/A')}"
if conn_type not in connection_type_sources:
connection_type_sources[conn_type] = []
connection_type_sources[conn_type].append(table_id)
return VectorStoreIndex.from_documents(documents)
def rerank_nodes(query, nodes, reranker, top_k=25, min_score_threshold=0.5):
if not nodes or not reranker:
return nodes[:top_k]
try:
log_message(f"Переранжирую {len(nodes)} узлов")
pairs = [[query, node.text] for node in nodes]
scores = reranker.predict(pairs)
scored_nodes = list(zip(nodes, scores))
scored_nodes.sort(key=lambda x: x[1], reverse=True)
filtered = [(node, score) for node, score in scored_nodes if score >= min_score_threshold]
if not filtered:
filtered = scored_nodes[:top_k]
log_message(f"Выбрано {min(len(filtered), top_k)} узлов")
return [node for node, score in filtered[:top_k]]
except Exception as e:
log_message(f"Ошибка переранжировки: {str(e)}")
return nodes[:top_k]
def create_query_engine(vector_index, vector_top_k=50, bm25_top_k=50,
similarity_cutoff=0.55, hybrid_top_k=100):
try:
from config import CUSTOM_PROMPT
bm25_retriever = BM25Retriever.from_defaults(
docstore=vector_index.docstore,
similarity_top_k=bm25_top_k
)
vector_retriever = VectorIndexRetriever(
index=vector_index,
similarity_top_k=vector_top_k,
similarity_cutoff=similarity_cutoff
)
hybrid_retriever = QueryFusionRetriever(
[vector_retriever, bm25_retriever],
similarity_top_k=hybrid_top_k,
num_queries=1
)
custom_prompt_template = PromptTemplate(CUSTOM_PROMPT)
response_synthesizer = get_response_synthesizer(
response_mode=ResponseMode.TREE_SUMMARIZE,
text_qa_template=custom_prompt_template
)
query_engine = RetrieverQueryEngine(
retriever=hybrid_retriever,
response_synthesizer=response_synthesizer
)
log_message(f"Query engine created: vector_top_k={vector_top_k}, "
f"bm25_top_k={bm25_top_k}, similarity_cutoff={similarity_cutoff}, "
f"hybrid_top_k={hybrid_top_k}")
return query_engine
except Exception as e:
log_message(f"Ошибка создания query engine: {str(e)}")
raise