RAG2 / app.py
antimoda1
add TODO
0aa6d2c
import gradio as gr
import numpy as np
from llm import get_llm_answer
from retrieval import Retrieval
from parse_documents import _parse_single_year
from vocabulary.parse_vocabulary import VOCABULARY_MANAGER
class Perform:
def __init__(self):
self.retrieval = Retrieval()
lengthh = len(self.retrieval.paragraphs_df)
self.scores = None
self.sorted_idx = None
self.years_mask = np.ones(lengthh, dtype=bool)
def get_years_range_mask(self, year_from, year_to):
try:
year_from = _parse_single_year(year_from)
year_to = _parse_single_year(year_to)
if year_from > year_to:
year_from, year_to = year_to, year_from
except (ValueError, TypeError):
raise ValueError(f"Некорректный диапазон лет: {year_from} - {year_to}")
self.years_mask = (
(self.retrieval.paragraphs_df["end_year"] >= year_from) &
(self.retrieval.paragraphs_df["start_year"] <= year_to)
).values
def perform_search(self, query, top_k, year_from, year_to):
self.get_years_range_mask(year_from, year_to)
# если есть query → считаем scores
if query:
self.scores = self.retrieval.search(query)
self.sorted_idx = np.argsort(self.scores)[::-1]
# если нет query и scores нет → используем только фильтр
if self.scores is None:
filtered = np.where(self.years_mask)[0]
if len(filtered) <= top_k:
return None, None, filtered, "Показаны все записи по фильтру лет"
return None, None, filtered[-top_k:], "Показаны записи по фильтру лет"
# применяем mask к отсортированным индексам
filtered_sorted = self.sorted_idx[self.years_mask[self.sorted_idx]]
if len(filtered_sorted) == 0:
return self.scores, None, [], "⚠️ Нет результатов в выбранном диапазоне лет"
top_k_indices = filtered_sorted[:top_k]
return self.scores, None, top_k_indices, f"Найдено {len(filtered_sorted)} результатов"
def format_retrieval_results(self, indices):
if len(indices) == 0:
return "Нет результатов"
texts = self.retrieval.paragraphs_df["text"].iloc[indices]
return "\n\n".join(texts)
perform = Perform()
def ui_search(query, top_k, year_from, year_to):
return perform.perform_search(query, top_k, year_from, year_to)
def ui_format_results(indices, top_k):
if indices is None:
return "Нет результатов"
indices = indices[:top_k]
return perform.format_retrieval_results(indices)
def ask_llm(query, filtered_indices_state):
"""Этап 2: Отправка отфильтрованных чанков в LLM с потоковой выдачей"""
if not query:
yield "Введите вопрос"
return
context = perform.format_retrieval_results(filtered_indices_state)
if not context or context == "Нет валидных чанков":
yield "Нет валидных чанков для отправки"
return
# Формируем промпт и отправляем в LLM
prompt = VOCABULARY_MANAGER.wrap_prompt(context, query)
# Потоковая выдача ответа
full_answer = ""
for chunk in get_llm_answer(prompt):
full_answer += chunk
yield full_answer
# Создаем интерфейс Gradio
with gr.Blocks(title="RAG Application", theme=gr.themes.Soft()) as iface:
gr.Markdown("## Справочник по общественного истории транспорта Рязани")
# Строка 1: поиск и фильтр по датам
with gr.Row():
search_query_input = gr.Textbox(
label="Запрос для поиска",
placeholder="Введите запрос для поиска",
lines=1,
scale=3
)
year_from_input = gr.Textbox(
label="От года",
value='1918',
placeholder="Год",
lines=1,
scale=1
)
year_to_input = gr.Textbox(
label="До года",
value='2026',
placeholder="Год",
lines=1,
scale=1
)
# Строка 2: top-k параметр
with gr.Row():
top_k_slider = gr.Slider(
minimum=1,
maximum=100,
value=30,
step=1,
label="Top-k для поиска",
scale=2
)
# Строка 3: кнопка поиска и статус
with gr.Row():
search_btn = gr.Button(
"🔍 Выполнить поиск",
variant="primary",
scale=1
)
search_status = gr.Textbox(
label="Статус",
interactive=False,
scale=3
)
with gr.Row():
with gr.Column(scale=2):
# Большое текстовое поле для результатов retrieval
retrieval_results = gr.Textbox(
label="Результаты поиска",
placeholder="Результаты поиска появятся здесь",
lines=15,
max_lines=30,
interactive=False
)
with gr.Row():
with gr.Column(scale=1):
# Ввод вопроса для LLM
llm_query_input = gr.Textbox(
label="Ваш вопрос по результатам поиска",
placeholder="Введите вопрос по историческим документам...",
lines=2
)
with gr.Column(scale=2):
# Кнопка отправки в LLM
llm_btn = gr.Button("Спросить LLM", variant="secondary")
with gr.Column(scale=3):
# Ответ LLM с потоковой выдачей
llm_answer = gr.Markdown(
label="Ответ LLM (появляется постепенно)"
)
# Состояния для хранения данных между вызовами
all_scores_state = gr.State()
all_chunk_ids_state = gr.State()
top_k_indices_state = gr.State()
# Обработчик поиска
search_btn.click(
fn=ui_search,
inputs=[search_query_input, top_k_slider, year_from_input, year_to_input],
outputs=[all_scores_state, all_chunk_ids_state, top_k_indices_state, search_status]
).then(
fn=ui_format_results,
inputs=[top_k_indices_state, top_k_slider],
outputs=[retrieval_results]
)
# Обработчик изменения слайдера top_k
top_k_slider.change(
fn=ui_format_results,
inputs=[top_k_indices_state, top_k_slider],
outputs=[retrieval_results]
)
# Отправка в LLM с потоковой выдачей
llm_btn.click(
fn=ask_llm,
inputs=[llm_query_input, top_k_indices_state],
outputs=[llm_answer]
)
if __name__ == "__main__":
iface.launch(ssr_mode=False,
share=True
)