File size: 5,612 Bytes
ba52088
 
 
2e8b03f
ba52088
9985d37
 
 
e10965e
703587b
 
 
e10965e
 
9985d37
e10965e
 
2595129
9985d37
 
 
3b55526
9985d37
 
 
3b55526
b38db64
c7a9dbd
9985d37
c7a9dbd
b38db64
9985d37
 
b38db64
9985d37
 
b38db64
c7a9dbd
3ac0ce6
ae5a669
 
dfc7ba2
ba52088
2595129
 
2edec29
dfc7ba2
 
c33deff
d577496
dfc7ba2
ec64429
dfc7ba2
806f3f9
 
200954f
806f3f9
 
 
 
 
 
 
 
 
200954f
806f3f9
 
 
30be7bf
806f3f9
 
 
 
 
 
 
 
 
 
 
 
 
 
ec64429
806f3f9
 
 
 
 
 
ec64429
806f3f9
 
 
 
 
 
2595129
2edec29
 
c33deff
7565a55
 
2edec29
9985d37
7565a55
2edec29
9985d37
 
200954f
 
9985d37
30be7bf
ec64429
806f3f9
2edec29
806f3f9
2edec29
 
806f3f9
9985d37
 
806f3f9
 
 
 
 
 
 
 
 
 
 
 
 
2edec29
806f3f9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from llama_index.llms.google_genai import GoogleGenAI
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from sentence_transformers import CrossEncoder
from my_logging import log_message

def get_llm_model(api_key, model_name="gemini-2.0-flash"):
    """Get LLM model"""
    return GoogleGenAI(model=model_name, api_key=api_key)

def get_embedding_model(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"):
    """Get embedding model"""
    return HuggingFaceEmbedding(model_name=model_name)

def get_reranker_model(model_name='cross-encoder/ms-marco-MiniLM-L-12-v2'):
    """Get reranker model"""
    return CrossEncoder(model_name)


def format_sources(nodes):
    """Format retrieved sources for display"""
    sources = []
    for node in nodes:
        meta = node.metadata
        doc_type = meta.get('type', 'text')
        doc_id = meta.get('document_id', 'unknown')
        
        if doc_type == 'table':
            table_num = meta.get('table_number', 'unknown')
            title = meta.get('table_title', '')
            sources.append(f"📊 {doc_id} - Таблица {table_num}: {title}")
        elif doc_type == 'image':
            img_num = meta.get('image_number', 'unknown')
            sources.append(f"🖼️ {doc_id} - Рисунок {img_num}")
        else:
            section = meta.get('section_id', '')
            sources.append(f"📄 {doc_id} - Раздел {section}")
    
    return "\n".join(set(sources))

import re

def answer_question(question, query_engine, reranker):
    try:
        log_message(f"\n{'='*70}")
        log_message(f"QUERY: {question}")

        retrieved = query_engine.retrieve(question)
        log_message(f"RETRIEVED: {len(retrieved)} unique nodes")

        reranked = rerank_nodes(question, retrieved, reranker, top_k=20, min_score=-0.5)
        log_message(f"RERANKED: {len(reranked)} nodes")
        

        # Group by document and type
        doc_groups = {}
        for n in reranked:
            doc_id = n.metadata.get('document_id', 'unknown')
            if doc_id not in doc_groups:
                doc_groups[doc_id] = {'tables': [], 'text': [], 'images': []}
            
            node_type = n.metadata.get('type', 'text')
            if node_type == 'table':
                doc_groups[doc_id]['tables'].append(n)
            elif node_type == 'image':
                doc_groups[doc_id]['images'].append(n)
            else:
                doc_groups[doc_id]['text'].append(n)
        
        log_message(f"Documents found: {list(doc_groups.keys())}")

        # Format context by document
        context_parts = []
        for doc_id, groups in doc_groups.items():
            doc_section = [f"=== ДОКУМЕНТ: {doc_id} ==="]
            
            # Tables first (most important for your queries)
            if groups['tables']:
                doc_section.append("\n--- ТАБЛИЦЫ ---")
                for n in groups['tables']:
                    meta = n.metadata
                    table_id = meta.get('table_identifier', meta.get('table_number', 'unknown'))
                    title = meta.get('table_title', '')
                    doc_section.append(f"\n[Таблица {table_id}] {title}")
                    doc_section.append(n.text[:1500])  # Limit length
                    log_message(f"  Included table {table_id} from {doc_id}")
            
            # Then text
            if groups['text']:
                doc_section.append("\n--- ТЕКСТ ---")
                for n in groups['text'][:3]:  # Limit text chunks
                    doc_section.append(n.text[:800])
                    log_message(f"  Included text section from {doc_id}")
            
            context_parts.append("\n".join(doc_section))

        context = "\n\n" + ("="*70 + "\n\n").join(context_parts)
        
        log_message(f"Context length: {len(context)} chars")

        from config import CUSTOM_PROMPT
        prompt = CUSTOM_PROMPT.format(context_str=context, query_str=question)

        from llama_index.core import Settings
        response = Settings.llm.complete(prompt)

        sources = format_sources(reranked)
        return response.text, sources

    except Exception as e:
        log_message(f"Error: {e}")
        import traceback
        log_message(traceback.format_exc())
        return f"Ошибка: {e}", ""

def rerank_nodes(query, nodes, reranker, top_k=20, min_score=0.1):  # Much lower threshold
    """Rerank with detailed score logging"""
    if not nodes or not reranker:
        log_message("WARNING: No nodes or reranker available")
        return nodes[:top_k]

    pairs = [[query, n.text[:500]] for n in nodes]  # Limit text length for reranker
    scores = reranker.predict(pairs)
    scored = sorted(zip(nodes, scores), key=lambda x: x[1], reverse=True)
    
    # Detailed logging
    if scored:
        top_5_scores = [s for _, s in scored[:5]]
        bottom_5_scores = [s for _, s in scored[-5:]]
        log_message(f"Score range: {min(scores):.3f} to {max(scores):.3f}")
        log_message(f"Top 5 scores: {top_5_scores}")
        log_message(f"Bottom 5 scores: {bottom_5_scores}")
    
    # Count how many pass threshold
    above_threshold = sum(1 for _, s in scored if s >= min_score)
    log_message(f"Nodes above threshold ({min_score}): {above_threshold}/{len(scored)}")
    
    filtered = [n for n, s in scored if s >= min_score]
    result = filtered[:top_k] if filtered else [n for n, _ in scored[:top_k]]
    
    log_message(f"Returning {len(result)} nodes after reranking")
    return result