MrSimple07 commited on
Commit
abfdf7a
·
1 Parent(s): 1368f74

bm = 25, semantic = 35. hybrid = 40

Browse files
Files changed (2) hide show
  1. index_retriever.py +73 -8
  2. utils.py +26 -10
index_retriever.py CHANGED
@@ -11,29 +11,29 @@ from config import CUSTOM_PROMPT, PROMPT_SIMPLE_POISK
11
  def create_vector_index(documents):
12
  log_message("Строю векторный индекс")
13
  return VectorStoreIndex.from_documents(documents)
 
14
  def create_query_engine(vector_index):
15
  try:
16
  bm25_retriever = BM25Retriever.from_defaults(
17
  docstore=vector_index.docstore,
18
- similarity_top_k=15 # Lower since we're combining with semantic
19
  )
20
 
21
  vector_retriever = VectorIndexRetriever(
22
  index=vector_index,
23
- similarity_top_k=15, # Lower since we're combining with BM25
24
- similarity_cutoff=0.6 # Slightly lower threshold
25
  )
26
 
27
- # Hybrid retriever combines both approaches
28
  hybrid_retriever = QueryFusionRetriever(
29
  [vector_retriever, bm25_retriever],
30
- similarity_top_k=30, # Final top_k after fusion
31
  num_queries=1
32
  )
33
 
34
  custom_prompt_template = PromptTemplate(PROMPT_SIMPLE_POISK)
35
  response_synthesizer = get_response_synthesizer(
36
- response_mode=ResponseMode.TREE_SUMMARIZE,
37
  text_qa_template=custom_prompt_template
38
  )
39
 
@@ -42,9 +42,74 @@ def create_query_engine(vector_index):
42
  response_synthesizer=response_synthesizer
43
  )
44
 
45
- log_message("Query engine создан (BM25 + Semantic, без reranking)")
46
  return query_engine
47
 
48
  except Exception as e:
49
  log_message(f"Ошибка создания query engine: {str(e)}")
50
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def create_vector_index(documents):
12
  log_message("Строю векторный индекс")
13
  return VectorStoreIndex.from_documents(documents)
14
+
15
  def create_query_engine(vector_index):
16
  try:
17
  bm25_retriever = BM25Retriever.from_defaults(
18
  docstore=vector_index.docstore,
19
+ similarity_top_k=25
20
  )
21
 
22
  vector_retriever = VectorIndexRetriever(
23
  index=vector_index,
24
+ similarity_top_k=35,
25
+ similarity_cutoff=0.7
26
  )
27
 
 
28
  hybrid_retriever = QueryFusionRetriever(
29
  [vector_retriever, bm25_retriever],
30
+ similarity_top_k=40,
31
  num_queries=1
32
  )
33
 
34
  custom_prompt_template = PromptTemplate(PROMPT_SIMPLE_POISK)
35
  response_synthesizer = get_response_synthesizer(
36
+ response_mode=ResponseMode.TREE_SUMMARIZE,
37
  text_qa_template=custom_prompt_template
38
  )
39
 
 
42
  response_synthesizer=response_synthesizer
43
  )
44
 
45
+ log_message("Query engine успешно создан")
46
  return query_engine
47
 
48
  except Exception as e:
49
  log_message(f"Ошибка создания query engine: {str(e)}")
50
+ raise
51
+
52
+ def rerank_nodes(query, nodes, reranker, top_k=25, min_score_threshold=0.5, diversity_penalty=0.3):
53
+ if not nodes or not reranker:
54
+ return nodes[:top_k]
55
+
56
+ try:
57
+ log_message(f"Переранжирую {len(nodes)} узлов")
58
+
59
+ pairs = [[query, node.text] for node in nodes]
60
+ scores = reranker.predict(pairs)
61
+ scored_nodes = list(zip(nodes, scores))
62
+
63
+ scored_nodes.sort(key=lambda x: x[1], reverse=True)
64
+
65
+ if min_score_threshold is not None:
66
+ scored_nodes = [(node, score) for node, score in scored_nodes
67
+ if score >= min_score_threshold]
68
+ log_message(f"После фильтрации по порогу {min_score_threshold}: {len(scored_nodes)} узлов")
69
+
70
+ if not scored_nodes:
71
+ log_message("Нет узлов после фильтрации, снижаю порог")
72
+ scored_nodes = list(zip(nodes, scores))
73
+ scored_nodes.sort(key=lambda x: x[1], reverse=True)
74
+ min_score_threshold = scored_nodes[0][1] * 0.6
75
+ scored_nodes = [(node, score) for node, score in scored_nodes
76
+ if score >= min_score_threshold]
77
+
78
+ selected_nodes = []
79
+ selected_docs = set()
80
+ selected_sections = set()
81
+
82
+ for node, score in scored_nodes:
83
+ if len(selected_nodes) >= top_k:
84
+ break
85
+
86
+ metadata = node.metadata if hasattr(node, 'metadata') else {}
87
+ doc_id = metadata.get('document_id', 'unknown')
88
+ section_key = f"{doc_id}_{metadata.get('section_path', metadata.get('section_id', ''))}"
89
+
90
+ # Apply diversity penalty
91
+ penalty = 0
92
+ if doc_id in selected_docs:
93
+ penalty += diversity_penalty * 0.5
94
+ if section_key in selected_sections:
95
+ penalty += diversity_penalty
96
+
97
+ adjusted_score = score * (1 - penalty)
98
+
99
+ # Add if still competitive
100
+ if not selected_nodes or adjusted_score >= selected_nodes[0][1] * 0.6:
101
+ selected_nodes.append((node, score))
102
+ selected_docs.add(doc_id)
103
+ selected_sections.add(section_key)
104
+
105
+ log_message(f"Выбрано {len(selected_nodes)} узлов с разнообразием")
106
+ log_message(f"Уникальных документов: {len(selected_docs)}, секций: {len(selected_sections)}")
107
+
108
+ if selected_nodes:
109
+ log_message(f"Score range: {selected_nodes[0][1]:.3f} to {selected_nodes[-1][1]:.3f}")
110
+
111
+ return [node for node, score in selected_nodes]
112
+
113
+ except Exception as e:
114
+ log_message(f"Ошибка переранжировки: {str(e)}")
115
+ return nodes[:top_k]
utils.py CHANGED
@@ -6,7 +6,7 @@ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
6
  from sentence_transformers import CrossEncoder
