MrSimple07 commited on
Commit
40de98c
·
1 Parent(s): c28dd72

eski holat with utils

Browse files
Files changed (2) hide show
  1. index_retriever.py +64 -1
  2. utils.py +37 -2
index_retriever.py CHANGED
@@ -12,7 +12,70 @@ def create_vector_index(documents):
12
  log_message("Строю векторный индекс")
13
  return VectorStoreIndex.from_documents(documents)
14
 
15
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def create_query_engine(vector_index):
18
  try:
 
12
  log_message("Строю векторный индекс")
13
  return VectorStoreIndex.from_documents(documents)
14
 
15
+ def rerank_nodes(query, nodes, reranker, top_k=25, min_score_threshold=0.5, diversity_penalty=0.3):
16
+ if not nodes or not reranker:
17
+ return nodes[:top_k]
18
+
19
+ try:
20
+ log_message(f"Переранжирую {len(nodes)} узлов")
21
+
22
+ pairs = [[query, node.text] for node in nodes]
23
+ scores = reranker.predict(pairs)
24
+ scored_nodes = list(zip(nodes, scores))
25
+
26
+ scored_nodes.sort(key=lambda x: x[1], reverse=True)
27
+
28
+ if min_score_threshold is not None:
29
+ scored_nodes = [(node, score) for node, score in scored_nodes
30
+ if score >= min_score_threshold]
31
+ log_message(f"После фильтрации по порогу {min_score_threshold}: {len(scored_nodes)} узлов")
32
+
33
+ if not scored_nodes:
34
+ log_message("Нет узлов после фильтрации, снижаю порог")
35
+ scored_nodes = list(zip(nodes, scores))
36
+ scored_nodes.sort(key=lambda x: x[1], reverse=True)
37
+ min_score_threshold = scored_nodes[0][1] * 0.6
38
+ scored_nodes = [(node, score) for node, score in scored_nodes
39
+ if score >= min_score_threshold]
40
+
41
+ selected_nodes = []
42
+ selected_docs = set()
43
+ selected_sections = set()
44
+
45
+ for node, score in scored_nodes:
46
+ if len(selected_nodes) >= top_k:
47
+ break
48
+
49
+ metadata = node.metadata if hasattr(node, 'metadata') else {}
50
+ doc_id = metadata.get('document_id', 'unknown')
51
+ section_key = f"{doc_id}_{metadata.get('section_path', metadata.get('section_id', ''))}"
52
+
53
+ # Apply diversity penalty
54
+ penalty = 0
55
+ if doc_id in selected_docs:
56
+ penalty += diversity_penalty * 0.5
57
+ if section_key in selected_sections:
58
+ penalty += diversity_penalty
59
+
60
+ adjusted_score = score * (1 - penalty)
61
+
62
+ # Add if still competitive
63
+ if not selected_nodes or adjusted_score >= selected_nodes[0][1] * 0.6:
64
+ selected_nodes.append((node, score))
65
+ selected_docs.add(doc_id)
66
+ selected_sections.add(section_key)
67
+
68
+ log_message(f"Выбрано {len(selected_nodes)} узлов с разнообразием")
69
+ log_message(f"Уникальных документов: {len(selected_docs)}, секций: {len(selected_sections)}")
70
+
71
+ if selected_nodes:
72
+ log_message(f"Score range: {selected_nodes[0][1]:.3f} to {selected_nodes[-1][1]:.3f}")
73
+
74
+ return [node for node, score in selected_nodes]
75
+
76
+ except Exception as e:
77
+ log_message(f"Ошибка переранжировки: {str(e)}")
78
+ return nodes[:top_k]
79
 
80
  def create_query_engine(vector_index):
81
  try:
utils.py CHANGED
@@ -226,6 +226,33 @@ def generate_sources_html(nodes, chunks_df=None):
226
  html += "</div>"
227
  return html
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  def answer_question(question, query_engine, reranker, current_model, chunks_df=None):
231
  if query_engine is None:
@@ -239,11 +266,19 @@ def answer_question(question, query_engine, reranker, current_model, chunks_df=N
239
  # Direct retrieval without query expansion
240
  retrieved_nodes = query_engine.retriever.retrieve(question)
241
 
242
- log_message(f"Получено {len(retrieved_nodes)} узлов")
 
 
 
 
 
 
 
 
243
 
244
  reranked_nodes = rerank_nodes(
245
  question,
246
- retrieved_nodes,
247
  reranker,
248
  top_k=20,
249
  min_score_threshold=0.5,
 
226
  html += "</div>"
227
  return html
228
 
229
+ def deduplicate_nodes(nodes):
230
+ """Deduplicate retrieved nodes based on unique identifiers"""
231
+ seen = set()
232
+ unique_nodes = []
233
+
234
+ for node in nodes:
235
+ # Create unique identifier from metadata
236
+ doc_id = node.metadata.get('document_id', '')
237
+ section_id = node.metadata.get('section_id', '')
238
+ chunk_id = node.metadata.get('chunk_id', 0)
239
+ node_type = node.metadata.get('type', 'text')
240
+
241
+ if node_type == 'table':
242
+ table_num = node.metadata.get('table_number', '')
243
+ identifier = f"{doc_id}|table|{table_num}|{chunk_id}"
244
+ elif node_type == 'image':
245
+ img_num = node.metadata.get('image_number', '')
246
+ identifier = f"{doc_id}|image|{img_num}"
247
+ else:
248
+ identifier = f"{doc_id}|{section_id}|{chunk_id}"
249
+
250
+ if identifier not in seen:
251
+ seen.add(identifier)
252
+ unique_nodes.append(node)
253
+
254
+ return unique_nodes
255
+
256
 
257
  def answer_question(question, query_engine, reranker, current_model, chunks_df=None):
258
  if query_engine is None:
 
266
  # Direct retrieval without query expansion
267
  retrieved_nodes = query_engine.retriever.retrieve(question)
268
 
269
+ total_retrieved = len(retrieved_nodes)
270
+ log_message(f"RETRIEVED: {total_retrieved} nodes (before deduplication)")
271
+
272
+ # Deduplicate
273
+ unique_retrieved = deduplicate_nodes(retrieved_nodes)
274
+ duplicates_removed = total_retrieved - len(unique_retrieved)
275
+ log_message(f"DEDUPLICATION: {duplicates_removed} duplicates removed")
276
+ log_message(f"UNIQUE NODES: {len(unique_retrieved)} nodes")
277
+
278
 
279
  reranked_nodes = rerank_nodes(
280
  question,
281
+ unique_retrieved,
282
  reranker,
283
  top_k=20,
284
  min_score_threshold=0.5,