MrSimple07 commited on
Commit
eefdfd0
·
1 Parent(s): 90e6b4c

hybrid_retrieve_with_keywords implemented

Browse files
Files changed (2) hide show
  1. index_retriever.py +113 -137
  2. utils.py +8 -11
index_retriever.py CHANGED
@@ -49,6 +49,119 @@ def create_query_engine(vector_index):
49
  log_message(f"Ошибка создания query engine: {str(e)}")
50
  raise
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def rerank_nodes(query, nodes, reranker, top_k=25, min_score_threshold=0.5, diversity_penalty=0.3):
53
  if not nodes or not reranker:
54
  return nodes[:top_k]
@@ -114,140 +227,3 @@ def rerank_nodes(query, nodes, reranker, top_k=25, min_score_threshold=0.5, dive
114
  log_message(f"Ошибка переранжировки: {str(e)}")
115
  return nodes[:top_k]
116
 
117
-
118
- from rank_bm25 import BM25Okapi
119
- import numpy as np
120
-
121
- class HybridRetriever:
122
- def __init__(self, vector_retriever, documents):
123
- self.vector_retriever = vector_retriever
124
- self.documents = documents
125
-
126
- # Build BM25 index
127
- tokenized_docs = [doc.text.lower().split() for doc in documents]
128
- self.bm25 = BM25Okapi(tokenized_docs)
129
-
130
- # Build metadata index for exact matching
131
- self.metadata_index = self._build_metadata_index(documents)
132
-
133
- def _build_metadata_index(self, documents):
134
- """Index by materials, GOSTs, classes for exact matching"""
135
- index = {
136
- 'materials': {},
137
- 'gosts': {},
138
- 'classes': {},
139
- 'key_terms': {}
140
- }
141
-
142
- for i, doc in enumerate(documents):
143
- metadata = doc.metadata
144
-
145
- # Index materials
146
- for material in metadata.get('materials', []):
147
- if material not in index['materials']:
148
- index['materials'][material] = []
149
- index['materials'][material].append(i)
150
-
151
- # Index GOSTs
152
- for gost in metadata.get('gosts', []):
153
- if gost not in index['gosts']:
154
- index['gosts'][gost] = []
155
- index['gosts'][gost].append(i)
156
-
157
- # Index classes
158
- for cls in metadata.get('classes', []):
159
- if cls not in index['classes']:
160
- index['classes'][cls] = []
161
- index['classes'][cls].append(i)
162
-
163
- # Index key terms
164
- for term in metadata.get('key_terms', []):
165
- term_lower = term.lower()
166
- if term_lower not in index['key_terms']:
167
- index['key_terms'][term_lower] = []
168
- index['key_terms'][term_lower].append(i)
169
-
170
- return index
171
-
172
- def retrieve(self, query, top_k=20, vector_weight=0.5, bm25_weight=0.3, metadata_weight=0.2):
173
- """Hybrid retrieval combining vector, BM25, and metadata matching"""
174
-
175
- # 1. Vector search
176
- vector_results = self.vector_retriever.retrieve(query)
177
- vector_scores = {node.node_id: node.score for node in vector_results}
178
-
179
- # 2. BM25 search
180
- tokenized_query = query.lower().split()
181
- bm25_scores = self.bm25.get_scores(tokenized_query)
182
-
183
- # 3. Metadata exact matching
184
- metadata_scores = self._get_metadata_scores(query)
185
-
186
- # 4. Combine scores
187
- all_node_ids = set(list(vector_scores.keys()) +
188
- list(range(len(self.documents))))
189
-
190
- combined_scores = {}
191
- for node_id in all_node_ids:
192
- vec_score = vector_scores.get(node_id, 0.0)
193
- bm25_score = bm25_scores[node_id] if isinstance(node_id, int) and node_id < len(bm25_scores) else 0.0
194
- meta_score = metadata_scores.get(node_id, 0.0)
195
-
196
- # Normalize and combine
197
- combined_scores[node_id] = (
198
- vector_weight * vec_score +
199
- bm25_weight * (bm25_score / (max(bm25_scores) + 1e-6)) +
200
- metadata_weight * meta_score
201
- )
202
-
203
- # 5. Get top-k
204
- sorted_nodes = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
205
-
206
- # Return as node objects
207
- results = []
208
- for node_id, score in sorted_nodes:
209
- if isinstance(node_id, int) and node_id < len(self.documents):
210
- doc = self.documents[node_id]
211
- # Create node-like object
212
- from types import SimpleNamespace
213
- node = SimpleNamespace(
214
- text=doc.text,
215
- metadata=doc.metadata,
216
- score=score,
217
- node_id=node_id
218
- )
219
- results.append(node)
220
-
221
- return results
222
-
223
- def _get_metadata_scores(self, query):
224
- """Score documents by exact metadata matches"""
225
- scores = {}
226
- query_lower = query.lower()
227
-
228
- # Check for material codes
229
- import re
230
- material_pattern = r'\b\d{2}[ХНТМКВБА]+\d{1,2}[ХНТМКВБА]*\d*\b'
231
- materials_in_query = re.findall(material_pattern, query, re.IGNORECASE)
232
-
233
- for material in materials_in_query:
234
- if material in self.metadata_index['materials']:
235
- for doc_id in self.metadata_index['materials'][material]:
236
- scores[doc_id] = scores.get(doc_id, 0) + 1.0
237
-
238
- # Check for GOSTs
239
- gost_pattern = r'ГОСТ\s+[РЕН\s]*\d+[\.\-\d]*'
240
- gosts_in_query = re.findall(gost_pattern, query, re.IGNORECASE)
241
-
242
- for gost in gosts_in_query:
243
- if gost in self.metadata_index['gosts']:
244
- for doc_id in self.metadata_index['gosts'][gost]:
245
- scores[doc_id] = scores.get(doc_id, 0) + 0.8
246
-
247
- # Check for key terms
248
- for term, doc_ids in self.metadata_index['key_terms'].items():
249
- if term in query_lower:
250
- for doc_id in doc_ids:
251
- scores[doc_id] = scores.get(doc_id, 0) + 0.5
252
-
253
- return scores
 
49
  log_message(f"Ошибка создания query engine: {str(e)}")
50
  raise
51
 
52
+ import re
53
+ from typing import List, Dict, Set
54
+ from my_logging import log_message
55
+
56
+ def extract_keywords_from_query(query: str) -> Dict[str, List[str]]:
57
+ """Extract technical keywords from query"""
58
+ keywords = {
59
+ 'materials': [],
60
+ 'gosts': [],
61
+ 'classes': [],
62
+ 'technical_terms': []
63
+ }
64
+
65
+ # Material codes: 08Х18Н10Т, 12Х18Н10Т, etc.
66
+ material_pattern = r'\b\d{2}[ХНТМКВБА]+\d{1,2}[ХНТМКВБА]*\d*\b'
67
+ keywords['materials'] = re.findall(material_pattern, query, re.IGNORECASE)
68
+
69
+ # GOST standards
70
+ gost_pattern = r'ГОСТ\s+[РЕН\s]*\d+[\.\-\d]*'
71
+ keywords['gosts'] = re.findall(gost_pattern, query, re.IGNORECASE)
72
+
73
+ # Classification codes: 3СIIIa, 1А, 2BII, etc.
74
+ class_pattern = r'\b\d[АБВГСD]+[IV]+[a-z]?\b'
75
+ keywords['classes'] = re.findall(class_pattern, query, re.IGNORECASE)
76
+
77
+ # Technical terms
78
+ terms = ['полуфабрикат', 'план качества', 'контроль', 'арматура',
79
+ 'ультразвуковой', 'сварка', 'испытание']
80
+ for term in terms:
81
+ if term.lower() in query.lower():
82
+ keywords['technical_terms'].append(term)
83
+
84
+ return keywords
85
+
86
+ def keyword_search_nodes(nodes: List, keywords: Dict[str, List[str]]) -> List:
87
+ """Filter nodes by exact keyword matches"""
88
+ if not any(keywords.values()):
89
+ return nodes
90
+
91
+ matched_nodes = []
92
+
93
+ for node in nodes:
94
+ text_lower = node.text.lower()
95
+ metadata = node.metadata if hasattr(node, 'metadata') else {}
96
+
97
+ # Check materials
98
+ for material in keywords['materials']:
99
+ if material.lower() in text_lower:
100
+ matched_nodes.append(node)
101
+ break
102
+ else:
103
+ # Check GOSTs
104
+ for gost in keywords['gosts']:
105
+ if gost.lower() in text_lower:
106
+ matched_nodes.append(node)
107
+ break
108
+ else:
109
+ # Check classes
110
+ for cls in keywords['classes']:
111
+ if cls.lower() in text_lower:
112
+ matched_nodes.append(node)
113
+ break
114
+ else:
115
+ # Check technical terms (at least 2 matches)
116
+ term_matches = sum(1 for term in keywords['technical_terms']
117
+ if term.lower() in text_lower)
118
+ if term_matches >= 2:
119
+ matched_nodes.append(node)
120
+
121
+ return matched_nodes
122
+
123
+ def hybrid_retrieve_with_keywords(question: str, query_engine, top_k: int = 40) -> List:
124
+ """Retrieve using both vector search and keyword matching"""
125
+
126
+ # Extract keywords from query
127
+ keywords = extract_keywords_from_query(question)
128
+ log_message(f"Извлечены ключевые слова: {keywords}")
129
+
130
+ # Get vector search results
131
+ vector_nodes = query_engine.retriever.retrieve(question)
132
+ log_message(f"Векторный поиск: {len(vector_nodes)} узлов")
133
+
134
+ # Apply keyword filtering
135
+ if any(keywords.values()):
136
+ keyword_nodes = keyword_search_nodes(vector_nodes, keywords)
137
+ log_message(f"После фильтрации по ключевым словам: {len(keyword_nodes)} узлов")
138
+
139
+ # If keyword search found results, prioritize them
140
+ if keyword_nodes:
141
+ # Deduplicate and combine
142
+ seen_ids = set()
143
+ combined_nodes = []
144
+
145
+ # First add keyword matches
146
+ for node in keyword_nodes[:top_k]:
147
+ node_id = id(node)
148
+ if node_id not in seen_ids:
149
+ combined_nodes.append(node)
150
+ seen_ids.add(node_id)
151
+
152
+ # Then fill with vector results
153
+ for node in vector_nodes:
154
+ if len(combined_nodes) >= top_k:
155
+ break
156
+ node_id = id(node)
157
+ if node_id not in seen_ids:
158
+ combined_nodes.append(node)
159
+ seen_ids.add(node_id)
160
+
161
+ return combined_nodes[:top_k]
162
+
163
+ return vector_nodes[:top_k]
164
+
165
  def rerank_nodes(query, nodes, reranker, top_k=25, min_score_threshold=0.5, diversity_penalty=0.3):
166
  if not nodes or not reranker:
167
  return nodes[:top_k]
 
227
  log_message(f"Ошибка переранжировки: {str(e)}")
228
  return nodes[:top_k]
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py CHANGED
@@ -232,6 +232,7 @@ def generate_sources_html(nodes, chunks_df=None):
232
  html += "</div>"
233
  return html
234
  def answer_question(question, query_engine, reranker, current_model, chunks_df=None, hybrid_retriever=None):
 
235
  if query_engine is None:
236
  return "<div style='background-color: #e53e3e; color: white; padding: 20px; border-radius: 10px;'>Система не инициализирована</div>", "", ""
237
 
@@ -240,22 +241,18 @@ def answer_question(question, query_engine, reranker, current_model, chunks_df=N
240
 
241
  llm = get_llm_model(current_model)
242
 
243
- # Use hybrid retriever if available
244
- if hybrid_retriever:
245
- retrieved_nodes = hybrid_retriever.retrieve(question, top_k=30)
246
- log_message(f"Hybrid retrieval: получено {len(retrieved_nodes)} узлов")
247
- else:
248
- retrieved_nodes = query_engine.retriever.retrieve(question)
249
- log_message(f"Vector retrieval: получено {len(retrieved_nodes)} узлов")
250
 
251
- # Rerank with increased top_k
252
  reranked_nodes = rerank_nodes(
253
  question,
254
  retrieved_nodes,
255
  reranker,
256
- top_k=25, # Increased from 20
257
- min_score_threshold=0.3, # Lowered from 0.5 to catch more results
258
- diversity_penalty=0.2 # Reduced penalty
259
  )
260
 
261
  formatted_context = format_context_for_llm(reranked_nodes)
 
232
  html += "</div>"
233
  return html
234
  def answer_question(question, query_engine, reranker, current_model, chunks_df=None, hybrid_retriever=None):
235
+ from index_retriever import hybrid_retrieve_with_keywords
236
  if query_engine is None:
237
  return "<div style='background-color: #e53e3e; color: white; padding: 20px; border-radius: 10px;'>Система не инициализирована</div>", "", ""
238
 
 
241
 
242
  llm = get_llm_model(current_model)
243
 
244
+ # Use keyword-enhanced retrieval
245
+ retrieved_nodes = hybrid_retrieve_with_keywords(question, query_engine, top_k=40)
246
+ log_message(f"Hybrid keyword retrieval: получено {len(retrieved_nodes)} узлов")
 
 
 
 
247
 
248
+ # Rerank
249
  reranked_nodes = rerank_nodes(
250
  question,
251
  retrieved_nodes,
252
  reranker,
253
+ top_k=25,
254
+ min_score_threshold=0.3,
255
+ diversity_penalty=0.2
256
  )
257
 
258
  formatted_context = format_context_for_llm(reranked_nodes)