antimoda1 commited on
Commit
b6d731b
·
1 Parent(s): 873ada4
Files changed (3) hide show
  1. app.py +65 -112
  2. generation.py +2 -1
  3. retrieval.py +7 -7
app.py CHANGED
@@ -1,5 +1,6 @@
1
- import re
2
  import gradio as gr
 
 
3
  from generation import wrap_prompt
4
  from llm import get_llm_answer
5
  from retrieval import Retrieval
@@ -8,114 +9,69 @@ from vocabulary.parse_vocabulary import parse_vocabulary
8
 
9
 
10
  vocabulary, _ = parse_vocabulary('vocabulary/vocabulary.md')
11
- retrieval = Retrieval()
12
 
13
 
14
- def perform_search(query, top_k, year_from, year_to):
15
- """Этап 1: Поиск и возврат результатов с фильтром по датам"""
16
-
17
- if not query:
18
- return None, [], [], "Введите вопрос для поиска"
19
-
20
- # Преобразуем входные значения
21
- try:
22
- year_from = _parse_single_year(year_from)
23
- year_to = _parse_single_year(year_to)
24
-
25
- # Проверяем корректность диапазона
26
- if year_from > year_to:
27
- year_from, year_to = year_to, year_from
28
-
29
- except (ValueError, TypeError):
30
- return None, [], [], f"⚠️ Ошибка: некорректный диапазон лет ({year_from} - {year_to})"
31
 
32
- # Выполняем поиск BM25
33
- scores = retrieval.bm25_search(query)
34
- scores = list(scores) # Преобразуем в список если это ndarray
35
-
36
- # Получаем индексы чанков
37
- chunk_ids = list(range(len(scores)))
38
-
39
- # Применяем ЖЕСТКИЙ фильтр по датам ДО выбора top-k
40
- year_search_range = (year_from, year_to)
41
- filtered_by_date = retrieval.filter_by_year_range(chunk_ids, year_search_range)
42
-
43
- # Если нет результатов после фильтра по датам
44
- if not filtered_by_date:
45
- return scores, chunk_ids, [], f"⚠️ Нет результатов в диапазоне {year_from}-{year_to}"
46
-
47
- # Находим top-k среди отфильтрованных по датам (сортируем по релевантности BM25)
48
- top_k = min(top_k, len(filtered_by_date))
49
- filtered_scores = [(idx, scores[idx]) for idx in filtered_by_date]
50
- filtered_scores.sort(key=lambda x: x[1], reverse=True)
51
- top_k_indices = [idx for idx, _ in filtered_scores[:top_k]]
52
-
53
- status = f"Найдено {len(scores)} чанков, {len(filtered_by_date)} в диапазоне {year_from}-{year_to}. Top-{top_k} выбраны."
54
-
55
- return scores, chunk_ids, top_k_indices, status
56
 
57
- def format_selected_chunks(selected_indices):
58
- """Форматирует выбранные чанки в единый текст для вывода и LLM
59
-
60
- Выводит целые абзацы по выбранным чанкам с названиями документов:
61
- Документ {название}:
62
- {полный текст абзаца}
63
- """
64
- if not selected_indices:
65
- return ""
66
-
67
- # Найдем все уникальные абзацы из выбранных чанков
68
- paragraphs_to_show = {} # paragraph_id -> doc_id
69
-
70
- for idx in selected_indices:
71
- if idx >= len(retrieval.paragraph_metadata) or idx >= len(retrieval.docs_metadata):
72
- continue
73
-
74
- paragraph_id = retrieval.paragraph_metadata[idx]
75
- doc_id = retrieval.docs_metadata[idx]
76
- paragraphs_to_show[paragraph_id] = doc_id
77
-
78
- # Для каждого отмеченного абзаца найдем ВСЕ его чанки
79
- full_paragraph_chunks = {} # paragraph_id -> [chunk_ids]
80
- for chunk_id, paragraph_id in enumerate(retrieval.paragraph_metadata):
81
- if paragraph_id in paragraphs_to_show:
82
- if paragraph_id not in full_paragraph_chunks:
83
- full_paragraph_chunks[paragraph_id] = []
84
- full_paragraph_chunks[paragraph_id].append(chunk_id)
85
-
86
- # Форматируем вывод
87
- result_lines = []
88
- for paragraph_id in sorted(paragraphs_to_show.keys()):
89
- doc_id = paragraphs_to_show[paragraph_id]
90
- chunk_indices = sorted(full_paragraph_chunks[paragraph_id])
91
-
92
- doc_name = retrieval.docs_names[doc_id] if doc_id < len(retrieval.docs_names) else "Неизвестный документ"
93
-
94
- # Объединяем все чанки абзаца в полный текст
95
- paragraph_text = " ".join([retrieval.chunks[idx] for idx in chunk_indices])
96
-
97
- # Форматируем вывод с названием документа
98
- result_lines.append(f"Документ {doc_name}:")
99
- result_lines.append(paragraph_text)
100
- result_lines.append("") # Пустая строка между абзацами
101
-
102
- return "\n".join(result_lines)
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- def format_retrieval_results(filtered_indices, top_k_results):
106
- """Форматирует результаты retrieval для отображения в текстовом поле
107
-
108
- Берет top_k результатов и выводит целые абзацы с названиями документов
109
- """
110
- if len(filtered_indices) == 0:
111
- return "Нет результатов"
112
-
113
- top_k_results = min(top_k_results, len(filtered_indices))
114
-
115
- # Берем top-k индексов (уже отсортированы по релевантности)
116
- top_k_indices = filtered_indices[:top_k_results]
117
-
118
- return format_selected_chunks(top_k_indices)
119
 
