MrSimple07 commited on
Commit
d2e7d9e
·
1 Parent(s): f79b229

new keyword score based index retriever + answer question

Browse files
Files changed (2) hide show
  1. index_retriever.py +136 -61
  2. utils.py +18 -16
index_retriever.py CHANGED
@@ -1,25 +1,8 @@
1
  from llama_index.core import VectorStoreIndex
2
- from llama_index.core.query_engine import RetrieverQueryEngine
3
  from llama_index.core.retrievers import VectorIndexRetriever
4
  from llama_index.retrievers.bm25 import BM25Retriever
5
- from llama_index.core.retrievers import QueryFusionRetriever
6
- from llama_index.core.response_synthesizers import get_response_synthesizer
7
  from my_logging import log_message
8
-
9
- SIMPLE_PROMPT = """Вы - эксперт по нормативной документации.
10
-
11
- Контекст:
12
- {context_str}
13
-
14
- Вопрос: {query_str}
15
-
16
- Инструкция:
17
- 1. Отвечайте ТОЛЬКО на основе предоставленного контекста
18
- 2. Цитируйте конкретные источники (документ, раздел, таблицу)
19
- 3. Если информации недостаточно, четко укажите это
20
- 4. Будьте точны и конкретны
21
-
22
- Ответ:"""
23
 
24
  def create_vector_index(documents):
25
  """Create vector index from documents"""
@@ -28,59 +11,151 @@ def create_vector_index(documents):
28
  log_message("✓ Index created")
29
  return index
30
 
31
- def keyword_filter_nodes(query, nodes, min_keyword_matches=1):
32
- """Return nodes that contain at least one keyword from the query."""
33
- keywords = [w.lower() for w in query.split() if len(w) > 2]
34
- filtered = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  for node in nodes:
36
- text = node.text.lower()
37
- if any(k in text for k in keywords):
38
- filtered.append(node)
39
- return filtered
 
 
 
 
40
 
41
- def create_query_engine(vector_index):
42
- """Create hybrid retrieval engine with deduplication"""
43
- log_message("Creating query engine...")
44
 
 
 
 
 
45
  vector_retriever = VectorIndexRetriever(
46
  index=vector_index,
47
- similarity_top_k=40 # Reduced from 50
48
  )
 
 
 
