MrSimple07 commited on
Commit
dfc7ba2
·
1 Parent(s): d2e7d9e

new keyword score based index retriever + answer question

Browse files
Files changed (2) hide show
  1. index_retriever.py +60 -136
  2. utils.py +16 -18
index_retriever.py CHANGED
@@ -1,8 +1,25 @@
1
  from llama_index.core import VectorStoreIndex
 
2
  from llama_index.core.retrievers import VectorIndexRetriever
3
  from llama_index.retrievers.bm25 import BM25Retriever
 
 
4
  from my_logging import log_message
5
- import re
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def create_vector_index(documents):
8
  """Create vector index from documents"""
@@ -11,151 +28,58 @@ def create_vector_index(documents):
11
  log_message("✓ Index created")
12
  return index
13
 
14
-
15
- def extract_keywords(query):
16
- stopwords = {
17
- 'какой', 'какие', 'каком', 'какая', 'где', 'когда', 'как', 'что', 'чем',
18
- 'для', 'при', 'или', 'это', 'есть', 'быть', 'мочь', 'должен', 'нужно',
19
- 'можно', 'требуется', 'необходимо', 'я', 'мы', 'вы', 'он', 'она', 'они',
20
- 'в', 'на', 'с', 'по', 'из', 'к', 'о', 'об', 'и', 'а', 'но', 'за', 'до', 'от'
21
- }
22
-
23
- words = re.findall(r'\b[\wа-яА-Я0-9]+\b', query.lower())
24
-
25
- # Filter keywords
26
- keywords = []
27
- for word in words:
28
- if (word not in stopwords and len(word) > 2) or any(c.isdigit() for c in word):
29
- keywords.append(word)
30
-
31
- # Also extract exact phrases with hyphens/caps (e.g., "08Х18Н10Т", "С-25")
32
- exact_matches = re.findall(r'\b[А-ЯA-Z0-9][а-яА-Яa-zA-Z0-9\-]*\b', query)
33
- keywords.extend([m.lower() for m in exact_matches if len(m) > 2])
34
-
35
- log_message(f"Keywords extracted: {set(keywords)}")
36
- return list(set(keywords))
37
-
38
-
39
- def calculate_keyword_score(text, keywords):
40
- """Calculate keyword match score for a text chunk"""
41
- text_lower = text.lower()
42
- score = 0
43
-
44
- for keyword in keywords:
45
- # Exact match (case-insensitive)
46
- count = text_lower.count(keyword.lower())
47
- if count > 0:
48
- # Higher weight for longer keywords (likely more specific)
49
- weight = len(keyword) / 5.0
50
- score += count * weight
51
-
52
- return score
53
-
54
-
55
- def deduplicate_nodes(nodes):
56
- """Remove duplicate nodes based on text content"""
57
- seen_hashes = set()
58
- unique_nodes = []
59
-
60
  for node in nodes:
61
- # Use first 200 chars as fingerprint
62
- text_hash = hash(node.text[:200])
63
-
64
- if text_hash not in seen_hashes:
65
- seen_hashes.add(text_hash)
66
- unique_nodes.append(node)
67
-
68
- return unique_nodes
69
 
 
 
 
70
 
71
- def hybrid_retrieve(query, vector_index, top_k=50):
72
- """Hybrid retrieval: vector + BM25 + keyword boosting"""
73
-
74
- # 1. Vector retrieval
75
  vector_retriever = VectorIndexRetriever(
76
  index=vector_index,
77
- similarity_top_k=top_k
78
  )
79
- vector_nodes = vector_retriever.retrieve(query)
80
-
81
- # 2. BM25 retrieval
82
  bm25_retriever = BM25Retriever.from_defaults(
83
  docstore=vector_index.docstore,
84
- similarity_top_k=top_k
 
 
 
 
 
85
  )
