antimoda1 commited on
Commit ·
b6d731b
1
Parent(s): 873ada4
fix app
Browse files- app.py +65 -112
- generation.py +2 -1
- retrieval.py +7 -7
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
-
import re
|
| 2 |
import gradio as gr
|
|
|
|
|
|
|
| 3 |
from generation import wrap_prompt
|
| 4 |
from llm import get_llm_answer
|
| 5 |
from retrieval import Retrieval
|
|
@@ -8,114 +9,69 @@ from vocabulary.parse_vocabulary import parse_vocabulary
|
|
| 8 |
|
| 9 |
|
| 10 |
vocabulary, _ = parse_vocabulary('vocabulary/vocabulary.md')
|
| 11 |
-
retrieval = Retrieval()
|
| 12 |
|
| 13 |
|
| 14 |
-
|
| 15 |
-
"""Этап 1: Поиск и возврат результатов с фильтром по датам"""
|
| 16 |
-
|
| 17 |
-
if not query:
|
| 18 |
-
return None, [], [], "Введите вопрос для поиска"
|
| 19 |
-
|
| 20 |
-
# Преобразуем входные значения
|
| 21 |
-
try:
|
| 22 |
-
year_from = _parse_single_year(year_from)
|
| 23 |
-
year_to = _parse_single_year(year_to)
|
| 24 |
-
|
| 25 |
-
# Проверяем корректность диапазона
|
| 26 |
-
if year_from > year_to:
|
| 27 |
-
year_from, year_to = year_to, year_from
|
| 28 |
-
|
| 29 |
-
except (ValueError, TypeError):
|
| 30 |
-
return None, [], [], f"⚠️ Ошибка: некорректный диапазон лет ({year_from} - {year_to})"
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
# Применяем ЖЕСТКИЙ фильтр по датам ДО выбора top-k
|
| 40 |
-
year_search_range = (year_from, year_to)
|
| 41 |
-
filtered_by_date = retrieval.filter_by_year_range(chunk_ids, year_search_range)
|
| 42 |
-
|
| 43 |
-
# Если нет результатов после фильтра по датам
|
| 44 |
-
if not filtered_by_date:
|
| 45 |
-
return scores, chunk_ids, [], f"⚠️ Нет результатов в диапазоне {year_from}-{year_to}"
|
| 46 |
-
|
| 47 |
-
# Находим top-k среди отфильтрованных по датам (сортируем по релевантности BM25)
|
| 48 |
-
top_k = min(top_k, len(filtered_by_date))
|
| 49 |
-
filtered_scores = [(idx, scores[idx]) for idx in filtered_by_date]
|
| 50 |
-
filtered_scores.sort(key=lambda x: x[1], reverse=True)
|
| 51 |
-
top_k_indices = [idx for idx, _ in filtered_scores[:top_k]]
|
| 52 |
-
|
| 53 |
-
status = f"Найдено {len(scores)} чанков, {len(filtered_by_date)} в диапазоне {year_from}-{year_to}. Top-{top_k} выбраны."
|
| 54 |
-
|
| 55 |
-
return scores, chunk_ids, top_k_indices, status
|
| 56 |
|
| 57 |
-
def
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
return ""
|
| 66 |
-
|
| 67 |
-
# Найдем все уникальные абзацы из выбранных чанков
|
| 68 |
-
paragraphs_to_show = {} # paragraph_id -> doc_id
|
| 69 |
-
|
| 70 |
-
for idx in selected_indices:
|
| 71 |
-
if idx >= len(retrieval.paragraph_metadata) or idx >= len(retrieval.docs_metadata):
|
| 72 |
-
continue
|
| 73 |
-
|
| 74 |
-
paragraph_id = retrieval.paragraph_metadata[idx]
|
| 75 |
-
doc_id = retrieval.docs_metadata[idx]
|
| 76 |
-
paragraphs_to_show[paragraph_id] = doc_id
|
| 77 |
-
|
| 78 |
-
# Для каждого отмеченного абзаца найдем ВСЕ его чанки
|
| 79 |
-
full_paragraph_chunks = {} # paragraph_id -> [chunk_ids]
|
| 80 |
-
for chunk_id, paragraph_id in enumerate(retrieval.paragraph_metadata):
|
| 81 |
-
if paragraph_id in paragraphs_to_show:
|
| 82 |
-
if paragraph_id not in full_paragraph_chunks:
|
| 83 |
-
full_paragraph_chunks[paragraph_id] = []
|
| 84 |
-
full_paragraph_chunks[paragraph_id].append(chunk_id)
|
| 85 |
-
|
| 86 |
-
# Форматируем вывод
|
| 87 |
-
result_lines = []
|
| 88 |
-
for paragraph_id in sorted(paragraphs_to_show.keys()):
|
| 89 |
-
doc_id = paragraphs_to_show[paragraph_id]
|
| 90 |
-
chunk_indices = sorted(full_paragraph_chunks[paragraph_id])
|
| 91 |
-
|
| 92 |
-
doc_name = retrieval.docs_names[doc_id] if doc_id < len(retrieval.docs_names) else "Неизвестный документ"
|
| 93 |
-
|
| 94 |
-
# Объединяем все чанки абзаца в полный текст
|
| 95 |
-
paragraph_text = " ".join([retrieval.chunks[idx] for idx in chunk_indices])
|
| 96 |
-
|
| 97 |
-
# Форматируем вывод с названием документа
|
| 98 |
-
result_lines.append(f"Документ {doc_name}:")
|
| 99 |
-
result_lines.append(paragraph_text)
|
| 100 |
-
result_lines.append("") # Пустая строка между абзацами
|
| 101 |
-
|
| 102 |
-
return "\n".join(result_lines)
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
-
def format_retrieval_results(filtered_indices, top_k_results):
|
| 106 |
-
"""Форматирует результаты retrieval для отображения в текстовом поле
|
| 107 |
-
|
| 108 |
-
Берет top_k результатов и выводит целые абзацы с названиями документов
|
| 109 |
-
"""
|
| 110 |
-
if len(filtered_indices) == 0:
|
| 111 |
-
return "Нет результатов"
|
| 112 |
-
|
| 113 |
-
top_k_results = min(top_k_results, len(filtered_indices))
|
| 114 |
-
|
| 115 |
-
# Берем top-k индексов (уже отсортированы по релевантности)
|
| 116 |
-
top_k_indices = filtered_indices[:top_k_results]
|
| 117 |
-
|
| 118 |
-
return format_selected_chunks(top_k_indices)
|
| 119 |
|
| 120 |
def ask_llm(query, filtered_indices_state):
|
| 121 |
"""Этап 2: Отправка отфильтрованных чанков в LLM с потоковой выдачей"""
|
|
@@ -123,22 +79,19 @@ def ask_llm(query, filtered_indices_state):
|
|
| 123 |
yield "Введите вопрос"
|
| 124 |
return
|
| 125 |
|
| 126 |
-
|
| 127 |
-
chunks_to_use = filtered_indices_state if filtered_indices_state else []
|
| 128 |
-
|
| 129 |
-
if not chunks_to_use:
|
| 130 |
yield "Нет выбранных чанков для отправки в LLM"
|
| 131 |
return
|
| 132 |
|
| 133 |
# Форматируем контекст используя ту же функцию, что и в интерфейсе
|
| 134 |
-
context = format_selected_chunks(
|
| 135 |
|
| 136 |
if not context or context == "Нет валидных чанков":
|
| 137 |
yield "Нет валидных чанков для отправки"
|
| 138 |
return
|
| 139 |
|
| 140 |
# Формируем промпт и отправляем в LLM
|
| 141 |
-
prompt = wrap_prompt(context, query,
|
| 142 |
|
| 143 |
# Потоковая выдача ответа
|
| 144 |
full_answer = ""
|
|
@@ -237,18 +190,18 @@ with gr.Blocks(title="RAG Application", theme=gr.themes.Soft()) as iface:
|
|
| 237 |
|
| 238 |
# Обработчик поиска
|
| 239 |
search_btn.click(
|
| 240 |
-
fn=perform_search,
|
| 241 |
inputs=[search_query_input, top_k_slider, year_from_input, year_to_input],
|
| 242 |
outputs=[all_scores_state, all_chunk_ids_state, top_k_indices_state, search_status]
|
| 243 |
).then(
|
| 244 |
-
fn=format_retrieval_results,
|
| 245 |
inputs=[top_k_indices_state, top_k_slider],
|
| 246 |
outputs=[retrieval_results]
|
| 247 |
)
|
| 248 |
|
| 249 |
# Обработчик изменения слайдера top_k
|
| 250 |
top_k_slider.change(
|
| 251 |
-
fn=format_retrieval_results,
|
| 252 |
inputs=[top_k_indices_state, top_k_slider],
|
| 253 |
outputs=[retrieval_results]
|
| 254 |
)
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
from generation import wrap_prompt
|
| 5 |
from llm import get_llm_answer
|
| 6 |
from retrieval import Retrieval
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
vocabulary, _ = parse_vocabulary('vocabulary/vocabulary.md')
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
+
class Perform:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.retrieval = Retrieval()
|
| 18 |
+
lengthh = len(self.retrieval.paragraphs_df)
|
| 19 |
+
self.scores = None
|
| 20 |
+
self.sorted_idx = None
|
| 21 |
+
self.years_mask = np.ones(lengthh, dtype=bool)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
def get_years_range_mask(self, year_from, year_to):
|
| 24 |
+
try:
|
| 25 |
+
year_from = _parse_single_year(year_from)
|
| 26 |
+
year_to = _parse_single_year(year_to)
|
| 27 |
+
if year_from > year_to:
|
| 28 |
+
year_from, year_to = year_to, year_from
|
| 29 |
+
except (ValueError, TypeError):
|
| 30 |
+
raise ValueError(f"Некорректный диапазон лет: {year_from} - {year_to}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
self.years_mask = (
|
| 33 |
+
(self.retrieval.paragraphs_df["end_year"] >= year_from) &
|
| 34 |
+
(self.retrieval.paragraphs_df["start_year"] <= year_to)
|
| 35 |
+
).values
|
| 36 |
+
|
| 37 |
+
def perform_search(self, query, top_k, year_from, year_to):
|
| 38 |
+
self.get_years_range_mask(year_from, year_to)
|
| 39 |
+
|
| 40 |
+
# если есть query → считаем scores
|
| 41 |
+
if query:
|
| 42 |
+
self.scores = self.retrieval.search(query)
|
| 43 |
+
self.sorted_idx = np.argsort(self.scores)[::-1]
|
| 44 |
+
|
| 45 |
+
# если нет query и scores нет → используем только фильтр
|
| 46 |
+
if self.scores is None:
|
| 47 |
+
filtered = np.where(self.years_mask)[0]
|
| 48 |
+
if len(filtered) <= top_k:
|
| 49 |
+
return None, None, filtered, "Показаны все записи по фильтру лет"
|
| 50 |
+
return None, None, filtered[-top_k:], "Показаны записи по фильтру лет"
|
| 51 |
+
|
| 52 |
+
# применяем mask к отсортированным индексам
|
| 53 |
+
filtered_sorted = self.sorted_idx[self.years_mask[self.sorted_idx]]
|
| 54 |
+
|
| 55 |
+
if len(filtered_sorted) == 0:
|
| 56 |
+
return self.scores, None, [], "⚠️ Нет результатов в выбранном диапазоне лет"
|
| 57 |
+
|
| 58 |
+
top_k_indices = filtered_sorted[:top_k]
|
| 59 |
+
|
| 60 |
+
return self.scores, None, top_k_indices, f"Найдено {len(filtered_sorted)} результатов"
|
| 61 |
+
|
| 62 |
+
def format_retrieval_results(self, top_k_indices):
|
| 63 |
+
if len(top_k_indices) == 0:
|
| 64 |
+
return "Нет результатов"
|
| 65 |
+
|
| 66 |
+
texts = self.retrieval.paragraphs_df["texts"].iloc[top_k_indices]
|
| 67 |
+
return "\n\n".join(texts)
|
| 68 |
+
|
| 69 |
+
def format_selected_chunks(self, indices):
|
| 70 |
+
texts = self.retrieval.paragraphs_df["texts"].iloc[indices]
|
| 71 |
+
return "\n\n".join(texts)
|
| 72 |
+
|
| 73 |
+
perform = Perform()
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
def ask_llm(query, filtered_indices_state):
|
| 77 |
"""Этап 2: Отправка отфильтрованных чанков в LLM с потоковой выдачей"""
|
|
|
|
| 79 |
yield "Введите вопрос"
|
| 80 |
return
|
| 81 |
|
| 82 |
+
if not filtered_indices_state:
|
|
|
|
|
|
|
|
|
|
| 83 |
yield "Нет выбранных чанков для отправки в LLM"
|
| 84 |
return
|
| 85 |
|
| 86 |
# Форматируем контекст используя ту же функцию, что и в интерфейсе
|
| 87 |
+
context = perform.format_selected_chunks(filtered_indices_state)
|
| 88 |
|
| 89 |
if not context or context == "Нет валидных чанков":
|
| 90 |
yield "Нет валидных чанков для отправки"
|
| 91 |
return
|
| 92 |
|
| 93 |
# Формируем промпт и отправляем в LLM
|
| 94 |
+
prompt = wrap_prompt(context, query, vocabulary)
|
| 95 |
|
| 96 |
# Потоковая выдача ответа
|
| 97 |
full_answer = ""
|
|
|
|
| 190 |
|
| 191 |
# Обработчик поиска
|
| 192 |
search_btn.click(
|
| 193 |
+
fn=perform.perform_search,
|
| 194 |
inputs=[search_query_input, top_k_slider, year_from_input, year_to_input],
|
| 195 |
outputs=[all_scores_state, all_chunk_ids_state, top_k_indices_state, search_status]
|
| 196 |
).then(
|
| 197 |
+
fn=perform.format_retrieval_results,
|
| 198 |
inputs=[top_k_indices_state, top_k_slider],
|
| 199 |
outputs=[retrieval_results]
|
| 200 |
)
|
| 201 |
|
| 202 |
# Обработчик изменения слайдера top_k
|
| 203 |
top_k_slider.change(
|
| 204 |
+
fn=perform.format_retrieval_results,
|
| 205 |
inputs=[top_k_indices_state, top_k_slider],
|
| 206 |
outputs=[retrieval_results]
|
| 207 |
)
|
generation.py
CHANGED
|
@@ -64,7 +64,8 @@ def lemmatize(text, vocabulary):
|
|
| 64 |
return found_terms
|
| 65 |
|
| 66 |
|
| 67 |
-
def wrap_prompt(retrieved_text, query_text,
|
|
|
|
| 68 |
tokens_from_query = lemmatize(query_text, vocabula)
|
| 69 |
tokens_from_retrieved = lemmatize(retrieved_text, vocabula)
|
| 70 |
info_for_llm = ''
|
|
|
|
| 64 |
return found_terms
|
| 65 |
|
| 66 |
|
| 67 |
+
def wrap_prompt(retrieved_text, query_text, inp_vocabula: dict[str, str]):
|
| 68 |
+
vocabula = inp_vocabula.copy() # Создаем копию словаря, чтобы не изменять оригинал
|
| 69 |
tokens_from_query = lemmatize(query_text, vocabula)
|
| 70 |
tokens_from_retrieved = lemmatize(retrieved_text, vocabula)
|
| 71 |
info_for_llm = ''
|
retrieval.py
CHANGED
|
@@ -29,18 +29,19 @@ class Retrieval:
|
|
| 29 |
|
| 30 |
1. ДАТАФРЕЙМ ПАРАГРАФОВ (self.paragraphs_df):
|
| 31 |
┌──────────────────────┬─────────────────────────────────┐
|
| 32 |
-
│ Колонка │ Описание
|
| 33 |
├──────────────────────┼─────────────────────────────────┤
|
| 34 |
│ paragraph_id │ Уникальный ID параграфа │
|
| 35 |
│ summary │ Название документа/раздела │
|
| 36 |
│ start_year │ Год начала периода │
|
| 37 |
│ end_year │ Год окончания периода │
|
|
|
|
| 38 |
│ document_id │ Ссылка на исходный документ │
|
| 39 |
└──────────────────────┴─────────────────────────────────┘
|
| 40 |
|
| 41 |
2. ДАТАФРЕЙМ ЧАНКОВ (self.chunks_df):
|
| 42 |
┌──────────────────────┬─────────────────────────────────┐
|
| 43 |
-
│ Колонка │ Описание
|
| 44 |
├──────────────────────┼─────────────────────────────────┤
|
| 45 |
│ chunk_id │ Уникальный ID чанка │
|
| 46 |
│ paragraph_id │ Foreign key на параграф │
|
|
@@ -101,6 +102,7 @@ class Retrieval:
|
|
| 101 |
- summary: название документа/раздела
|
| 102 |
- start_year: год начала периода
|
| 103 |
- end_year: год окончания периода
|
|
|
|
| 104 |
- document_id: ссылка на исходный документ
|
| 105 |
|
| 106 |
chunks_df:
|
|
@@ -130,6 +132,7 @@ class Retrieval:
|
|
| 130 |
'summary': summary,
|
| 131 |
'start_year': year_range[0],
|
| 132 |
'end_year': year_range[1],
|
|
|
|
| 133 |
'document_id': doc_id
|
| 134 |
})
|
| 135 |
|
|
@@ -361,11 +364,8 @@ class Retrieval:
|
|
| 361 |
paragraph_scores = df.groupby('paragraph_id')['score'].max().reindex(self.paragraphs_df['paragraph_id']).fillna(0)
|
| 362 |
return paragraph_scores
|
| 363 |
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
target_summary: str, weight_bm25: float = 0.5, weight_semantic: float = 0.5) -> None:
|
| 367 |
-
""" Тестирует запрос с cross-encoder и выводит результаты.
|
| 368 |
-
|
| 369 |
Args:
|
| 370 |
query: Текст запроса
|
| 371 |
target_summary: Ожидаемый summary
|
|
|
|
| 29 |
|
| 30 |
1. ДАТАФРЕЙМ ПАРАГРАФОВ (self.paragraphs_df):
|
| 31 |
┌──────────────────────┬─────────────────────────────────┐
|
| 32 |
+
│ Колонка │ Описание │
|
| 33 |
├──────────────────────┼─────────────────────────────────┤
|
| 34 |
│ paragraph_id │ Уникальный ID параграфа │
|
| 35 |
│ summary │ Название документа/раздела │
|
| 36 |
│ start_year │ Год начала периода │
|
| 37 |
│ end_year │ Год окончания периода │
|
| 38 |
+
│ text │ Текст │
|
| 39 |
│ document_id │ Ссылка на исходный документ │
|
| 40 |
└──────────────────────┴─────────────────────────────────┘
|
| 41 |
|
| 42 |
2. ДАТАФРЕЙМ ЧАНКОВ (self.chunks_df):
|
| 43 |
┌──────────────────────┬─────────────────────────────────┐
|
| 44 |
+
│ Колонка │ Описание │
|
| 45 |
├──────────────────────┼─────────────────────────────────┤
|
| 46 |
│ chunk_id │ Уникальный ID чанка │
|
| 47 |
│ paragraph_id │ Foreign key на параграф │
|
|
|
|
| 102 |
- summary: название документа/раздела
|
| 103 |
- start_year: год начала периода
|
| 104 |
- end_year: год окончания периода
|
| 105 |
+
- text: текст абзаца
|
| 106 |
- document_id: ссылка на исходный документ
|
| 107 |
|
| 108 |
chunks_df:
|
|
|
|
| 132 |
'summary': summary,
|
| 133 |
'start_year': year_range[0],
|
| 134 |
'end_year': year_range[1],
|
| 135 |
+
'text': paragraph,
|
| 136 |
'document_id': doc_id
|
| 137 |
})
|
| 138 |
|
|
|
|
| 364 |
paragraph_scores = df.groupby('paragraph_id')['score'].max().reindex(self.paragraphs_df['paragraph_id']).fillna(0)
|
| 365 |
return paragraph_scores
|
| 366 |
|
| 367 |
+
def search(self, query: str, target_summary: str, weight_bm25: float = 0.5, weight_semantic: float = 0.5) -> None:
|
| 368 |
+
"""
|
|
|
|
|
|
|
|
|
|
| 369 |
Args:
|
| 370 |
query: Текст запроса
|
| 371 |
target_summary: Ожидаемый summary
|