AIEXP_0 / index_retriever.py
MrSimple01's picture
Update index_retriever.py
4db0689 verified
from llama_index.core import VectorStoreIndex, Settings
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.response_synthesizers import get_response_synthesizer, ResponseMode
from llama_index.core.prompts import PromptTemplate
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from my_logging import log_message
from config import CUSTOM_PROMPT, PROMPT_SIMPLE_POISK
def create_vector_index(documents):
log_message("Строю векторный индекс")
connection_type_sources = {}
table_count = 0
for doc in documents:
if doc.metadata.get('type') == 'table':
table_count += 1
conn_type = doc.metadata.get('connection_type', '')
if conn_type:
table_id = f"{doc.metadata.get('document_id', 'unknown')} Table {doc.metadata.get('table_number', 'N/A')}"
if conn_type not in connection_type_sources:
connection_type_sources[conn_type] = []
connection_type_sources[conn_type].append(table_id)
log_message("="*60)
log_message(f"INDEXING {table_count} TABLE CHUNKS")
log_message("CONNECTION TYPES IN INDEX WITH SOURCES:")
for conn_type in sorted(connection_type_sources.keys()):
sources = list(set(connection_type_sources[conn_type])) # Unique sources
log_message(f" {conn_type}: {len(connection_type_sources[conn_type])} chunks from {len(sources)} tables")
for src in sources:
log_message(f" - {src}")
log_message("="*60)
return VectorStoreIndex.from_documents(documents)
def rerank_nodes(query, nodes, reranker, top_k=25, min_score_threshold=0.5):
if not nodes or not reranker:
return nodes[:top_k]
try:
log_message(f"Переранжирую {len(nodes)} узлов")
pairs = [[query, node.text] for node in nodes]
scores = reranker.predict(pairs)
scored_nodes = list(zip(nodes, scores))
scored_nodes.sort(key=lambda x: x[1], reverse=True)
# Apply threshold
filtered = [(node, score) for node, score in scored_nodes if score >= min_score_threshold]
if not filtered:
# Lower threshold if nothing passes
filtered = scored_nodes[:top_k]
log_message(f"Выбрано {min(len(filtered), top_k)} узлов")
return [node for node, score in filtered[:top_k]]
except Exception as e:
log_message(f"Ошибка переранжировки: {str(e)}")
return nodes[:top_k]
# MODIFIED: Update create_query_engine function signature
def create_query_engine(vector_index, vector_top_k=50, bm25_top_k=50,
similarity_cutoff=0.55, hybrid_top_k=100):
try:
from config import CUSTOM_PROMPT
bm25_retriever = BM25Retriever.from_defaults(
docstore=vector_index.docstore,
similarity_top_k=bm25_top_k # NOW PARAMETERIZED
)
vector_retriever = VectorIndexRetriever(
index=vector_index,
similarity_top_k=vector_top_k, # NOW PARAMETERIZED
similarity_cutoff=similarity_cutoff # NOW PARAMETERIZED
)
hybrid_retriever = QueryFusionRetriever(
[vector_retriever, bm25_retriever],
similarity_top_k=hybrid_top_k, # NOW PARAMETERIZED
num_queries=1
)
custom_prompt_template = PromptTemplate(CUSTOM_PROMPT)
response_synthesizer = get_response_synthesizer(
response_mode=ResponseMode.TREE_SUMMARIZE,
text_qa_template=custom_prompt_template
)
query_engine = RetrieverQueryEngine(
retriever=hybrid_retriever,
response_synthesizer=response_synthesizer
)
log_message(f"Query engine created: vector_top_k={vector_top_k}, "
f"bm25_top_k={bm25_top_k}, similarity_cutoff={similarity_cutoff}, "
f"hybrid_top_k={hybrid_top_k}")
return query_engine
except Exception as e:
log_message(f"Ошибка создания query engine: {str(e)}")
raise