120
  def ask_llm(query, filtered_indices_state):
121
  """Этап 2: Отправка отфильтрованных чанков в LLM с потоковой выдачей"""
@@ -123,22 +79,19 @@ def ask_llm(query, filtered_indices_state):
123
  yield "Введите вопрос"
124
  return
125
 
126
- # Используем все отфильтрованные чанки
127
- chunks_to_use = filtered_indices_state if filtered_indices_state else []
128
-
129
- if not chunks_to_use:
130
  yield "Нет выбранных чанков для отправки в LLM"
131
  return
132
 
133
  # Форматируем контекст используя ту же функцию, что и в интерфейсе
134
- context = format_selected_chunks(list(chunks_to_use))
135
 
136
  if not context or context == "Нет валидных чанков":
137
  yield "Нет валидных чанков для отправки"
138
  return
139
 
140
  # Формируем промпт и отправляем в LLM
141
- prompt = wrap_prompt(context, query, vocabula=vocabulary.copy())
142
 
143
  # Потоковая выдача ответа
144
  full_answer = ""
@@ -237,18 +190,18 @@ with gr.Blocks(title="RAG Application", theme=gr.themes.Soft()) as iface:
237
 
238
  # Обработчик поиска
239
  search_btn.click(
240
- fn=perform_search,
241
  inputs=[search_query_input, top_k_slider, year_from_input, year_to_input],
242
  outputs=[all_scores_state, all_chunk_ids_state, top_k_indices_state, search_status]
243
  ).then(
244
- fn=format_retrieval_results,
245
  inputs=[top_k_indices_state, top_k_slider],
246
  outputs=[retrieval_results]
247
  )
248
 
249
  # Обработчик изменения слайдера top_k
250
  top_k_slider.change(
251
- fn=format_retrieval_results,
252
  inputs=[top_k_indices_state, top_k_slider],
253
  outputs=[retrieval_results]
254
  )
 
 
1
  import gradio as gr
2
+ import numpy as np
3
+
4
  from generation import wrap_prompt
5
  from llm import get_llm_answer
6
  from retrieval import Retrieval
 
9
 
10
 
11
  vocabulary, _ = parse_vocabulary('vocabulary/vocabulary.md')
 
12
 
13
 