86
- bm25_nodes = bm25_retriever.retrieve(query)
87
-
88
- # 3. Combine and deduplicate
89
- all_nodes = vector_nodes + bm25_nodes
90
- unique_nodes = deduplicate_nodes(all_nodes)
91
-
92
- log_message(f"Vector: {len(vector_nodes)}, BM25: {len(bm25_nodes)}, Unique: {len(unique_nodes)}")
93
-
94
- # 4. Extract keywords
95
- keywords = extract_keywords(query)
96
-
97
- # 5. Add keyword scores
98
- scored_nodes = []
99
- for node in unique_nodes:
100
- keyword_score = calculate_keyword_score(node.text, keywords)
101
-
102
- # Combine with original similarity score
103
- original_score = node.score if hasattr(node, 'score') and node.score else 0.5
104
-
105
- # Boost formula: original score + keyword bonus (capped at 0.3)
106
- keyword_boost = min(keyword_score * 0.1, 0.3)
107
- combined_score = original_score + keyword_boost
108
-
109
- scored_nodes.append((node, combined_score, keyword_score))
110
-
111
- # 6. Sort by combined score
112
- scored_nodes.sort(key=lambda x: x[1], reverse=True)
113
-
114
- # Log top scores
115
- log_message("\nTop 10 scores after keyword boosting:")
116
- for i, (node, combined, kw_score) in enumerate(scored_nodes[:10], 1):
117
- doc_id = node.metadata.get('document_id', '?')
118
- node_type = node.metadata.get('type', '?')
119
- log_message(f" {i}. [{doc_id}] {node_type} - Score: {combined:.3f} (kw: {kw_score:.2f})")
120
-
121
- # Return nodes only (without scores)
122
- return [node for node, _, _ in scored_nodes[:top_k]]
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- def keyword_retrieve_fallback(query, vector_index, keywords, top_k=20):
126
- """Fallback: direct keyword search in all documents"""
127
- all_nodes = list(vector_index.docstore.docs.values())
128
-
129
- scored = []
130
- for node in all_nodes:
131
- score = calculate_keyword_score(node.text, keywords)
132
- if score > 0:
133
- scored.append((node, score))
134
-
135
- scored.sort(key=lambda x: x[1], reverse=True)
136
 
137
- if scored:
138
- log_message(f"\nKeyword fallback found {len(scored)} matches")
139
- log_message(f"Top scores: {[s for _, s in scored[:5]]}")
140
-
141
- return [node for node, _ in scored[:top_k]]
142
-
143
 
144
- def create_query_engine(vector_index):
145
-
146
- def retrieve(query):
147
- nodes = hybrid_retrieve(query, vector_index, top_k=60)
148
-
149
- # Fallback: If too few results, add pure keyword matches
150
- keywords = extract_keywords(query)
151
- if len(nodes) < 20 and keywords:
152
- log_message("\n⚠ Adding keyword fallback results...")
153
- fallback_nodes = keyword_retrieve_fallback(query, vector_index, keywords, top_k=30)
154
- nodes.extend(fallback_nodes)
155
- nodes = deduplicate_nodes(nodes)
156
-
157
- log_message(f"\nFinal retrieval: {len(nodes)} nodes")
158
- return nodes[:50] # Cap at 50
159
-
160
- log_message("✓ Query engine created (hybrid + keyword boost)")
161
- return retrieve
 
1
  from llama_index.core import VectorStoreIndex
2
+ from llama_index.core.query_engine import RetrieverQueryEngine
3
  from llama_index.core.retrievers import VectorIndexRetriever
4
  from llama_index.retrievers.bm25 import BM25Retriever
5
+ from llama_index.core.retrievers import QueryFusionRetriever
6
+ from llama_index.core.response_synthesizers import get_response_synthesizer
7
  from my_logging import log_message
8
+
9
+ SIMPLE_PROMPT = """Вы - эксперт по нормативной документации.
10
+
11
+ Контекст:
12
+ {context_str}
13
+
14
+ Вопрос: {query_str}
15
+
16
+ Инструкция:
17
+ 1. Отвечайте ТОЛЬКО на основе предоставленного контекста
18
+ 2. Цитируйте конкретные источники (документ, раздел, таблицу)
19
+ 3. Если информации недостаточно, четко укажите это
20
+ 4. Будьте точны и конкретны
21
+
22
+ Ответ:"""
23
 
24
  def create_vector_index(documents):
25
  """Create vector index from documents"""
 
28
  log_message("✓ Index created")
29
  return index
30
 
31
+ def keyword_filter_nodes(query, nodes, min_keyword_matches=1):
32
+ """Return nodes that contain at least one keyword from the query."""
33
+ keywords = [w.lower() for w in query.split() if len(w) > 2]
34
+ filtered = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  for node in nodes:
36
+ text = node.text.lower()
37
+ if any(k in text for k in keywords):
38
+ filtered.append(node)
39
+ return filtered
 
 
 
 
40
 
41
+ def create_query_engine(vector_index):
42
+ """Create hybrid retrieval engine with deduplication"""
43
+ log_message("Creating query engine...")
44
 
 
 
 
 
45
  vector_retriever = VectorIndexRetriever(
46
  index=vector_index,
47
+ similarity_top_k=50 # Reduced from 50
48
  )
 
 
 
