MrSimple07 commited on
Commit
a2280fa
·
1 Parent(s): 6c83262
Files changed (2) hide show
  1. index_retriever.py +53 -27
  2. utils.py +57 -45
index_retriever.py CHANGED
@@ -16,24 +16,24 @@ def create_query_engine(vector_index):
16
  try:
17
  bm25_retriever = BM25Retriever.from_defaults(
18
  docstore=vector_index.docstore,
19
- similarity_top_k=15
20
  )
21
 
22
  vector_retriever = VectorIndexRetriever(
23
  index=vector_index,
24
- similarity_top_k=20,
25
- similarity_cutoff=0.7
26
  )
27
 
28
  hybrid_retriever = QueryFusionRetriever(
29
  [vector_retriever, bm25_retriever],
30
- similarity_top_k=30,
31
  num_queries=1
32
  )
33
 
34
  custom_prompt_template = PromptTemplate(PROMPT_SIMPLE_POISK)
35
  response_synthesizer = get_response_synthesizer(
36
- response_mode=ResponseMode.TREE_SUMMARIZE,
37
  text_qa_template=custom_prompt_template
38
  )
39
 
@@ -49,16 +49,16 @@ def create_query_engine(vector_index):
49
  log_message(f"Ошибка создания query engine: {str(e)}")
50
  raise
51
 
52
- def rerank_nodes(query, nodes, reranker, top_k=20, min_score_threshold=None):
53
  """
54
- Rerank nodes with adaptive top_k based on score distribution
55
  """
56
  if not nodes or not reranker:
57
  return nodes[:top_k]
58
-
59
  try:
60
  log_message(f"Переранжирую {len(nodes)} узлов")
61
-
62
  pairs = [[query, node.text] for node in nodes]
63
  scores = reranker.predict(pairs)
64
  scored_nodes = list(zip(nodes, scores))
@@ -66,30 +66,56 @@ def rerank_nodes(query, nodes, reranker, top_k=20, min_score_threshold=None):
66
  # Sort by score descending
67
  scored_nodes.sort(key=lambda x: x[1], reverse=True)
68
 
69
- # Apply minimum score threshold if specified
70
  if min_score_threshold is not None:
71
- scored_nodes = [(node, score) for node, score in scored_nodes if score >= min_score_threshold]
 
72
  log_message(f"После фильтрации по порогу {min_score_threshold}: {len(scored_nodes)} узлов")
73
 
74
- # Adaptive top_k: if we have many high-scoring results, keep more
75
- if len(scored_nodes) > top_k:
76
- top_score = scored_nodes[0][1] if scored_nodes else 0
77
- # If 30th node still has >70% of top score, expand to 30
78
- if len(scored_nodes) >= 30 and scored_nodes[29][1] / top_score > 0.7:
79
- effective_top_k = 30
80
- log_message(f"Расширяю top_k до {effective_top_k} из-за высоких скоров")
81
- else:
82
- effective_top_k = top_k
83
- else:
84
- effective_top_k = len(scored_nodes)
85
 
86
- reranked_nodes = [node for node, score in scored_nodes[:effective_top_k]]
 
 
 
87
 