14
+ class Perform:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ def __init__(self):
17
+ self.retrieval = Retrieval()
18
+ lengthh = len(self.retrieval.paragraphs_df)
19
+ self.scores = None
20
+ self.sorted_idx = None
21
+ self.years_mask = np.ones(lengthh, dtype=bool)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ def get_years_range_mask(self, year_from, year_to):
24
+ try:
25
+ year_from = _parse_single_year(year_from)
26
+ year_to = _parse_single_year(year_to)
27
+ if year_from > year_to:
28
+ year_from, year_to = year_to, year_from
29
+ except (ValueError, TypeError):
30
+ raise ValueError(f"Некорректный диапазон лет: {year_from} - {year_to}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ self.years_mask = (
33
+ (self.retrieval.paragraphs_df["end_year"] >= year_from) &
34
+ (self.retrieval.paragraphs_df["start_year"] <= year_to)
35
+ ).values
36
+
37
+ def perform_search(self, query, top_k, year_from, year_to):
38
+ self.get_years_range_mask(year_from, year_to)
39
+
40
+ # если есть query → считаем scores
41
+ if query:
42
+ self.scores = self.retrieval.search(query)
43
+ self.sorted_idx = np.argsort(self.scores)[::-1]
44
+
45
+ # если нет query и scores нет → используем только фильтр
46
+ if self.scores is None:
47
+ filtered = np.where(self.years_mask)[0]
48
+ if len(filtered) <= top_k:
49
+ return None, None, filtered, "Показаны все записи по фильтру лет"
50
+ return None, None, filtered[-top_k:], "Показаны записи по фильтру лет"
51
+
52
+ # применяем mask к отсортированным индексам
53
+ filtered_sorted = self.sorted_idx[self.years_mask[self.sorted_idx]]
54
+
55
+ if len(filtered_sorted) == 0:
56
+ return self.scores, None, [], "⚠️ Нет результатов в выбранном диапазоне лет"
57
+
58
+ top_k_indices = filtered_sorted[:top_k]
59
+
60
+ return self.scores, None, top_k_indices, f"Найдено {len(filtered_sorted)} результатов"
61
+
62
+ def format_retrieval_results(self, top_k_indices):
63
+ if len(top_k_indices) == 0:
64
+ return "Нет результатов"
65
+
66
+ texts = self.retrieval.paragraphs_df["texts"].iloc[top_k_indices]
67
+ return "\n\n".join(texts)
68
+
69
+ def format_selected_chunks(self, indices):
70
+ texts = self.retrieval.paragraphs_df["texts"].iloc[indices]
71
+ return "\n\n".join(texts)
72
+
73
+ perform = Perform()
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  def ask_llm(query, filtered_indices_state):
77
  """Этап 2: Отправка отфильтрованных чанков в LLM с потоковой выдачей"""
 
79
  yield "Введите вопрос"
80
  return
81
 
82
+ if not filtered_indices_state:
 
 
 
83
  yield "Нет выбранных чанков для отправки в LLM"
84
  return
85
 
86
  # Форматируем контекст используя ту же функцию, что и в интерфейсе
87
+ context = perform.format_selected_chunks(filtered_indices_state)
88
 
89
  if not context or context == "Нет валидных чанков":
90
  yield "Нет валидных чанков для отправки"
91
  return
92
 
93
  # Формируем промпт и отправляем в LLM
94
+ prompt = wrap_prompt(context, query, vocabulary)
95
 
96
  # Потоковая выдача ответа
97
  full_answer = ""
 
190
 
191
  # Обработчик поиска
192
  search_btn.click(
193
+ fn=perform.perform_search,
194
  inputs=[search_query_input, top_k_slider, year_from_input, year_to_input],
195
  outputs=[all_scores_state, all_chunk_ids_state, top_k_indices_state, search_status]
196
  ).then(
197
+ fn=perform.format_retrieval_results,
198
  inputs=[top_k_indices_state, top_k_slider],
199
  outputs=[retrieval_results]
200
  )
201
 
202
  # Обработчик изменения слайдера top_k
203
  top_k_slider.change(
204
+ fn=perform.format_retrieval_results,
205
  inputs=[top_k_indices_state, top_k_slider],
206
  outputs=[retrieval_results]
207
  )
generation.py CHANGED
@@ -64,7 +64,8 @@ def lemmatize(text, vocabulary):
64
  return found_terms
65
 
66
 
67
- def wrap_prompt(retrieved_text, query_text, vocabula: dict[str, str]):
 
68
  tokens_from_query = lemmatize(query_text, vocabula)
69
  tokens_from_retrieved = lemmatize(retrieved_text, vocabula)
70
  info_for_llm = ''
 
64
  return found_terms
65
 
66
 
67
+ def wrap_prompt(retrieved_text, query_text, inp_vocabula: dict[str, str]):
68
+ vocabula = inp_vocabula.copy() # Создаем копию словаря, чтобы не изменять оригинал
69
  tokens_from_query = lemmatize(query_text, vocabula)
