File size: 3,774 Bytes
fa02ae1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from llama_index.core import VectorStoreIndex, Settings
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.response_synthesizers import get_response_synthesizer, ResponseMode
from llama_index.core.prompts import PromptTemplate
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from my_logging import log_message
from config import CUSTOM_PROMPT, PROMPT_SIMPLE_POISK

def create_vector_index(documents):
    log_message("Строю векторный индекс")
    connection_type_sources = {}
    table_count = 0
    
    for doc in documents:
        if doc.metadata.get('type') == 'table':
            table_count += 1
            conn_type = doc.metadata.get('connection_type', '')
            if conn_type:
                table_id = f"{doc.metadata.get('document_id', 'unknown')} Table {doc.metadata.get('table_number', 'N/A')}"
                if conn_type not in connection_type_sources:
                    connection_type_sources[conn_type] = []
                connection_type_sources[conn_type].append(table_id)  
    return VectorStoreIndex.from_documents(documents)

def rerank_nodes(query, nodes, reranker, top_k=25, min_score_threshold=0.5):
    if not nodes or not reranker:
        return nodes[:top_k]
    
    try:
        log_message(f"Переранжирую {len(nodes)} узлов")
        
        pairs = [[query, node.text] for node in nodes]
        scores = reranker.predict(pairs)
        scored_nodes = list(zip(nodes, scores))
        
        scored_nodes.sort(key=lambda x: x[1], reverse=True)        
        filtered = [(node, score) for node, score in scored_nodes if score >= min_score_threshold]
        
        if not filtered:
            filtered = scored_nodes[:top_k]
        
        log_message(f"Выбрано {min(len(filtered), top_k)} узлов")
        
        return [node for node, score in filtered[:top_k]]
        
    except Exception as e:
        log_message(f"Ошибка переранжировки: {str(e)}")
        return nodes[:top_k]

def create_query_engine(vector_index, vector_top_k=50, bm25_top_k=50, 

                       similarity_cutoff=0.55, hybrid_top_k=100):
    try:
        from config import CUSTOM_PROMPT
        
        bm25_retriever = BM25Retriever.from_defaults(
            docstore=vector_index.docstore, 
            similarity_top_k=bm25_top_k 
        )
        
        vector_retriever = VectorIndexRetriever(
            index=vector_index, 
            similarity_top_k=vector_top_k,  
            similarity_cutoff=similarity_cutoff 
        )
        
        hybrid_retriever = QueryFusionRetriever(
            [vector_retriever, bm25_retriever],
            similarity_top_k=hybrid_top_k,  
            num_queries=1
        )
        
        custom_prompt_template = PromptTemplate(CUSTOM_PROMPT)
        response_synthesizer = get_response_synthesizer(
            response_mode=ResponseMode.TREE_SUMMARIZE,
            text_qa_template=custom_prompt_template
        )
        
        query_engine = RetrieverQueryEngine(
            retriever=hybrid_retriever,
            response_synthesizer=response_synthesizer
        )
        
        log_message(f"Query engine created: vector_top_k={vector_top_k}, "
                   f"bm25_top_k={bm25_top_k}, similarity_cutoff={similarity_cutoff}, "
                   f"hybrid_top_k={hybrid_top_k}")
        return query_engine
        
    except Exception as e:
        log_message(f"Ошибка создания query engine: {str(e)}")
        raise