7
  from config import AVAILABLE_MODELS, DEFAULT_MODEL, GOOGLE_API_KEY
8
  import time
9
- # from index_retriever import rerank_nodes
10
  from my_logging import log_message
11
  from config import PROMPT_SIMPLE_POISK
12
 
@@ -260,15 +260,31 @@ def answer_question(question, query_engine, reranker, current_model, chunks_df=N
260
 
261
  llm = get_llm_model(current_model)
262
 
263
- # Simple retrieval without query expansion
264
- retrieved_nodes = query_engine.retriever.retrieve(question)
265
 
266
- log_message(f"Получено {len(retrieved_nodes)} узлов (BM25 + Semantic)")
 
267
 
268
- # Use nodes directly without reranking
269
- final_nodes = retrieved_nodes[:30] # Ensure we use top 30
 
 
 
 
 
270
 
271
- formatted_context = format_context_for_llm(final_nodes)
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  enhanced_question = f"""Контекст из базы данных:
274
  {formatted_context}
@@ -285,18 +301,18 @@ def answer_question(question, query_engine, reranker, current_model, chunks_df=N
285
 
286
  log_message(f"Обработка завершена за {processing_time:.2f}с")
287
 
288
- sources_html = generate_sources_html(final_nodes, chunks_df)
289
 
290
  answer_with_time = f"""<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; margin-bottom: 10px;'>
291
  <h3 style='color: #63b3ed; margin-top: 0;'>Ответ (Модель: {current_model}):</h3>
292
  <div style='line-height: 1.6; font-size: 16px;'>{response.response}</div>
293
  <div style='margin-top: 15px; padding-top: 10px; border-top: 1px solid #4a5568; font-size: 14px; color: #a0aec0;'>
294
- Время обработки: {processing_time:.2f} секунд | Метод: BM25 + Semantic (без reranking)
295
  </div>
296
  </div>"""
297
 
298
  chunk_info = []
299
- for node in final_nodes:
300
  metadata = node.metadata if hasattr(node, 'metadata') else {}
301
  chunk_info.append({
302
  'document_id': metadata.get('document_id', 'unknown'),
 
6
  from sentence_transformers import CrossEncoder
7
  from config import AVAILABLE_MODELS, DEFAULT_MODEL, GOOGLE_API_KEY
8
  import time
9
+ from index_retriever import rerank_nodes
10
  from my_logging import log_message
11
  from config import PROMPT_SIMPLE_POISK
12
 
 
260
 
261
  llm = get_llm_model(current_model)
262
 
263
+ query_variations = expand_query(question, llm)
 
264
 
265
+ all_nodes = []
266
+ seen_node_ids = set()
267
 
268
+ for query_var in query_variations:
269
+ retrieved = query_engine.retriever.retrieve(query_var)
270
+ for node in retrieved:
271
+ node_id = f"{node.node_id if hasattr(node, 'node_id') else hash(node.text)}"
272
+ if node_id not in seen_node_ids:
273
+ all_nodes.append(node)
274
+ seen_node_ids.add(node_id)
275
 
276
+ log_message(f"Получено {len(all_nodes)} уникальных узлов из {len(query_variations)} запросов")
277
+
278
+ reranked_nodes = rerank_nodes(
279
+ question,
280
+ all_nodes,
281
+ reranker,
282
+ top_k=25,
283
+ min_score_threshold=0.5,
284
+ diversity_penalty=0.3
285
+ )
286
+
287
+ formatted_context = format_context_for_llm(reranked_nodes)
288
 
289
  enhanced_question = f"""Контекст из базы данных:
290
  {formatted_context}
 
301
 
302
  log_message(f"Обработка завершена за {processing_time:.2f}с")
303
 
304
+ sources_html = generate_sources_html(reranked_nodes, chunks_df)
305
 
306
  answer_with_time = f"""<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; margin-bottom: 10px;'>
307
  <h3 style='color: #63b3ed; margin-top: 0;'>Ответ (Модель: {current_model}):</h3>
308
  <div style='line-height: 1.6; font-size: 16px;'>{response.response}</div>
309
  <div style='margin-top: 15px; padding-top: 10px; border-top: 1px solid #4a5568; font-size: 14px; color: #a0aec0;'>
310
+ Время обработки: {processing_time:.2f} секунд
311
  </div>
312
  </div>"""
313
 
314
  chunk_info = []
315
+ for node in reranked_nodes:
316
  metadata = node.metadata if hasattr(node, 'metadata') else {}
317
  chunk_info.append({
318
  'document_id': metadata.get('document_id', 'unknown'),