Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| from llama_index.core import Settings | |
| from documents_prep import load_json_documents, load_table_documents, load_image_documents | |
| from my_logging import log_message | |
| from index_retriever import create_vector_index, create_query_engine | |
| import sys | |
| from config import ( | |
| HF_REPO_ID, HF_TOKEN, DOWNLOAD_DIR, CHUNKS_FILENAME, | |
| JSON_FILES_DIR, TABLE_DATA_DIR, IMAGE_DATA_DIR, DEFAULT_MODEL, AVAILABLE_MODELS | |
| ) | |
| from converters.converter import process_uploaded_file, convert_single_excel_to_json, convert_single_excel_to_csv | |
| from main_utils import * | |
| def restart_system(): | |
| """Перезапуск системы для применения новых документов""" | |
| global query_engine, chunks_df, reranker, vector_index, current_model | |
| try: | |
| log_message("Начало перезапуска системы...") | |
| log_message("Очистка кэша HuggingFace...") | |
| import shutil | |
| cache_dir = os.path.expanduser("~/.cache/huggingface/hub") | |
| if os.path.exists(cache_dir): | |
| try: | |
| shutil.rmtree(cache_dir) | |
| log_message("✓ Кэш очищен") | |
| except: | |
| log_message("⚠ Не удалось очистить кэш полностью") | |
| query_engine, chunks_df, reranker, vector_index, chunk_info = initialize_system( | |
| repo_id=HF_REPO_ID, | |
| hf_token=HF_TOKEN, | |
| download_dir=DOWNLOAD_DIR, | |
| json_files_dir=JSON_FILES_DIR, | |
| table_data_dir=TABLE_DATA_DIR, | |
| image_data_dir=IMAGE_DATA_DIR, | |
| use_json_instead_csv=True, | |
| ) | |
| if query_engine: | |
| # Get updated stats | |
| stats = get_repository_stats(HF_REPO_ID, HF_TOKEN, JSON_FILES_DIR, | |
| TABLE_DATA_DIR, IMAGE_DATA_DIR) | |
| stats_display = format_stats_display(stats) | |
| log_message("Система успешно перезапущена") | |
| return "✅ Система успешно перезапущена! Новые документы загружены.", stats_display | |
| else: | |
| return "❌ Ошибка при перезапуске системы", "Статистика недоступна" | |
| except Exception as e: | |
| error_msg = f"Ошибка перезапуска: {str(e)}" | |
| log_message(error_msg) | |
| return f"❌ {error_msg}", "Статистика недоступна" | |
| def initialize_system(repo_id, hf_token, download_dir, chunks_filename=None, | |
| json_files_dir=None, table_data_dir=None, image_data_dir=None, | |
| use_json_instead_csv=False): | |
| try: | |
| log_message("Инициализация системы") | |
| os.makedirs(download_dir, exist_ok=True) | |
| from config import CHUNK_SIZE, CHUNK_OVERLAP | |
| from llama_index.core.text_splitter import TokenTextSplitter | |
| embed_model = get_embedding_model() | |
| llm = get_llm_model(DEFAULT_MODEL) | |
| reranker = get_reranker_model() | |
| Settings.embed_model = embed_model | |
| Settings.llm = llm | |
| Settings.text_splitter = TokenTextSplitter( | |
| chunk_size=CHUNK_SIZE, | |
| chunk_overlap=CHUNK_OVERLAP, | |
| separator=" ", | |
| backup_separators=["\n", ".", "!", "?"] | |
| ) | |
| all_documents = [] | |
| chunks_df = None | |
| if use_json_instead_csv and json_files_dir: | |
| log_message("Используем JSON файлы вместо CSV") | |
| from documents_prep import load_all_documents | |
| all_documents = load_all_documents( | |
| repo_id=repo_id, | |
| hf_token=hf_token, | |
| json_dir=json_files_dir, | |
| table_dir=table_data_dir if table_data_dir else "", | |
| image_dir=image_data_dir if image_data_dir else "" | |
| ) | |
| else: | |
| if chunks_filename: | |
| log_message("Загружаем данные из CSV") | |
| if table_data_dir: | |
| from documents_prep import load_table_documents | |
| table_chunks = load_table_documents(repo_id, hf_token, table_data_dir) | |
| log_message(f"Загружено {len(table_chunks)} табличных чанков") | |
| all_documents.extend(table_chunks) | |
| if image_data_dir: | |
| from documents_prep import load_image_documents | |
| image_documents = load_image_documents(repo_id, hf_token, image_data_dir) | |
| log_message(f"Загружено {len(image_documents)} документов изображений") | |
| all_documents.extend(image_documents) | |
| log_message(f"Всего документов после всей обработки: {len(all_documents)}") | |
| vector_index = create_vector_index(all_documents) | |
| query_engine = create_query_engine(vector_index) | |
| chunk_info = [] | |
| for doc in all_documents: | |
| chunk_info.append({ | |
| 'document_id': doc.metadata.get('document_id', 'unknown'), | |
| 'section_id': doc.metadata.get('section_id', 'unknown'), | |
| 'type': doc.metadata.get('type', 'text'), | |
| 'chunk_text': doc.text[:200] + '...' if len(doc.text) > 200 else doc.text, | |
| 'table_number': doc.metadata.get('table_number', ''), | |
| 'image_number': doc.metadata.get('image_number', ''), | |
| 'section': doc.metadata.get('section', ''), | |
| 'connection_type': doc.metadata.get('connection_type', '') | |
| }) | |
| log_message(f"Система успешно инициализирована") | |
| return query_engine, chunks_df, reranker, vector_index, chunk_info | |
| except Exception as e: | |
| log_message(f"Ошибка инициализации: {str(e)}") | |
| import traceback | |
| log_message(traceback.format_exc()) | |
| return None, None, None, None, [] | |
| def switch_model(model_name, vector_index): | |
| from llama_index.core import Settings | |
| from index_retriever import create_query_engine | |
| try: | |
| log_message(f"Переключение на модель: {model_name}") | |
| new_llm = get_llm_model(model_name) | |
| Settings.llm = new_llm | |
| if vector_index is not None: | |
| new_query_engine = create_query_engine(vector_index) | |
| log_message(f"Модель успешно переключена на: {model_name}") | |
| return new_query_engine, f"✅ Модель переключена на: {model_name}" | |
| else: | |
| return None, "❌ Ошибка: система не инициализирована" | |
| except Exception as e: | |
| error_msg = f"Ошибка переключения модели: {str(e)}" | |
| log_message(error_msg) | |
| return None, f"❌ {error_msg}" | |
| retrieval_params = { | |
| 'vector_top_k': 70, | |
| 'bm25_top_k': 70, | |
| 'similarity_cutoff': 0.45, | |
| 'hybrid_top_k': 140, | |
| 'rerank_top_k': 20 | |
| } | |
| def create_query_engine(vector_index, vector_top_k=70, bm25_top_k=70, | |
| similarity_cutoff=0.45, hybrid_top_k=140): | |
| try: | |
| from config import CUSTOM_PROMPT | |
| from index_retriever import create_query_engine as create_index_query_engine | |
| query_engine = create_index_query_engine( | |
| vector_index=vector_index, | |
| vector_top_k=vector_top_k, | |
| bm25_top_k=bm25_top_k, | |
| similarity_cutoff=similarity_cutoff, | |
| hybrid_top_k=hybrid_top_k | |
| ) | |
| log_message(f"Query engine created with params: vector_top_k={vector_top_k}, " | |
| f"bm25_top_k={bm25_top_k}, cutoff={similarity_cutoff}, hybrid_top_k={hybrid_top_k}") | |
| return query_engine | |
| except Exception as e: | |
| log_message(f"Ошибка создания query engine: {str(e)}") | |
| raise | |
| def main_answer_question(question): | |
| global query_engine, reranker, current_model, chunks_df, retrieval_params | |
| if not question.strip(): | |
| return ("<div style='color: black;'>Пожалуйста, введите вопрос</div>", | |
| "<div style='color: black;'>Источники появятся после обработки запроса</div>", | |
| "<div style='color: black;'>Чанки появятся после обработки запроса</div>") | |
| try: | |
| answer_html, sources_html, chunks_html = answer_question( | |
| question, query_engine, reranker, current_model, chunks_df, | |
| rerank_top_k=retrieval_params['rerank_top_k'] | |
| ) | |
| return answer_html, sources_html, chunks_html | |
| except Exception as e: | |
| log_message(f"Ошибка при ответе на вопрос: {str(e)}") | |
| return (f"<div style='color: red;'>Ошибка: {str(e)}</div>", | |
| "<div style='color: black;'>Источники недоступны из-за ошибки</div>", | |
| "<div style='color: black;'>Чанки недоступны из-за ошибки</div>") | |
| def update_retrieval_params(vector_top_k, bm25_top_k, similarity_cutoff, hybrid_top_k, rerank_top_k): | |
| global query_engine, vector_index, retrieval_params | |
| try: | |
| retrieval_params['vector_top_k'] = vector_top_k | |
| retrieval_params['bm25_top_k'] = bm25_top_k | |
| retrieval_params['similarity_cutoff'] = similarity_cutoff | |
| retrieval_params['hybrid_top_k'] = hybrid_top_k | |
| retrieval_params['rerank_top_k'] = rerank_top_k | |
| # Recreate query engine with new parameters | |
| if vector_index is not None: | |
| query_engine = create_query_engine( | |
| vector_index=vector_index, | |
| vector_top_k=vector_top_k, | |
| bm25_top_k=bm25_top_k, | |
| similarity_cutoff=similarity_cutoff, | |
| hybrid_top_k=hybrid_top_k | |
| ) | |
| log_message(f"Параметры поиска обновлены: vector_top_k={vector_top_k}, " | |
| f"bm25_top_k={bm25_top_k}, cutoff={similarity_cutoff}, " | |
| f"hybrid_top_k={hybrid_top_k}, rerank_top_k={rerank_top_k}") | |
| return f"✅ Параметры обновлены" | |
| else: | |
| return "❌ Система не инициализирована" | |
| except Exception as e: | |
| error_msg = f"Ошибка обновления параметров: {str(e)}" | |
| log_message(error_msg) | |
| return f"❌ {error_msg}" | |
| def retrieve_chunks(question: str, top_k: int = 20) -> list: | |
| from index_retriever import rerank_nodes | |
| global query_engine, reranker | |
| if query_engine is None: | |
| return [] | |
| try: | |
| retrieved_nodes = query_engine.retriever.retrieve(question) | |
| log_message(f"Получено {len(retrieved_nodes)} узлов") | |
| reranked_nodes = rerank_nodes( | |
| question, | |
| retrieved_nodes, | |
| reranker, | |
| top_k=top_k, | |
| min_score_threshold=0.5 | |
| ) | |
| chunks_data = [] | |
| for i, node in enumerate(reranked_nodes): | |
| metadata = node.metadata if hasattr(node, 'metadata') else {} | |
| chunk = { | |
| 'rank': i + 1, | |
| 'document_id': metadata.get('document_id', 'unknown'), | |
| 'section_id': metadata.get('section_id', ''), | |
| 'section_path': metadata.get('section_path', ''), | |
| 'section_text': metadata.get('section_text', ''), | |
| 'type': metadata.get('type', 'text'), | |
| 'table_number': metadata.get('table_number', ''), | |
| 'image_number': metadata.get('image_number', ''), | |
| 'text': node.text | |
| } | |
| chunks_data.append(chunk) | |
| log_message(f"Возвращено {len(chunks_data)} чанков") | |
| return chunks_data | |
| except Exception as e: | |
| log_message(f"Ошибка получения чанков: {str(e)}") | |
| return [] | |
| def create_demo_interface(answer_question_func, switch_model_func, current_model, chunk_info=None): | |
| with gr.Blocks(title="AIEXP - AI Expert для нормативной документации", theme=gr.themes.Soft()) as demo: | |
| gr.api(retrieve_chunks, api_name="retrieve_chunks") | |
| gr.Markdown(""" | |
| # AIEXP - Artificial Intelligence Expert | |
| ## Инструмент для работы с нормативной документацией | |
| """) | |
| with gr.Tab("Поиск по нормативным документам"): | |
| gr.Markdown("### Задайте вопрос по нормативной документации") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(AVAILABLE_MODELS.keys()), | |
| value=current_model, | |
| label="Выберите языковую модель", | |
| info="Выберите модель для генерации ответов" | |
| ) | |
| with gr.Column(scale=1): | |
| switch_btn = gr.Button("Переключить модель", variant="secondary") | |
| model_status = gr.Textbox( | |
| value=f"Текущая модель: {current_model}", | |
| label="Статус модели", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| question_input = gr.Textbox( | |
| label="Ваш вопрос к базе знаний", | |
| placeholder="Введите вопрос по нормативным документам...", | |
| lines=3 | |
| ) | |
| ask_btn = gr.Button("Найти ответ", variant="primary", size="lg") | |
| gr.Examples( | |
| examples=[ | |
| "О чем этот рисунок: ГОСТ Р 50.04.07-2022 Приложение Л. Л.1.5 Рисунок Л.2", | |
| "Л.9 Формула в ГОСТ Р 50.04.07 - 2022 что и о чем там?", | |
| "Какой стандарт устанавливает порядок признания протоколов испытаний продукции в области использования атомной энергии?", | |
| "Кто несет ответственность за организацию и проведение признания протоколов испытаний продукции?", | |
| "В каких случаях могут быть признаны протоколы испытаний, проведенные лабораториями?", | |
| "В какой таблице можно найти информацию о методы исследований при аттестационных испытаниях технологии термической обработки заготовок из легированных сталей? Какой документ и какой раздел?" | |
| ], | |
| inputs=question_input | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| answer_output = gr.HTML( | |
| label="", | |
| value=f"<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; text-align: center;'>Здесь появится ответ на ваш вопрос...<br><small>Текущая модель: {current_model}</small></div>", | |
| ) | |
| with gr.Column(scale=1): | |
| sources_output = gr.HTML( | |
| label="", | |
| value="<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; text-align: center;'>Здесь появятся релевантные чанки...</div>", | |
| ) | |
| with gr.Column(scale=1): | |
| chunks_output = gr.HTML( | |
| label="Релевантные чанки", | |
| value="<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; text-align: center;'>Здесь появятся релевантные чанки...</div>", | |
| ) | |
| with gr.Tab("⚙️ Параметры поиска"): | |
| gr.Markdown("### Настройка параметров векторного поиска и переранжирования") | |
| with gr.Row(): | |
| with gr.Column(): | |
| vector_top_k = gr.Slider( | |
| minimum=10, | |
| maximum=200, | |
| value=70, | |
| step=10, | |
| label="Vector Top K", | |
| info="Количество результатов из векторного поиска" | |
| ) | |
| with gr.Column(): | |
| bm25_top_k = gr.Slider( | |
| minimum=10, | |
| maximum=200, | |
| value=70, | |
| step=10, | |
| label="BM25 Top K", | |
| info="Количество результатов из BM25 поиска" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| similarity_cutoff = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.45, | |
| step=0.05, | |
| label="Similarity Cutoff", | |
| info="Минимальный порог схожести для векторного поиска" | |
| ) | |
| with gr.Column(): | |
| hybrid_top_k = gr.Slider( | |
| minimum=10, | |
| maximum=300, | |
| value=140, | |
| step=10, | |
| label="Hybrid Top K", | |
| info="Количество результатов из гибридного поиска" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| rerank_top_k = gr.Slider( | |
| minimum=5, | |
| maximum=100, | |
| value=20, | |
| step=5, | |
| label="Rerank Top K", | |
| info="Количество результатов после переранжирования" | |
| ) | |
| with gr.Column(): | |
| update_btn = gr.Button("Применить параметры", variant="primary") | |
| update_status = gr.Textbox( | |
| value="Параметры готовы к применению", | |
| label="Статус", | |
| interactive=False | |
| ) | |
| gr.Markdown(""" | |
| ### Рекомендации: | |
| - **Vector Top K**: Увеличьте для более полного поиска по семантике (50-100) | |
| - **BM25 Top K**: Увеличьте для лучшего поиска по ключевым словам (30-80) | |
| - **Similarity Cutoff**: Снизьте для более мягких критериев (0.3-0.6), повысьте для строгих (0.7-0.9) | |
| - **Hybrid Top K**: Объединённые результаты (100-150) | |
| - **Rerank Top K**: Финальные результаты (10-30) | |
| """) | |
| update_btn.click( | |
| fn=update_retrieval_params, | |
| inputs=[vector_top_k, bm25_top_k, similarity_cutoff, hybrid_top_k, rerank_top_k], | |
| outputs=[update_status] | |
| ) | |
| gr.Markdown("### Текущие параметры:") | |
| current_params_display = gr.Textbox( | |
| value="Vector: 70 | BM25: 70 | Cutoff: 0.45 | Hybrid: 140 | Rerank: 20", | |
| label="", | |
| interactive=False, | |
| lines=2 | |
| ) | |
| def display_current_params(): | |
| return f"""Vector Top K: {retrieval_params['vector_top_k']} | |
| BM25 Top K: {retrieval_params['bm25_top_k']} | |
| Similarity Cutoff: {retrieval_params['similarity_cutoff']} | |
| Hybrid Top K: {retrieval_params['hybrid_top_k']} | |
| Rerank Top K: {retrieval_params['rerank_top_k']}""" | |
| demo.load( | |
| fn=display_current_params, | |
| outputs=[current_params_display] | |
| ) | |
| update_btn.click( | |
| fn=display_current_params, | |
| outputs=[current_params_display] | |
| ) | |
| with gr.Tab("📤 Загрузка документов"): | |
| gr.Markdown(""" | |
| ### Загрузка новых документов в систему | |
| Выберите тип документа и загрузите файл. Система автоматически обработает и добавит его в базу знаний. | |
| """) | |
| # Add stats display at the top | |
| stats_display = gr.Markdown( | |
| value=format_stats_display( | |
| get_repository_stats(HF_REPO_ID, HF_TOKEN, JSON_FILES_DIR, | |
| TABLE_DATA_DIR, IMAGE_DATA_DIR) | |
| ), | |
| label="" | |
| ) | |
| gr.Markdown("---") # Separator | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| file_type_radio = gr.Radio( | |
| choices=["Таблица", "Изображение", "Текстовый JSON"], | |
| value="Таблица", | |
| label="Тип документа", | |
| info="Выберите тип загружаемого документа" | |
| ) | |
| file_upload = gr.File( | |
| label="Выберите файл", | |
| file_types=[".xlsx", ".xls", ".csv", ".json"], | |
| type="filepath" | |
| ) | |
| with gr.Row(): | |
| upload_btn = gr.Button("📤 Загрузить и обработать", variant="primary", size="lg") | |
| restart_btn = gr.Button("🔄 Перезапустить систему", variant="secondary", size="lg") | |
| upload_status = gr.Textbox( | |
| label="Статус загрузки", | |
| value="Ожидание загрузки файла...", | |
| interactive=False, | |
| lines=8 | |
| ) | |
| restart_status = gr.Textbox( | |
| label="Статус перезапуска", | |
| value="Система готова к работе", | |
| interactive=False, | |
| lines=2 | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown(""" | |
| ### Требования к файлам: | |
| **Таблицы (Excel → JSON):** | |
| - Формат: .xlsx или .xls | |
| - Обязательные колонки: | |
| - Номер таблицы | |
| - Обозначение документа | |
| - Раздел документа | |
| - Название таблицы | |
| **Изображения (Excel → CSV):** | |
| - Формат: .xlsx, .xls или .csv | |
| - Метаданные изображений | |
| **JSON документы:** | |
| - Формат: .json | |
| - Структурированные данные | |
| ### Процесс загрузки: | |
| 1. Выберите тип документа | |
| 2. Загрузите файл | |
| 3. Дождитесь обработки | |
| 4. Нажмите "Перезапустить систему" | |
| """) | |
| upload_btn.click( | |
| fn=process_uploaded_file, | |
| inputs=[file_upload, file_type_radio], | |
| outputs=[upload_status] | |
| ) | |
| restart_btn.click( | |
| fn=restart_system, | |
| inputs=[], | |
| outputs=[restart_status, stats_display] | |
| ) | |
| switch_btn.click( | |
| fn=switch_model_func, | |
| inputs=[model_dropdown], | |
| outputs=[model_status] | |
| ) | |
| ask_btn.click( | |
| fn=answer_question_func, | |
| inputs=[question_input], | |
| outputs=[answer_output, sources_output, chunks_output] | |
| ) | |
| question_input.submit( | |
| fn=answer_question_func, | |
| inputs=[question_input], | |
| outputs=[answer_output, sources_output, chunks_output] | |
| ) | |
| return demo | |
| query_engine = None | |
| chunks_df = None | |
| reranker = None | |
| vector_index = None | |
| current_model = DEFAULT_MODEL | |
| def main_switch_model(model_name): | |
| global query_engine, vector_index, current_model | |
| new_query_engine, status_message = switch_model(model_name, vector_index) | |
| if new_query_engine: | |
| query_engine = new_query_engine | |
| current_model = model_name | |
| return status_message | |
| def main(): | |
| global query_engine, chunks_df, reranker, vector_index, current_model | |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "") | |
| if GOOGLE_API_KEY: | |
| log_message("Использование Google API для модели генерации текста") | |
| else: | |
| log_message("Google API ключ не найден, использование локальной модели") | |
| log_message("Запуск AIEXP - AI Expert для нормативной документации") | |
| query_engine, chunks_df, reranker, vector_index, chunk_info = initialize_system( | |
| repo_id=HF_REPO_ID, | |
| hf_token=HF_TOKEN, | |
| download_dir=DOWNLOAD_DIR, | |
| json_files_dir=JSON_FILES_DIR, | |
| table_data_dir=TABLE_DATA_DIR, | |
| image_data_dir=IMAGE_DATA_DIR, | |
| use_json_instead_csv=True, | |
| ) | |
| if query_engine: | |
| log_message("Запуск веб-интерфейса") | |
| demo = create_demo_interface( | |
| answer_question_func=main_answer_question, | |
| switch_model_func=main_switch_model, | |
| current_model=current_model, | |
| chunk_info=chunk_info | |
| ) | |
| demo.api = "retrieve_chunks" | |
| demo.queue() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| debug=False | |
| ) | |
| else: | |
| log_message("Невозможно запустить приложение из-за ошибки инициализации") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() |