49
  bm25_retriever = BM25Retriever.from_defaults(
50
  docstore=vector_index.docstore,
51
+ similarity_top_k=50 # Reduced from 50
52
+ )
53
+ hybrid_retriever = QueryFusionRetriever(
54
+ [vector_retriever, bm25_retriever],
55
+ similarity_top_k=60, # Reduced from 60
56
+ num_queries=1
57
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ class DeduplicatedQueryEngine(RetrieverQueryEngine):
60
+ def retrieve(self, query):
61
+ nodes = hybrid_retriever.retrieve(query)
62
+
63
+ # CRITICAL: Deduplicate by text content hash
64
+ seen_hashes = set()
65
+ unique_nodes = []
66
+
67
+ for node in nodes:
68
+ text_hash = hash(node.text[:200])
69
+
70
+ if text_hash not in seen_hashes:
71
+ seen_hashes.add(text_hash)
72
+ unique_nodes.append(node)
73
+
74
+ log_message(f"Retrieved: {len(nodes)} → Unique: {len(unique_nodes)}")
75
+ return unique_nodes[:50] # Return top 50 unique
76
 
77
+ response_synthesizer = get_response_synthesizer()
 
 
 
 
 
 
 
 
 
 
78
 
79
+ query_engine = DeduplicatedQueryEngine(
80
+ retriever=hybrid_retriever,
81
+ response_synthesizer=response_synthesizer
82
+ )
 
 
83
 
84
+ log_message("✓ Query engine created (with deduplication)")
85
+ return query_engine
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py CHANGED
@@ -37,23 +37,20 @@ def format_sources(nodes):
37
 
38
  return "\n".join(set(sources))
39
 
40
- def answer_question(question, retrieve_func, reranker):
41
  try:
42
  log_message(f"\n{'='*70}")
43
  log_message(f"QUERY: {question}")
44
 
45
- # Retrieve with keyword boosting
46
- retrieved = retrieve_func(question)
47
- log_message(f"RETRIEVED: {len(retrieved)} nodes")
48
 
49
- # Rerank (optional - уже есть keyword boost)
50
- if reranker:
51
- reranked = rerank_nodes(question, retrieved, reranker, top_k=25, min_score=0.2)
52
- log_message(f"RERANKED: {len(reranked)} nodes")
53
- else:
54
- reranked = retrieved[:25]
55
-
56
- # Build context
57
  context_parts = []
58
  for n in reranked:
59
  meta = n.metadata
@@ -69,7 +66,7 @@ def answer_question(question, retrieve_func, reranker):
69
  else:
70
  source_label = f"[{doc_id}]"
71
 
72
- context_parts.append(f"{source_label}\n{n.text}")
73
 
74
  context = "\n\n" + ("="*50 + "\n\n").join(context_parts)
75
 
@@ -82,17 +79,18 @@ def answer_question(question, retrieve_func, reranker):
82
 
83
  sources = format_sources(reranked)
84
 
85
- # Detailed logging
86
  log_message(f"\n{'='*70}")
87
- log_message("FINAL CHUNKS:")
88
  for i, node in enumerate(reranked, 1):
89
  log_message(f"\n--- Chunk {i} ---")
90
- log_message(f"Doc: {node.metadata.get('document_id')}")
91
  log_message(f"Type: {node.metadata.get('type')}")
92
  if node.metadata.get('type') == 'table':
93
  table_id = node.metadata.get('table_identifier')
94
- log_message(f"Table: {table_id}")
95
- log_message(f"Preview: {node.text[:400]}...")
 
96
 
97
  return response.text, sources
98
 
 
37
 
38
  return "\n".join(set(sources))
39
 
40
+ def answer_question(question, query_engine, reranker):
41
  try:
42
  log_message(f"\n{'='*70}")
43
  log_message(f"QUERY: {question}")
44
 
45
+ # Retrieve nodes (already deduplicated)
46
+ retrieved = query_engine.retrieve(question)
47
+ log_message(f"RETRIEVED: {len(retrieved)} unique nodes")
48
 
49
+ # Rerank
50
+ reranked = rerank_nodes(question, retrieved, reranker, top_k=25, min_score=0.25)
51
+ log_message(f"RERANKED: {len(reranked)} nodes")
52
+
53
+ # Build context - NO TRUNCATION
 
 
 
54
  context_parts = []
55
  for n in reranked:
56
  meta = n.metadata
 
66
  else:
67
  source_label = f"[{doc_id}]"
68
 
69
+ context_parts.append(f"{source_label}\n{n.text}") # Full text
70
 
71
  context = "\n\n" + ("="*50 + "\n\n").join(context_parts)
72
 
 
79
 
80
  sources = format_sources(reranked)
81
 
82
+ # Log retrieved chunks WITHOUT duplicates
83
  log_message(f"\n{'='*70}")
84
+ log_message("RETRIEVED CHUNKS:")
85
  for i, node in enumerate(reranked, 1):
86
  log_message(f"\n--- Chunk {i} ---")
87
+ log_message(f"Document: {node.metadata.get('document_id')}")
88
  log_message(f"Type: {node.metadata.get('type')}")
89
  if node.metadata.get('type') == 'table':
90
  table_id = node.metadata.get('table_identifier')
91
+ rows = f"{node.metadata.get('row_start', 0)}-{node.metadata.get('row_end', 0)}"
92
+ log_message(f"Table: {table_id} (rows {rows})")
93
+ log_message(f"Text: {node.text[:300]}...")
94
 
95
  return response.text, sources
96