88
- log_message(f"Возвращаю топ-{effective_top_k} узлов после переранжировки")
89
- log_message(f"Score range: {scored_nodes[0][1]:.3f} to {scored_nodes[min(effective_top_k-1, len(scored_nodes)-1)][1]:.3f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- return reranked_nodes
92
-
93
  except Exception as e:
94
  log_message(f"Ошибка переранжировки: {str(e)}")
95
  return nodes[:top_k]
 
16
  try:
17
  bm25_retriever = BM25Retriever.from_defaults(
18
  docstore=vector_index.docstore,
19
+ similarity_top_k=20 # Increased for more candidates
20
  )
21
 
22
  vector_retriever = VectorIndexRetriever(
23
  index=vector_index,
24
+ similarity_top_k=25, # Increased
25
+ similarity_cutoff=0.65 # Slightly lower for recall
26
  )
27
 
28
  hybrid_retriever = QueryFusionRetriever(
29
  [vector_retriever, bm25_retriever],
30
+ similarity_top_k=40, # More candidates for reranking
31
  num_queries=1
32
  )
33
 
34
  custom_prompt_template = PromptTemplate(PROMPT_SIMPLE_POISK)
35
  response_synthesizer = get_response_synthesizer(
36
+ response_mode=ResponseMode.TREE_SUMMARIZE,
37
  text_qa_template=custom_prompt_template
38
  )
39
 
 
49
  log_message(f"Ошибка создания query engine: {str(e)}")
50
  raise
51
 
52
+ def rerank_nodes(query, nodes, reranker, top_k=20, min_score_threshold=0.5, diversity_penalty=0.3):
53
  """
54
+ Rerank nodes with diversity and adaptive scoring
55
  """
56
  if not nodes or not reranker:
57
  return nodes[:top_k]
58
+
59
  try:
60
  log_message(f"Переранжирую {len(nodes)} узлов")
61
+
62
  pairs = [[query, node.text] for node in nodes]
63
  scores = reranker.predict(pairs)
64
  scored_nodes = list(zip(nodes, scores))
 
66
  # Sort by score descending
67
  scored_nodes.sort(key=lambda x: x[1], reverse=True)
68
 
69
+ # Filter by minimum threshold (more strict)
70
  if min_score_threshold is not None:
71
+ scored_nodes = [(node, score) for node, score in scored_nodes
72
+ if score >= min_score_threshold]
73
  log_message(f"После фильтрации по порогу {min_score_threshold}: {len(scored_nodes)} узлов")
74
 
75
+ if not scored_nodes:
76
+ log_message("Нет узлов после фильтрации, снижаю порог")
77
+ scored_nodes = list(zip(nodes, scores))
78
+ scored_nodes.sort(key=lambda x: x[1], reverse=True)
79
+ min_score_threshold = scored_nodes[0][1] * 0.5 # 50% of top score
80
+ scored_nodes = [(node, score) for node, score in scored_nodes
81
+ if score >= min_score_threshold]
 
 
 
 
82
 
83
+ # MMR-like diversity selection
84
+ selected_nodes = []
85
+ selected_docs = set()
86
+ selected_sections = set()
87
 
88
+ for node, score in scored_nodes:
89
+ if len(selected_nodes) >= top_k:
90
+ break
91
+
92
+ metadata = node.metadata if hasattr(node, 'metadata') else {}
93
+ doc_id = metadata.get('document_id', 'unknown')
94
+ section_key = f"{doc_id}_{metadata.get('section_path', metadata.get('section_id', ''))}"
95
+
96
+ # Apply diversity penalty
97
+ penalty = 0
98
+ if doc_id in selected_docs:
99
+ penalty += diversity_penalty * 0.5
100
+ if section_key in selected_sections:
101
+ penalty += diversity_penalty
102
+
103
+ adjusted_score = score * (1 - penalty)
104
+
105
+ # Add if still competitive
106
+ if not selected_nodes or adjusted_score >= selected_nodes[0][1] * 0.6:
107
+ selected_nodes.append((node, score))
108
+ selected_docs.add(doc_id)
109
+ selected_sections.add(section_key)
110
+
111
+ log_message(f"Выбрано {len(selected_nodes)} узлов с разнообразием")
112
+ log_message(f"Уникальных документов: {len(selected_docs)}, секций: {len(selected_sections)}")
113
+
114
+ if selected_nodes:
115
+ log_message(f"Score range: {selected_nodes[0][1]:.3f} to {selected_nodes[-1][1]:.3f}")
116
+
117
+ return [node for node, score in selected_nodes]
118
 
 
 
119
  except Exception as e:
120
  log_message(f"Ошибка переранжировки: {str(e)}")
121
  return nodes[:top_k]
utils.py CHANGED
@@ -10,39 +10,6 @@ from index_retriever import rerank_nodes
10
  from my_logging import log_message
11
  from config import PROMPT_SIMPLE_POISK
12
 
13
- def get_llm_model(model_name):
14
- try:
15
- model_config = AVAILABLE_MODELS.get(model_name)
16
- if not model_config:
17
- log_message(f"Модель {model_name} не найдена, использую модель по умолчанию")
18
- model_config = AVAILABLE_MODELS[DEFAULT_MODEL]
19
-
20
- if not model_config.get("api_key"):
21
- raise Exception(f"API ключ не найден для модели {model_name}")
22
-
23
- if model_config["provider"] == "google":
24
- return GoogleGenAI(
25
- model=model_config["model_name"],
26
- api_key=model_config["api_key"]
27
- )
28
- elif model_config["provider"] == "openai":
29
- return OpenAI(
30
- model=model_config["model_name"],
31
- api_key=model_config["api_key"]
32
- )
33
- else:
34
- raise Exception(f"Неподдерживаемый провайдер: {model_config['provider']}")
35
-
36
- except Exception as e:
37
- log_message(f"Ошибка создания модели {model_name}: {str(e)}")
38
- return GoogleGenAI(model="gemini-2.0-flash", api_key=GOOGLE_API_KEY)
39
-
40
- def get_embedding_model(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"):
41
- return HuggingFaceEmbedding(model_name=model_name)
42
-
43
- def get_reranker_model(model_name='cross-encoder/ms-marco-MiniLM-L-12-v2'):
44
- return CrossEncoder(model_name)
45
-
46
  def get_llm_model(model_name):
47
  try:
48
  model_config = AVAILABLE_MODELS.get(model_name)
@@ -168,7 +135,7 @@ def format_context_for_llm(nodes):
168
 
169
  return "\n".join(context_parts)
170
 
171
-
172
  def generate_sources_html(nodes, chunks_df=None):
173
  html = "<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; max-height: 400px; overflow-y: auto;'>"
174
  html += "<h3 style='color: #63b3ed; margin-top: 0;'>Источники:</h3>"
@@ -259,6 +226,31 @@ def generate_sources_html(nodes, chunks_df=None):
259
  html += "</div>"
260
  return html
261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  def answer_question(question, query_engine, reranker, current_model, chunks_df=None):
263
  if query_engine is None:
264
  return "<div style='background-color: #e53e3e; color: white; padding: 20px; border-radius: 10px;'>Система не инициализирована</div>", "", ""
@@ -266,25 +258,45 @@ def answer_question(question, query_engine, reranker, current_model, chunks_df=N
266
  try:
267
  start_time = time.time()
268
 
269
- retrieved_nodes = query_engine.retriever.retrieve(question)
270
- log_message(f"Получено узлов после гибридного поиска: {len(retrieved_nodes)}")
271
 
272
- # Use adaptive reranking with lower threshold for better recall
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  reranked_nodes = rerank_nodes(
274
- question,
275
- retrieved_nodes,
276
  reranker,
277
- top_k=25, # Increased from 20
278
- min_score_threshold=-5.0 # Add threshold to filter very low scores
 
279
  )
280
 
281
  formatted_context = format_context_for_llm(reranked_nodes)
282
 
283
- enhanced_question = f"""
284
- Контекст из базы данных:
285
  {formatted_context}
286
 
287
- Вопрос пользователя: {question}"""
 
 
 
288
 
289
  response = query_engine.query(enhanced_question)
290
 
@@ -299,7 +311,7 @@ def answer_question(question, query_engine, reranker, current_model, chunks_df=N
299
  <h3 style='color: #63b3ed; margin-top: 0;'>Ответ (Модель: {current_model}):</h3>
300
  <div style='line-height: 1.6; font-size: 16px;'>{response.response}</div>
301
  <div style='margin-top: 15px; padding-top: 10px; border-top: 1px solid #4a5568; font-size: 14px; color: #a0aec0;'>
302
- Время обработки: {processing_time:.2f} секунд | Источников: {len(reranked_nodes)}
303
  </div>
304
  </div>"""
305
 
 
10
  from my_logging import log_message
11
  from config import PROMPT_SIMPLE_POISK
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def get_llm_model(model_name):
14
  try:
15
  model_config = AVAILABLE_MODELS.get(model_name)
 
135
 
136
  return "\n".join(context_parts)
137
 
138
+
139
  def generate_sources_html(nodes, chunks_df=None):
140
  html = "<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; max-height: 400px; overflow-y: auto;'>"
141
  html += "<h3 style='color: #63b3ed; margin-top: 0;'>Источники:</h3>"
 
226
  html += "</div>"
227
  return html
228
 
229
+ def expand_query(question, llm_model):
230
+ """
231
+ Generate multiple query variations for better retrieval
232
+ """
233
+ expansion_prompt = f"""Дан вопрос: "{question}"
234
+
235
+ Сгенерируй 2 альтернативные формулировки этого вопроса для поиска в базе данных.
236
+ Используй синонимы и разные формулировки, сохраняя смысл.
237
+
238
+ Формат ответа (только вопросы, по одному на строку):
239
+ 1. [первая формулировка]
240
+ 2. [вторая формулировка]"""
241
+
242
+ try:
243
+ response = llm_model.complete(expansion_prompt)
244
+ expanded = [q.strip() for q in response.text.split('\n') if q.strip() and not q.strip().startswith('1.') and not q.strip().startswith('2.')]
245
+ # Clean up
246
+ expanded = [q.lstrip('12. ').strip() for q in expanded if len(q) > 10][:2]
247
+ log_message(f"Query expansion: {len(expanded)} вариантов")
248
+ return [question] + expanded
249
+ except Exception as e:
250
+ log_message(f"Ошибка расширения запроса: {str(e)}")
251
+ return [question]
252
+
253
+
254
  def answer_question(question, query_engine, reranker, current_model, chunks_df=None):
255
  if query_engine is None:
256
  return "<div style='background-color: #e53e3e; color: white; padding: 20px; border-radius: 10px;'>Система не инициализирована</div>", "", ""
 
258
  try:
259
  start_time = time.time()
260
 
261
+ # Get LLM for query expansion
262
+ llm = get_llm_model(current_model)
263
 
264
+ # Expand query
265
+ query_variations = expand_query(question, llm)
266
+
267
+ # Retrieve with multiple queries and deduplicate
268
+ all_nodes = []
269
+ seen_node_ids = set()
270
+
271
+ for query_var in query_variations:
272
+ retrieved = query_engine.retriever.retrieve(query_var)
273
+ for node in retrieved:
274
+ node_id = f"{node.node_id if hasattr(node, 'node_id') else hash(node.text)}"
275
+ if node_id not in seen_node_ids:
276
+ all_nodes.append(node)
277
+ seen_node_ids.add(node_id)
278
+
279
+ log_message(f"Получено {len(all_nodes)} уникальных узлов из {len(query_variations)} запросов")
280
+
281
+ # Rerank with stricter threshold and diversity
282
  reranked_nodes = rerank_nodes(
283
+ question, # Use original question for reranking
284
+ all_nodes,
285
  reranker,
286
+ top_k=20,
287
+ min_score_threshold=0.5, # Much stricter threshold
288
+ diversity_penalty=0.3
289
  )
290
 
291
  formatted_context = format_context_for_llm(reranked_nodes)
292
 
293
+ enhanced_question = f"""Контекст из базы данных:
 
294
  {formatted_context}
295
 
296
+ Вопрос пользователя: {question}
297
+
298
+ Инструкция: Ответь на вопрос, используя ТОЛЬКО информацию из контекста выше.
299
+ Если информации недостаточно, четко укажи это. Цитируй конкретные источники."""
300
 
301
  response = query_engine.query(enhanced_question)
302
 
 
311
  <h3 style='color: #63b3ed; margin-top: 0;'>Ответ (Модель: {current_model}):</h3>
312
  <div style='line-height: 1.6; font-size: 16px;'>{response.response}</div>
313
  <div style='margin-top: 15px; padding-top: 10px; border-top: 1px solid #4a5568; font-size: 14px; color: #a0aec0;'>
314
+ Время обработки: {processing_time:.2f} секунд | Источников: {len(reranked_nodes)} | Запросов: {len(query_variations)}
315
  </div>
316
  </div>"""
317