70
  tokens_from_retrieved = lemmatize(retrieved_text, vocabula)
71
  info_for_llm = ''
retrieval.py CHANGED
@@ -29,18 +29,19 @@ class Retrieval:
29
 
30
  1. ДАТАФРЕЙМ ПАРАГРАФОВ (self.paragraphs_df):
31
  ┌──────────────────────┬─────────────────────────────────┐
32
- │ Колонка │ Описание
33
  ├──────────────────────┼─────────────────────────────────┤
34
  │ paragraph_id │ Уникальный ID параграфа │
35
  │ summary │ Название документа/раздела │
36
  │ start_year │ Год начала периода │
37
  │ end_year │ Год окончания периода │
 
38
  │ document_id │ Ссылка на исходный документ │
39
  └──────────────────────┴─────────────────────────────────┘
40
 
41
  2. ДАТАФРЕЙМ ЧАНКОВ (self.chunks_df):
42
  ┌──────────────────────┬─────────────────────────────────┐
43
- │ Колонка │ Описание
44
  ├──────────────────────┼─────────────────────────────────┤
45
  │ chunk_id │ Уникальный ID чанка │
46
  │ paragraph_id │ Foreign key на параграф │
@@ -101,6 +102,7 @@ class Retrieval:
101
  - summary: название документа/раздела
102
  - start_year: год начала периода
103
  - end_year: год окончания периода
 
104
  - document_id: ссылка на исходный документ
105
 
106
  chunks_df:
@@ -130,6 +132,7 @@ class Retrieval:
130
  'summary': summary,
131
  'start_year': year_range[0],
132
  'end_year': year_range[1],
 
133
  'document_id': doc_id
134
  })
135
 
@@ -361,11 +364,8 @@ class Retrieval:
361
  paragraph_scores = df.groupby('paragraph_id')['score'].max().reindex(self.paragraphs_df['paragraph_id']).fillna(0)
362
  return paragraph_scores
363
 
364
- # ============ Для тестирования cross-encoder ============
365
- def test_query_with_cross_encoder(self, query: str,
366
- target_summary: str, weight_bm25: float = 0.5, weight_semantic: float = 0.5) -> None:
367
- """ Тестирует запрос с cross-encoder и выводит результаты.
368
-
369
  Args:
370
  query: Текст запроса
371
  target_summary: Ожидаемый summary
 
29
 
30
  1. ДАТАФРЕЙМ ПАРАГРАФОВ (self.paragraphs_df):
31
  ┌──────────────────────┬─────────────────────────────────┐
32
+ │ Колонка │ Описание
33
  ├──────────────────────┼─────────────────────────────────┤
34
  │ paragraph_id │ Уникальный ID параграфа │
35
  │ summary │ Название документа/раздела │
36
  │ start_year │ Год начала периода │
37
  │ end_year │ Год окончания периода │
38
+ │ text │ Текст │
39
  │ document_id │ Ссылка на исходный документ │
40
  └──────────────────────┴─────────────────────────────────┘
41
 
42
  2. ДАТАФРЕЙМ ЧАНКОВ (self.chunks_df):
43
  ┌──────────────────────┬─────────────────────────────────┐
44
+ │ Колонка │ Описание
45
  ├──────────────────────┼─────────────────────────────────┤
46
  │ chunk_id │ Уникальный ID чанка │
47
  │ paragraph_id │ Foreign key на параграф │
 
102
  - summary: название документа/раздела
103
  - start_year: год начала периода
104
  - end_year: год окончания периода
105
+ - text: текст абзаца
106
  - document_id: ссылка на исходный документ
107
 
108
  chunks_df:
 
132
  'summary': summary,
133
  'start_year': year_range[0],
134
  'end_year': year_range[1],
135
+ 'text': paragraph,
136
  'document_id': doc_id
137
  })
138
 
 
364
  paragraph_scores = df.groupby('paragraph_id')['score'].max().reindex(self.paragraphs_df['paragraph_id']).fillna(0)
365
  return paragraph_scores
366
 
367
+ def search(self, query: str, target_summary: str, weight_bm25: float = 0.5, weight_semantic: float = 0.5) -> None:
368
+ """
 
 
 
369
  Args:
370
  query: Текст запроса
371
  target_summary: Ожидаемый summary