49
  bm25_retriever = BM25Retriever.from_defaults(
50
  docstore=vector_index.docstore,
51
- similarity_top_k=40 # Reduced from 50
52
- )
53
- hybrid_retriever = QueryFusionRetriever(
54
- [vector_retriever, bm25_retriever],
55
- similarity_top_k=50, # Reduced from 60
56
- num_queries=1
57
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- class DeduplicatedQueryEngine(RetrieverQueryEngine):
60
- def retrieve(self, query):
61
- nodes = hybrid_retriever.retrieve(query)
62
-
63
- # CRITICAL: Deduplicate by text content hash
64
- seen_hashes = set()
65
- unique_nodes = []
66
-
67
- for node in nodes:
68
- # Create hash from first 200 chars to detect duplicates
69
- text_hash = hash(node.text[:200])
70
-
71
- if text_hash not in seen_hashes:
72
- seen_hashes.add(text_hash)
73
- unique_nodes.append(node)
74
-
75
- log_message(f"Retrieved: {len(nodes)} → Unique: {len(unique_nodes)}")
76
- return unique_nodes[:50] # Return top 50 unique
77
 
78
- response_synthesizer = get_response_synthesizer()
 
 
79
 
80
- query_engine = DeduplicatedQueryEngine(
81
- retriever=hybrid_retriever,
82
- response_synthesizer=response_synthesizer
83
- )
 
 
 
 
 
 
 
 
 
84
 
85
- log_message("✓ Query engine created (with deduplication)")
86
- return query_engine
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from llama_index.core import VectorStoreIndex
 
2
  from llama_index.core.retrievers import VectorIndexRetriever
3
  from llama_index.retrievers.bm25 import BM25Retriever
 
 
4
  from my_logging import log_message
5
+ import re
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def create_vector_index(documents):
8
  """Create vector index from documents"""
 
11
  log_message("✓ Index created")
12
  return index
13
 
14
+
15
+ def extract_keywords(query):
16
+ stopwords = {
17
+ 'какой', 'какие', 'каком', 'какая', 'где', 'когда', 'как', 'что', 'чем',
18
+ 'для', 'при', 'или', 'это', 'есть', 'быть', 'мочь', 'должен', 'нужно',
19
+ 'можно', 'требуется', 'необходимо', 'я', 'мы', 'вы', 'он', 'она', 'они',
20
+ 'в', 'на', 'с', 'по', 'из', 'к', 'о', 'об', 'и', 'а', 'но', 'за', 'до', 'от'
21
+ }
22
+
23
+ words = re.findall(r'\b[\wа-яА-Я0-9]+\b', query.lower())
24
+
25
+ # Filter keywords
26
+ keywords = []
27
+ for word in words:
28
+ if (word not in stopwords and len(word) > 2) or any(c.isdigit() for c in word):
29
+ keywords.append(word)
30
+
31
+ # Also extract exact phrases with hyphens/caps (e.g., "08Х18Н10Т", "С-25")
32
+ exact_matches = re.findall(r'\b[А-ЯA-Z0-9][а-яА-Яa-zA-Z0-9\-]*\b', query)
33
+ keywords.extend([m.lower() for m in exact_matches if len(m) > 2])
34
+
35
+ log_message(f"Keywords extracted: {set(keywords)}")
36
+ return list(set(keywords))
37
+
38
+
39
+ def calculate_keyword_score(text, keywords):
40
+ """Calculate keyword match score for a text chunk"""
41
+ text_lower = text.lower()
42
+ score = 0
43
+
44
+ for keyword in keywords:
45
+ # Exact match (case-insensitive)
46
+ count = text_lower.count(keyword.lower())
47
+ if count > 0:
48
+ # Higher weight for longer keywords (likely more specific)
49
+ weight = len(keyword) / 5.0
50
+ score += count * weight
51
+
52
+ return score
53
+
54
+
55
+ def deduplicate_nodes(nodes):
56
+ """Remove duplicate nodes based on text content"""
57
+ seen_hashes = set()
58
+ unique_nodes = []
59
+
60
  for node in nodes:
61
+ # Use first 200 chars as fingerprint
62
+ text_hash = hash(node.text[:200])
63
+
64
+ if text_hash not in seen_hashes:
65
+ seen_hashes.add(text_hash)
66
+ unique_nodes.append(node)
67
+
68
+ return unique_nodes
69
 
 
 
 
70
 
71
+ def hybrid_retrieve(query, vector_index, top_k=50):
72
+ """Hybrid retrieval: vector + BM25 + keyword boosting"""
73
+
74
+ # 1. Vector retrieval
75
  vector_retriever = VectorIndexRetriever(
76
  index=vector_index,
77
+ similarity_top_k=top_k
78
  )
79
+ vector_nodes = vector_retriever.retrieve(query)
80
+
81
+ # 2. BM25 retrieval
82
  bm25_retriever = BM25Retriever.from_defaults(
83
  docstore=vector_index.docstore,
84
+ similarity_top_k=top_k
 
 
 
 
 
85
  )
86
+ bm25_nodes = bm25_retriever.retrieve(query)
87
+
88
+ # 3. Combine and deduplicate
89
+ all_nodes = vector_nodes + bm25_nodes
90
+ unique_nodes = deduplicate_nodes(all_nodes)
91
+
92
+ log_message(f"Vector: {len(vector_nodes)}, BM25: {len(bm25_nodes)}, Unique: {len(unique_nodes)}")
93
+
94
+ # 4. Extract keywords
95
+ keywords = extract_keywords(query)
96
+
97
+ # 5. Add keyword scores
98
+ scored_nodes = []
99
+ for node in unique_nodes:
100
+ keyword_score = calculate_keyword_score(node.text, keywords)
101
+
102
+ # Combine with original similarity score
103
+ original_score = node.score if hasattr(node, 'score') and node.score else 0.5
104
+
105
+ # Boost formula: original score + keyword bonus (capped at 0.3)
106
+ keyword_boost = min(keyword_score * 0.1, 0.3)
107
+ combined_score = original_score + keyword_boost
108
+
109
+ scored_nodes.append((node, combined_score, keyword_score))
110
+
111
+ # 6. Sort by combined score
112
+ scored_nodes.sort(key=lambda x: x[1], reverse=True)
113
+
114
+ # Log top scores
115
+ log_message("\nTop 10 scores after keyword boosting:")
116
+ for i, (node, combined, kw_score) in enumerate(scored_nodes[:10], 1):
117
+ doc_id = node.metadata.get('document_id', '?')
118
+ node_type = node.metadata.get('type', '?')
119
+ log_message(f" {i}. [{doc_id}] {node_type} - Score: {combined:.3f} (kw: {kw_score:.2f})")
120
+
121
+ # Return nodes only (without scores)
122
+ return [node for node, _, _ in scored_nodes[:top_k]]
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ def keyword_retrieve_fallback(query, vector_index, keywords, top_k=20):
126
+ """Fallback: direct keyword search in all documents"""
127
+ all_nodes = list(vector_index.docstore.docs.values())
128
 
129
+ scored = []
130
+ for node in all_nodes:
131
+ score = calculate_keyword_score(node.text, keywords)
132
+ if score > 0:
133
+ scored.append((node, score))
134
+
135
+ scored.sort(key=lambda x: x[1], reverse=True)
136
+
137
+ if scored:
138
+ log_message(f"\nKeyword fallback found {len(scored)} matches")
139
+ log_message(f"Top scores: {[s for _, s in scored[:5]]}")
140
+
141
+ return [node for node, _ in scored[:top_k]]
142
 
143
+
144
+ def create_query_engine(vector_index):
145
+
146
+ def retrieve(query):
147
+ nodes = hybrid_retrieve(query, vector_index, top_k=60)
148
+
149
+ # Fallback: If too few results, add pure keyword matches
150
+ keywords = extract_keywords(query)
151
+ if len(nodes) < 20 and keywords:
152
+ log_message("\n⚠ Adding keyword fallback results...")
153
+ fallback_nodes = keyword_retrieve_fallback(query, vector_index, keywords, top_k=30)
154
+ nodes.extend(fallback_nodes)
155
+ nodes = deduplicate_nodes(nodes)
156
+
157
+ log_message(f"\nFinal retrieval: {len(nodes)} nodes")
158
+ return nodes[:50] # Cap at 50
159
+
160
+ log_message("✓ Query engine created (hybrid + keyword boost)")
161
+ return retrieve
utils.py CHANGED
@@ -37,20 +37,23 @@ def format_sources(nodes):
37
 
38
  return "\n".join(set(sources))
39
 
40
- def answer_question(question, query_engine, reranker):
41
  try:
42
  log_message(f"\n{'='*70}")
43
  log_message(f"QUERY: {question}")
44
 
45
- # Retrieve nodes (already deduplicated)
46
- retrieved = query_engine.retrieve(question)
47
- log_message(f"RETRIEVED: {len(retrieved)} unique nodes")
48
 
49
- # Rerank
50
- reranked = rerank_nodes(question, retrieved, reranker, top_k=15, min_score=0.25) # Reduced top_k
51
- log_message(f"RERANKED: {len(reranked)} nodes")
52
-
53
- # Build context - NO TRUNCATION
 
 
 
54
  context_parts = []
55
  for n in reranked:
56
  meta = n.metadata
@@ -66,7 +69,7 @@ def answer_question(question, query_engine, reranker):
66
  else:
67
  source_label = f"[{doc_id}]"
68
 
69
- context_parts.append(f"{source_label}\n{n.text}") # Full text
70
 
71
  context = "\n\n" + ("="*50 + "\n\n").join(context_parts)
72
 
@@ -79,18 +82,17 @@ def answer_question(question, query_engine, reranker):
79
 
80
  sources = format_sources(reranked)
81
 
82
- # Log retrieved chunks WITHOUT duplicates
83
  log_message(f"\n{'='*70}")
84
- log_message("RETRIEVED CHUNKS:")
85
  for i, node in enumerate(reranked, 1):
86
  log_message(f"\n--- Chunk {i} ---")
87
- log_message(f"Document: {node.metadata.get('document_id')}")
88
  log_message(f"Type: {node.metadata.get('type')}")
89
  if node.metadata.get('type') == 'table':
90
  table_id = node.metadata.get('table_identifier')
91
- rows = f"{node.metadata.get('row_start', 0)}-{node.metadata.get('row_end', 0)}"
92
- log_message(f"Table: {table_id} (rows {rows})")
93
- log_message(f"Text: {node.text[:300]}...")
94
 
95
  return response.text, sources
96
 
 
37
 
38
  return "\n".join(set(sources))
39
 
40
+ def answer_question(question, retrieve_func, reranker):
41
  try:
42
  log_message(f"\n{'='*70}")
43
  log_message(f"QUERY: {question}")
44
 
45
+ # Retrieve with keyword boosting
46
+ retrieved = retrieve_func(question)
47
+ log_message(f"RETRIEVED: {len(retrieved)} nodes")
48
 
49
+ # Rerank (optional - уже есть keyword boost)
50
+ if reranker:
51
+ reranked = rerank_nodes(question, retrieved, reranker, top_k=25, min_score=0.2)
52
+ log_message(f"RERANKED: {len(reranked)} nodes")
53
+ else:
54
+ reranked = retrieved[:25]
55
+
56
+ # Build context
57
  context_parts = []
58
  for n in reranked:
59
  meta = n.metadata
 
69
  else:
70
  source_label = f"[{doc_id}]"
71
 
72
+ context_parts.append(f"{source_label}\n{n.text}")
73
 
74
  context = "\n\n" + ("="*50 + "\n\n").join(context_parts)
75
 
 
82
 
83
  sources = format_sources(reranked)
84
 
85
+ # Detailed logging
86
  log_message(f"\n{'='*70}")
87
+ log_message("FINAL CHUNKS:")
88
  for i, node in enumerate(reranked, 1):
89
  log_message(f"\n--- Chunk {i} ---")
90
+ log_message(f"Doc: {node.metadata.get('document_id')}")
91
  log_message(f"Type: {node.metadata.get('type')}")
92
  if node.metadata.get('type') == 'table':
93
  table_id = node.metadata.get('table_identifier')
94
+ log_message(f"Table: {table_id}")
95
+ log_message(f"Preview: {node.text[:400]}...")
 
96
 
97
  return response.text, sources
98