RAG_AIEXP / app.py
MrSimple01's picture
Upload 10 files
fa02ae1 verified
import gradio as gr
import os
from llama_index.core import Settings
from documents_prep import load_json_documents, load_table_documents, load_image_documents
from my_logging import log_message
from index_retriever import create_vector_index, create_query_engine
import sys
from config import (
HF_REPO_ID, HF_TOKEN, DOWNLOAD_DIR, CHUNKS_FILENAME,
JSON_FILES_DIR, TABLE_DATA_DIR, IMAGE_DATA_DIR, DEFAULT_MODEL, AVAILABLE_MODELS
)
from converters.converter import process_uploaded_file, convert_single_excel_to_json, convert_single_excel_to_csv
from main_utils import *
def restart_system():
"""Перезапуск системы для применения новых документов"""
global query_engine, chunks_df, reranker, vector_index, current_model
try:
log_message("Начало перезапуска системы...")
log_message("Очистка кэша HuggingFace...")
import shutil
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
if os.path.exists(cache_dir):
try:
shutil.rmtree(cache_dir)
log_message("✓ Кэш очищен")
except:
log_message("⚠ Не удалось очистить кэш полностью")
query_engine, chunks_df, reranker, vector_index, chunk_info = initialize_system(
repo_id=HF_REPO_ID,
hf_token=HF_TOKEN,
download_dir=DOWNLOAD_DIR,
json_files_dir=JSON_FILES_DIR,
table_data_dir=TABLE_DATA_DIR,
image_data_dir=IMAGE_DATA_DIR,
use_json_instead_csv=True,
)
if query_engine:
# Get updated stats
stats = get_repository_stats(HF_REPO_ID, HF_TOKEN, JSON_FILES_DIR,
TABLE_DATA_DIR, IMAGE_DATA_DIR)
stats_display = format_stats_display(stats)
log_message("Система успешно перезапущена")
return "✅ Система успешно перезапущена! Новые документы загружены.", stats_display
else:
return "❌ Ошибка при перезапуске системы", "Статистика недоступна"
except Exception as e:
error_msg = f"Ошибка перезапуска: {str(e)}"
log_message(error_msg)
return f"❌ {error_msg}", "Статистика недоступна"
def initialize_system(repo_id, hf_token, download_dir, chunks_filename=None,
json_files_dir=None, table_data_dir=None, image_data_dir=None,
use_json_instead_csv=False):
try:
log_message("Инициализация системы")
os.makedirs(download_dir, exist_ok=True)
from config import CHUNK_SIZE, CHUNK_OVERLAP
from llama_index.core.text_splitter import TokenTextSplitter
embed_model = get_embedding_model()
llm = get_llm_model(DEFAULT_MODEL)
reranker = get_reranker_model()
Settings.embed_model = embed_model
Settings.llm = llm
Settings.text_splitter = TokenTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
separator=" ",
backup_separators=["\n", ".", "!", "?"]
)
all_documents = []
chunks_df = None
if use_json_instead_csv and json_files_dir:
log_message("Используем JSON файлы вместо CSV")
from documents_prep import load_all_documents
all_documents = load_all_documents(
repo_id=repo_id,
hf_token=hf_token,
json_dir=json_files_dir,
table_dir=table_data_dir if table_data_dir else "",
image_dir=image_data_dir if image_data_dir else ""
)
else:
if chunks_filename:
log_message("Загружаем данные из CSV")
if table_data_dir:
from documents_prep import load_table_documents
table_chunks = load_table_documents(repo_id, hf_token, table_data_dir)
log_message(f"Загружено {len(table_chunks)} табличных чанков")
all_documents.extend(table_chunks)
if image_data_dir:
from documents_prep import load_image_documents
image_documents = load_image_documents(repo_id, hf_token, image_data_dir)
log_message(f"Загружено {len(image_documents)} документов изображений")
all_documents.extend(image_documents)
log_message(f"Всего документов после всей обработки: {len(all_documents)}")
vector_index = create_vector_index(all_documents)
query_engine = create_query_engine(vector_index)
chunk_info = []
for doc in all_documents:
chunk_info.append({
'document_id': doc.metadata.get('document_id', 'unknown'),
'section_id': doc.metadata.get('section_id', 'unknown'),
'type': doc.metadata.get('type', 'text'),
'chunk_text': doc.text[:200] + '...' if len(doc.text) > 200 else doc.text,
'table_number': doc.metadata.get('table_number', ''),
'image_number': doc.metadata.get('image_number', ''),
'section': doc.metadata.get('section', ''),
'connection_type': doc.metadata.get('connection_type', '')
})
log_message(f"Система успешно инициализирована")
return query_engine, chunks_df, reranker, vector_index, chunk_info
except Exception as e:
log_message(f"Ошибка инициализации: {str(e)}")
import traceback
log_message(traceback.format_exc())
return None, None, None, None, []
def switch_model(model_name, vector_index):
from llama_index.core import Settings
from index_retriever import create_query_engine
try:
log_message(f"Переключение на модель: {model_name}")
new_llm = get_llm_model(model_name)
Settings.llm = new_llm
if vector_index is not None:
new_query_engine = create_query_engine(vector_index)
log_message(f"Модель успешно переключена на: {model_name}")
return new_query_engine, f"✅ Модель переключена на: {model_name}"
else:
return None, "❌ Ошибка: система не инициализирована"
except Exception as e:
error_msg = f"Ошибка переключения модели: {str(e)}"
log_message(error_msg)
return None, f"❌ {error_msg}"
retrieval_params = {
'vector_top_k': 70,
'bm25_top_k': 70,
'similarity_cutoff': 0.45,
'hybrid_top_k': 140,
'rerank_top_k': 20
}
def create_query_engine(vector_index, vector_top_k=70, bm25_top_k=70,
similarity_cutoff=0.45, hybrid_top_k=140):
try:
from config import CUSTOM_PROMPT
from index_retriever import create_query_engine as create_index_query_engine
query_engine = create_index_query_engine(
vector_index=vector_index,
vector_top_k=vector_top_k,
bm25_top_k=bm25_top_k,
similarity_cutoff=similarity_cutoff,
hybrid_top_k=hybrid_top_k
)
log_message(f"Query engine created with params: vector_top_k={vector_top_k}, "
f"bm25_top_k={bm25_top_k}, cutoff={similarity_cutoff}, hybrid_top_k={hybrid_top_k}")
return query_engine
except Exception as e:
log_message(f"Ошибка создания query engine: {str(e)}")
raise
def main_answer_question(question):
global query_engine, reranker, current_model, chunks_df, retrieval_params
if not question.strip():
return ("<div style='color: black;'>Пожалуйста, введите вопрос</div>",
"<div style='color: black;'>Источники появятся после обработки запроса</div>",
"<div style='color: black;'>Чанки появятся после обработки запроса</div>")
try:
answer_html, sources_html, chunks_html = answer_question(
question, query_engine, reranker, current_model, chunks_df,
rerank_top_k=retrieval_params['rerank_top_k']
)
return answer_html, sources_html, chunks_html
except Exception as e:
log_message(f"Ошибка при ответе на вопрос: {str(e)}")
return (f"<div style='color: red;'>Ошибка: {str(e)}</div>",
"<div style='color: black;'>Источники недоступны из-за ошибки</div>",
"<div style='color: black;'>Чанки недоступны из-за ошибки</div>")
def update_retrieval_params(vector_top_k, bm25_top_k, similarity_cutoff, hybrid_top_k, rerank_top_k):
global query_engine, vector_index, retrieval_params
try:
retrieval_params['vector_top_k'] = vector_top_k
retrieval_params['bm25_top_k'] = bm25_top_k
retrieval_params['similarity_cutoff'] = similarity_cutoff
retrieval_params['hybrid_top_k'] = hybrid_top_k
retrieval_params['rerank_top_k'] = rerank_top_k
# Recreate query engine with new parameters
if vector_index is not None:
query_engine = create_query_engine(
vector_index=vector_index,
vector_top_k=vector_top_k,
bm25_top_k=bm25_top_k,
similarity_cutoff=similarity_cutoff,
hybrid_top_k=hybrid_top_k
)
log_message(f"Параметры поиска обновлены: vector_top_k={vector_top_k}, "
f"bm25_top_k={bm25_top_k}, cutoff={similarity_cutoff}, "
f"hybrid_top_k={hybrid_top_k}, rerank_top_k={rerank_top_k}")
return f"✅ Параметры обновлены"
else:
return "❌ Система не инициализирована"
except Exception as e:
error_msg = f"Ошибка обновления параметров: {str(e)}"
log_message(error_msg)
return f"❌ {error_msg}"
def retrieve_chunks(question: str, top_k: int = 20) -> list:
from index_retriever import rerank_nodes
global query_engine, reranker
if query_engine is None:
return []
try:
retrieved_nodes = query_engine.retriever.retrieve(question)
log_message(f"Получено {len(retrieved_nodes)} узлов")
reranked_nodes = rerank_nodes(
question,
retrieved_nodes,
reranker,
top_k=top_k,
min_score_threshold=0.5
)
chunks_data = []
for i, node in enumerate(reranked_nodes):
metadata = node.metadata if hasattr(node, 'metadata') else {}
chunk = {
'rank': i + 1,
'document_id': metadata.get('document_id', 'unknown'),
'section_id': metadata.get('section_id', ''),
'section_path': metadata.get('section_path', ''),
'section_text': metadata.get('section_text', ''),
'type': metadata.get('type', 'text'),
'table_number': metadata.get('table_number', ''),
'image_number': metadata.get('image_number', ''),
'text': node.text
}
chunks_data.append(chunk)
log_message(f"Возвращено {len(chunks_data)} чанков")
return chunks_data
except Exception as e:
log_message(f"Ошибка получения чанков: {str(e)}")
return []
def create_demo_interface(answer_question_func, switch_model_func, current_model, chunk_info=None):
with gr.Blocks(title="AIEXP - AI Expert для нормативной документации", theme=gr.themes.Soft()) as demo:
gr.api(retrieve_chunks, api_name="retrieve_chunks")
gr.Markdown("""
# AIEXP - Artificial Intelligence Expert
## Инструмент для работы с нормативной документацией
""")
with gr.Tab("Поиск по нормативным документам"):
gr.Markdown("### Задайте вопрос по нормативной документации")
with gr.Row():
with gr.Column(scale=2):
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value=current_model,
label="Выберите языковую модель",
info="Выберите модель для генерации ответов"
)
with gr.Column(scale=1):
switch_btn = gr.Button("Переключить модель", variant="secondary")
model_status = gr.Textbox(
value=f"Текущая модель: {current_model}",
label="Статус модели",
interactive=False
)
with gr.Row():
with gr.Column(scale=3):
question_input = gr.Textbox(
label="Ваш вопрос к базе знаний",
placeholder="Введите вопрос по нормативным документам...",
lines=3
)
ask_btn = gr.Button("Найти ответ", variant="primary", size="lg")
gr.Examples(
examples=[
"О чем этот рисунок: ГОСТ Р 50.04.07-2022 Приложение Л. Л.1.5 Рисунок Л.2",
"Л.9 Формула в ГОСТ Р 50.04.07 - 2022 что и о чем там?",
"Какой стандарт устанавливает порядок признания протоколов испытаний продукции в области использования атомной энергии?",
"Кто несет ответственность за организацию и проведение признания протоколов испытаний продукции?",
"В каких случаях могут быть признаны протоколы испытаний, проведенные лабораториями?",
"В какой таблице можно найти информацию о методы исследований при аттестационных испытаниях технологии термической обработки заготовок из легированных сталей? Какой документ и какой раздел?"
],
inputs=question_input
)
with gr.Row():
with gr.Column(scale=2):
answer_output = gr.HTML(
label="",
value=f"<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; text-align: center;'>Здесь появится ответ на ваш вопрос...<br><small>Текущая модель: {current_model}</small></div>",
)
with gr.Column(scale=1):
sources_output = gr.HTML(
label="",
value="<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; text-align: center;'>Здесь появятся релевантные чанки...</div>",
)
with gr.Column(scale=1):
chunks_output = gr.HTML(
label="Релевантные чанки",
value="<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; text-align: center;'>Здесь появятся релевантные чанки...</div>",
)
with gr.Tab("⚙️ Параметры поиска"):
gr.Markdown("### Настройка параметров векторного поиска и переранжирования")
with gr.Row():
with gr.Column():
vector_top_k = gr.Slider(
minimum=10,
maximum=200,
value=70,
step=10,
label="Vector Top K",
info="Количество результатов из векторного поиска"
)
with gr.Column():
bm25_top_k = gr.Slider(
minimum=10,
maximum=200,
value=70,
step=10,
label="BM25 Top K",
info="Количество результатов из BM25 поиска"
)
with gr.Row():
with gr.Column():
similarity_cutoff = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.45,
step=0.05,
label="Similarity Cutoff",
info="Минимальный порог схожести для векторного поиска"
)
with gr.Column():
hybrid_top_k = gr.Slider(
minimum=10,
maximum=300,
value=140,
step=10,
label="Hybrid Top K",
info="Количество результатов из гибридного поиска"
)
with gr.Row():
with gr.Column():
rerank_top_k = gr.Slider(
minimum=5,
maximum=100,
value=20,
step=5,
label="Rerank Top K",
info="Количество результатов после переранжирования"
)
with gr.Column():
update_btn = gr.Button("Применить параметры", variant="primary")
update_status = gr.Textbox(
value="Параметры готовы к применению",
label="Статус",
interactive=False
)
gr.Markdown("""
### Рекомендации:
- **Vector Top K**: Увеличьте для более полного поиска по семантике (50-100)
- **BM25 Top K**: Увеличьте для лучшего поиска по ключевым словам (30-80)
- **Similarity Cutoff**: Снизьте для более мягких критериев (0.3-0.6), повысьте для строгих (0.7-0.9)
- **Hybrid Top K**: Объединённые результаты (100-150)
- **Rerank Top K**: Финальные результаты (10-30)
""")
update_btn.click(
fn=update_retrieval_params,
inputs=[vector_top_k, bm25_top_k, similarity_cutoff, hybrid_top_k, rerank_top_k],
outputs=[update_status]
)
gr.Markdown("### Текущие параметры:")
current_params_display = gr.Textbox(
value="Vector: 70 | BM25: 70 | Cutoff: 0.45 | Hybrid: 140 | Rerank: 20",
label="",
interactive=False,
lines=2
)
def display_current_params():
return f"""Vector Top K: {retrieval_params['vector_top_k']}
BM25 Top K: {retrieval_params['bm25_top_k']}
Similarity Cutoff: {retrieval_params['similarity_cutoff']}
Hybrid Top K: {retrieval_params['hybrid_top_k']}
Rerank Top K: {retrieval_params['rerank_top_k']}"""
demo.load(
fn=display_current_params,
outputs=[current_params_display]
)
update_btn.click(
fn=display_current_params,
outputs=[current_params_display]
)
with gr.Tab("📤 Загрузка документов"):
gr.Markdown("""
### Загрузка новых документов в систему
Выберите тип документа и загрузите файл. Система автоматически обработает и добавит его в базу знаний.
""")
# Add stats display at the top
stats_display = gr.Markdown(
value=format_stats_display(
get_repository_stats(HF_REPO_ID, HF_TOKEN, JSON_FILES_DIR,
TABLE_DATA_DIR, IMAGE_DATA_DIR)
),
label=""
)
gr.Markdown("---") # Separator
with gr.Row():
with gr.Column(scale=2):
file_type_radio = gr.Radio(
choices=["Таблица", "Изображение", "Текстовый JSON"],
value="Таблица",
label="Тип документа",
info="Выберите тип загружаемого документа"
)
file_upload = gr.File(
label="Выберите файл",
file_types=[".xlsx", ".xls", ".csv", ".json"],
type="filepath"
)
with gr.Row():
upload_btn = gr.Button("📤 Загрузить и обработать", variant="primary", size="lg")
restart_btn = gr.Button("🔄 Перезапустить систему", variant="secondary", size="lg")
upload_status = gr.Textbox(
label="Статус загрузки",
value="Ожидание загрузки файла...",
interactive=False,
lines=8
)
restart_status = gr.Textbox(
label="Статус перезапуска",
value="Система готова к работе",
interactive=False,
lines=2
)
with gr.Column(scale=1):
gr.Markdown("""
### Требования к файлам:
**Таблицы (Excel → JSON):**
- Формат: .xlsx или .xls
- Обязательные колонки:
- Номер таблицы
- Обозначение документа
- Раздел документа
- Название таблицы
**Изображения (Excel → CSV):**
- Формат: .xlsx, .xls или .csv
- Метаданные изображений
**JSON документы:**
- Формат: .json
- Структурированные данные
### Процесс загрузки:
1. Выберите тип документа
2. Загрузите файл
3. Дождитесь обработки
4. Нажмите "Перезапустить систему"
""")
upload_btn.click(
fn=process_uploaded_file,
inputs=[file_upload, file_type_radio],
outputs=[upload_status]
)
restart_btn.click(
fn=restart_system,
inputs=[],
outputs=[restart_status, stats_display]
)
switch_btn.click(
fn=switch_model_func,
inputs=[model_dropdown],
outputs=[model_status]
)
ask_btn.click(
fn=answer_question_func,
inputs=[question_input],
outputs=[answer_output, sources_output, chunks_output]
)
question_input.submit(
fn=answer_question_func,
inputs=[question_input],
outputs=[answer_output, sources_output, chunks_output]
)
return demo
query_engine = None
chunks_df = None
reranker = None
vector_index = None
current_model = DEFAULT_MODEL
def main_switch_model(model_name):
global query_engine, vector_index, current_model
new_query_engine, status_message = switch_model(model_name, vector_index)
if new_query_engine:
query_engine = new_query_engine
current_model = model_name
return status_message
def main():
global query_engine, chunks_df, reranker, vector_index, current_model
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")
if GOOGLE_API_KEY:
log_message("Использование Google API для модели генерации текста")
else:
log_message("Google API ключ не найден, использование локальной модели")
log_message("Запуск AIEXP - AI Expert для нормативной документации")
query_engine, chunks_df, reranker, vector_index, chunk_info = initialize_system(
repo_id=HF_REPO_ID,
hf_token=HF_TOKEN,
download_dir=DOWNLOAD_DIR,
json_files_dir=JSON_FILES_DIR,
table_data_dir=TABLE_DATA_DIR,
image_data_dir=IMAGE_DATA_DIR,
use_json_instead_csv=True,
)
if query_engine:
log_message("Запуск веб-интерфейса")
demo = create_demo_interface(
answer_question_func=main_answer_question,
switch_model_func=main_switch_model,
current_model=current_model,
chunk_info=chunk_info
)
demo.api = "retrieve_chunks"
demo.queue()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
debug=False
)
else:
log_message("Невозможно запустить приложение из-за ошибки инициализации")
sys.exit(1)
if __name__ == "__main__":
main()