Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| import faiss | |
| import pandas as pd | |
| import os | |
| import json | |
| from llama_index.core import Document, VectorStoreIndex, Settings | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from llama_index.llms.google_genai import GoogleGenAI | |
| from llama_index.llms.openai import OpenAI | |
| from llama_index.core.query_engine import RetrieverQueryEngine | |
| from llama_index.core.retrievers import VectorIndexRetriever | |
| from llama_index.core.response_synthesizers import get_response_synthesizer, ResponseMode | |
| from llama_index.core.prompts import PromptTemplate | |
| from llama_index.retrievers.bm25 import BM25Retriever | |
| from sentence_transformers import CrossEncoder | |
| from llama_index.core.retrievers import QueryFusionRetriever | |
| import time | |
| import sys | |
| import logging | |
| from config import * | |
| REPO_ID = "MrSimple01/AIEXP_RAG_FILES" | |
| faiss_index_filename = "cleaned_faiss_index.index" | |
| chunks_filename = "processed_chunks.csv" | |
| table_data_dir = "Табличные данные_JSON" | |
| image_data_dir = "Изображения" | |
| download_dir = "rag_files" | |
| HF_TOKEN = os.getenv('HF_TOKEN') | |
| # Global variables | |
| query_engine = None | |
| chunks_df = None | |
| reranker = None | |
| vector_index = None | |
| current_model = DEFAULT_MODEL | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| def log_message(message): | |
| logger.info(message) | |
| print(message, flush=True) | |
| sys.stdout.flush() | |
| def get_llm_model(model_name): | |
| """Get LLM model instance based on model name""" | |
| try: | |
| model_config = AVAILABLE_MODELS.get(model_name) | |
| if not model_config: | |
| log_message(f"Модель {model_name} не найдена, использую модель по умолчанию") | |
| model_config = AVAILABLE_MODELS[DEFAULT_MODEL] | |
| if not model_config.get("api_key"): | |
| raise Exception(f"API ключ не найден для модели {model_name}") | |
| if model_config["provider"] == "google": | |
| return GoogleGenAI( | |
| model=model_config["model_name"], | |
| api_key=model_config["api_key"] | |
| ) | |
| elif model_config["provider"] == "openai": | |
| return OpenAI( | |
| model=model_config["model_name"], | |
| api_key=model_config["api_key"] | |
| ) | |
| else: | |
| raise Exception(f"Неподдерживаемый провайдер: {model_config['provider']}") | |
| except Exception as e: | |
| log_message(f"Ошибка создания модели {model_name}: {str(e)}") | |
| # Fallback to default Google model | |
| return GoogleGenAI(model="gemini-2.0-flash", api_key=GOOGLE_API_KEY) | |
| def switch_model(model_name): | |
| """Switch to a different LLM model""" | |
| global query_engine, current_model | |
| try: | |
| log_message(f"Переключение на модель: {model_name}") | |
| # Create new LLM instance | |
| new_llm = get_llm_model(model_name) | |
| Settings.llm = new_llm | |
| # Recreate query engine with new model | |
| if vector_index is not None: | |
| recreate_query_engine() | |
| current_model = model_name | |
| log_message(f"Модель успешно переключена на: {model_name}") | |
| return f"✅ Модель переключена на: {model_name}" | |
| else: | |
| return "❌ Ошибка: система не инициализирована" | |
| except Exception as e: | |
| error_msg = f"Ошибка переключения модели: {str(e)}" | |
| log_message(error_msg) | |
| return f"❌ {error_msg}" | |
| def recreate_query_engine(): | |
| """Recreate query engine with current settings""" | |
| global query_engine | |
| try: | |
| # Create BM25 retriever | |
| bm25_retriever = BM25Retriever.from_defaults( | |
| docstore=vector_index.docstore, | |
| similarity_top_k=15 | |
| ) | |
| # Create vector retriever | |
| vector_retriever = VectorIndexRetriever( | |
| index=vector_index, | |
| similarity_top_k=20, | |
| similarity_cutoff=0.5 | |
| ) | |
| # Create hybrid retriever | |
| hybrid_retriever = QueryFusionRetriever( | |
| [vector_retriever, bm25_retriever], | |
| similarity_top_k=30, | |
| num_queries=1 | |
| ) | |
| # Create response synthesizer | |
| custom_prompt_template = PromptTemplate(CUSTOM_PROMPT) | |
| response_synthesizer = get_response_synthesizer( | |
| response_mode=ResponseMode.TREE_SUMMARIZE, | |
| text_qa_template=custom_prompt_template | |
| ) | |
| # Create new query engine | |
| query_engine = RetrieverQueryEngine( | |
| retriever=hybrid_retriever, | |
| response_synthesizer=response_synthesizer | |
| ) | |
| log_message("Query engine успешно пересоздан") | |
| except Exception as e: | |
| log_message(f"Ошибка пересоздания query engine: {str(e)}") | |
| raise | |
| def table_to_document(table_data, document_id=None): | |
| content = "" | |
| if isinstance(table_data, dict): | |
| doc_id = document_id or table_data.get('document_id', table_data.get('document', 'Неизвестно')) | |
| table_num = table_data.get('table_number', 'Неизвестно') | |
| table_title = table_data.get('table_title', 'Неизвестно') | |
| section = table_data.get('section', 'Неизвестно') | |
| content += f"Таблица: {table_num}\n" | |
| content += f"Название: {table_title}\n" | |
| content += f"Документ: {doc_id}\n" | |
| content += f"Раздел: {section}\n" | |
| if 'data' in table_data and isinstance(table_data['data'], list): | |
| for row in table_data['data']: | |
| if isinstance(row, dict): | |
| row_text = " | ".join([f"{k}: {v}" for k, v in row.items()]) | |
| content += f"{row_text}\n" | |
| return Document( | |
| text=content, | |
| metadata={ | |
| "type": "table", | |
| "table_number": table_data.get('table_number', 'unknown'), | |
| "table_title": table_data.get('table_title', 'unknown'), | |
| "document_id": doc_id or table_data.get('document_id', table_data.get('document', 'unknown')), | |
| "section": table_data.get('section', 'unknown') | |
| } | |
| ) | |
| def download_table_data(): | |
| log_message("Начинаю загрузку табличных данных") | |
| table_files = [] | |
| try: | |
| files = list_repo_files(repo_id=REPO_ID, repo_type="dataset", token=HF_TOKEN) | |
| for file in files: | |
| if file.startswith(table_data_dir) and file.endswith('.json'): | |
| table_files.append(file) | |
| log_message(f"Найдено {len(table_files)} JSON файлов с таблицами") | |
| table_documents = [] | |
| for file_path in table_files: | |
| try: | |
| log_message(f"Обрабатываю файл: {file_path}") | |
| local_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=file_path, | |
| local_dir='', | |
| repo_type="dataset", | |
| token=HF_TOKEN | |
| ) | |
| with open(local_path, 'r', encoding='utf-8') as f: | |
| table_data = json.load(f) | |
| if isinstance(table_data, dict): | |
| document_id = table_data.get('document', 'unknown') | |
| if 'sheets' in table_data: | |
| for sheet in table_data['sheets']: | |
| sheet['document'] = document_id | |
| doc = table_to_document(sheet, document_id) | |
| table_documents.append(doc) | |
| else: | |
| doc = table_to_document(table_data, document_id) | |
| table_documents.append(doc) | |
| elif isinstance(table_data, list): | |
| for table_json in table_data: | |
| doc = table_to_document(table_json) | |
| table_documents.append(doc) | |
| except Exception as e: | |
| log_message(f"Ошибка обработки файла {file_path}: {str(e)}") | |
| continue | |
| log_message(f"Создано {len(table_documents)} документов из таблиц") | |
| return table_documents | |
| except Exception as e: | |
| log_message(f"Ошибка загрузки табличных данных: {str(e)}") | |
| return [] | |
| def download_image_data(): | |
| log_message("Начинаю загрузку данных изображений") | |
| image_files = [] | |
| try: | |
| files = list_repo_files(repo_id=REPO_ID, repo_type="dataset", token=HF_TOKEN) | |
| for file in files: | |
| if file.startswith(image_data_dir) and file.endswith('.csv'): | |
| image_files.append(file) | |
| log_message(f"Найдено {len(image_files)} CSV файлов с изображениями") | |
| image_documents = [] | |
| for file_path in image_files: | |
| try: | |
| log_message(f"Обрабатываю файл изображений: {file_path}") | |
| local_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=file_path, | |
| local_dir='', | |
| repo_type="dataset", | |
| token=HF_TOKEN | |
| ) | |
| df = pd.read_csv(local_path) | |
| log_message(f"Загружено {len(df)} записей изображений из файла {file_path}") | |
| for _, row in df.iterrows(): | |
| content = f"Изображение: {row.get('№ Изображения', 'Неизвестно')}\n" | |
| content += f"Название: {row.get('Название изображения', 'Неизвестно')}\n" | |
| content += f"Описание: {row.get('Описание изображение', 'Неизвестно')}\n" | |
| content += f"Документ: {row.get('Обозначение документа', 'Неизвестно')}\n" | |
| content += f"Раздел: {row.get('Раздел документа', 'Неизвестно')}\n" | |
| content += f"Файл: {row.get('Файл изображения', 'Неизвестно')}\n" | |
| doc = Document( | |
| text=content, | |
| metadata={ | |
| "type": "image", | |
| "image_number": row.get('№ Изображения', 'unknown'), | |
| "document_id": row.get('Обозначение документа', 'unknown'), | |
| "file_path": row.get('Файл изображения', 'unknown'), | |
| "section": row.get('Раздел документа', 'unknown') | |
| } | |
| ) | |
| image_documents.append(doc) | |
| except Exception as e: | |
| log_message(f"Ошибка обработки файла {file_path}: {str(e)}") | |
| continue | |
| log_message(f"Создано {len(image_documents)} документов из изображений") | |
| return image_documents | |
| except Exception as e: | |
| log_message(f"Ошибка загрузки данных изображений: {str(e)}") | |
| return [] | |
| def initialize_models(): | |
| global query_engine, chunks_df, reranker, vector_index, current_model | |
| try: | |
| log_message("Инициализация системы") | |
| os.makedirs(download_dir, exist_ok=True) | |
| log_message("Загружаю основные файлы") | |
| chunks_csv_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=chunks_filename, | |
| local_dir=download_dir, | |
| repo_type="dataset", | |
| token=HF_TOKEN | |
| ) | |
| log_message("Загружаю данные чанков") | |
| chunks_df = pd.read_csv(chunks_csv_path) | |
| log_message(f"Загружено {len(chunks_df)} чанков") | |
| log_message("Инициализирую модели") | |
| embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") | |
| llm = get_llm_model(current_model) | |
| log_message("Инициализирую переранкер") | |
| reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2') | |
| Settings.embed_model = embed_model | |
| Settings.llm = llm | |
| text_column = None | |
| for col in chunks_df.columns: | |
| if 'text' in col.lower() or 'content' in col.lower() or 'chunk' in col.lower(): | |
| text_column = col | |
| break | |
| if text_column is None: | |
| text_column = chunks_df.columns[0] | |
| log_message(f"Использую колонку: {text_column}") | |
| log_message("Создаю документы из чанков") | |
| documents = [] | |
| for i, (_, row) in enumerate(chunks_df.iterrows()): | |
| doc = Document( | |
| text=str(row[text_column]), | |
| metadata={ | |
| "chunk_id": row.get('chunk_id', i), | |
| "document_id": row.get('document_id', 'unknown'), | |
| "type": "text" | |
| } | |
| ) | |
| documents.append(doc) | |
| log_message(f"Создано {len(documents)} текстовых документов") | |
| log_message("Добавляю табличные данные") | |
| table_documents = download_table_data() | |
| documents.extend(table_documents) | |
| log_message("Добавляю данные изображений") | |
| image_documents = download_image_data() | |
| documents.extend(image_documents) | |
| log_message(f"Всего документов: {len(documents)}") | |
| log_message("Строю векторный индекс") | |
| vector_index = VectorStoreIndex.from_documents(documents) | |
| # Create query engine | |
| recreate_query_engine() | |
| log_message(f"Система успешно инициализирована с моделью: {current_model}") | |
| return True | |
| except Exception as e: | |
| log_message(f"Ошибка инициализации: {str(e)}") | |
| return False | |
| def rerank_nodes(query, nodes, top_k=10): | |
| if not nodes or not reranker: | |
| return nodes[:top_k] | |
| try: | |
| log_message(f"Переранжирую {len(nodes)} узлов") | |
| pairs = [] | |
| for node in nodes: | |
| pairs.append([query, node.text]) | |
| scores = reranker.predict(pairs) | |
| scored_nodes = list(zip(nodes, scores)) | |
| scored_nodes.sort(key=lambda x: x[1], reverse=True) | |
| reranked_nodes = [node for node, score in scored_nodes[:top_k]] | |
| log_message(f"Возвращаю топ-{len(reranked_nodes)} переранжированных узлов") | |
| return reranked_nodes | |
| except Exception as e: | |
| log_message(f"Ошибка переранжировки: {str(e)}") | |
| return nodes[:top_k] | |
| def answer_question(question): | |
| global query_engine, chunks_df, current_model | |
| if query_engine is None: | |
| return "<div style='background-color: #e53e3e; color: white; padding: 20px; border-radius: 10px;'>Система не инициализирована</div>", "" | |
| try: | |
| log_message(f"Получен вопрос: {question}") | |
| log_message(f"Используется модель: {current_model}") | |
| start_time = time.time() | |
| log_message("Извлекаю релевантные узлы") | |
| retrieved_nodes = query_engine.retriever.retrieve(question) | |
| log_message(f"Извлечено {len(retrieved_nodes)} узлов") | |
| log_message("Применяю переранжировку") | |
| reranked_nodes = rerank_nodes(question, retrieved_nodes, top_k=10) | |
| log_message(f"Отправляю запрос в LLM с {len(reranked_nodes)} узлами") | |
| response = query_engine.query(question) | |
| end_time = time.time() | |
| processing_time = end_time - start_time | |
| log_message(f"Обработка завершена за {processing_time:.2f} секунд") | |
| sources_html = generate_sources_html(reranked_nodes) | |
| answer_with_time = f"""<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; margin-bottom: 10px;'> | |
| <h3 style='color: #63b3ed; margin-top: 0;'>Ответ (Модель: {current_model}):</h3> | |
| <div style='line-height: 1.6; font-size: 16px;'>{response.response}</div> | |
| <div style='margin-top: 15px; padding-top: 10px; border-top: 1px solid #4a5568; font-size: 14px; color: #a0aec0;'> | |
| Время обработки: {processing_time:.2f} секунд | |
| </div> | |
| </div>""" | |
| return answer_with_time, sources_html | |
| except Exception as e: | |
| log_message(f"Ошибка обработки вопроса: {str(e)}") | |
| error_msg = f"<div style='background-color: #e53e3e; color: white; padding: 20px; border-radius: 10px;'>Ошибка обработки вопроса: {str(e)}</div>" | |
| return error_msg, "" | |
| def generate_sources_html(nodes): | |
| html = "<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; max-height: 400px; overflow-y: auto;'>" | |
| html += "<h3 style='color: #63b3ed; margin-top: 0;'>Источники:</h3>" | |
| for i, node in enumerate(nodes): | |
| metadata = node.metadata if hasattr(node, 'metadata') else {} | |
| doc_type = metadata.get('type', 'text') | |
| doc_id = metadata.get('document_id', 'unknown') | |
| html += f"<div style='margin-bottom: 15px; padding: 15px; border: 1px solid #4a5568; border-radius: 8px; background-color: #1a202c;'>" | |
| if doc_type == 'text': | |
| html += f"<h4 style='margin: 0 0 10px 0; color: #63b3ed;'>📄 {doc_id}</h4>" | |
| elif doc_type == 'table': | |
| table_num = metadata.get('table_number', 'unknown') | |
| if table_num and table_num != 'unknown': | |
| if not table_num.startswith('№'): | |
| table_num = f"№{table_num}" | |
| html += f"<h4 style='margin: 0 0 10px 0; color: #68d391;'>📊 Таблица {table_num} - {doc_id}</h4>" | |
| else: | |
| html += f"<h4 style='margin: 0 0 10px 0; color: #68d391;'>📊 Таблица - {doc_id}</h4>" | |
| elif doc_type == 'image': | |
| image_num = metadata.get('image_number', 'unknown') | |
| section = metadata.get('section', '') | |
| if image_num and image_num != 'unknown': | |
| if not str(image_num).startswith('№'): | |
| image_num = f"№{image_num}" | |
| html += f"<h4 style='margin: 0 0 10px 0; color: #fbb6ce;'>🖼️ Изображение {image_num} - {doc_id} ({section})</h4>" | |
| else: | |
| html += f"<h4 style='margin: 0 0 10px 0; color: #fbb6ce;'>🖼️ Изображение - {doc_id} ({section})</h4>" | |
| if chunks_df is not None and 'file_link' in chunks_df.columns and doc_type == 'text': | |
| doc_rows = chunks_df[chunks_df['document_id'] == doc_id] | |
| if not doc_rows.empty: | |
| file_link = doc_rows.iloc[0]['file_link'] | |
| html += f"<a href='{file_link}' target='_blank' style='color: #68d391; text-decoration: none; font-size: 14px; display: inline-block; margin-top: 10px;'>🔗 Ссылка на документ</a><br>" | |
| html += "</div>" | |
| html += "</div>" | |
| return html | |
| def create_demo_interface(): | |
| with gr.Blocks(title="AIEXP - AI Expert для нормативной документации", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # AIEXP - Artificial Intelligence Expert | |
| ## Инструмент для работы с нормативной документацией | |
| """) | |
| with gr.Tab("🏠 Поиск по нормативным документам"): | |
| gr.Markdown("### Задайте вопрос по нормативной документации") | |
| # Model selection section | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(AVAILABLE_MODELS.keys()), | |
| value=current_model, | |
| label="🤖 Выберите языковую модель", | |
| info="Выберите модель для генерации ответов" | |
| ) | |
| with gr.Column(scale=1): | |
| switch_btn = gr.Button("🔄 Переключить модель", variant="secondary") | |
| model_status = gr.Textbox( | |
| value=f"Текущая модель: {current_model}", | |
| label="Статус модели", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| question_input = gr.Textbox( | |
| label="Ваш вопрос к базе знаний", | |
| placeholder="Введите вопрос по нормативным документам...", | |
| lines=3 | |
| ) | |
| ask_btn = gr.Button("🔍 Найти ответ", variant="primary", size="lg") | |
| gr.Examples( | |
| examples=[ | |
| "О чем этот рисунок: ГОСТ Р 50.04.07-2022 Приложение Л. Л.1.5 Рисунок Л.2", | |
| "Л.9 Формула в ГОСТ Р 50.04.07 - 2022 что и о чем там?", | |
| "Какой стандарт устанавливает порядок признания протоколов испытаний продукции в области использования атомной энергии?", | |
| "Кто несет ответственность за организацию и проведение признания протоколов испытаний продукции?", | |
| "В каких случаях могут быть признаны протоколы испытаний, проведенные лабораториями?", | |
| ], | |
| inputs=question_input | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| answer_output = gr.HTML( | |
| label="", | |
| value=f"<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; text-align: center;'>Здесь появится ответ на ваш вопрос...<br><small>Текущая модель: {current_model}</small></div>", | |
| ) | |
| with gr.Column(scale=1): | |
| sources_output = gr.HTML( | |
| label="", | |
| value="<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; text-align: center;'>Здесь появятся источники...</div>", | |
| ) | |
| # Event handlers | |
| def update_model_status(new_model): | |
| result = switch_model(new_model) | |
| return result | |
| switch_btn.click( | |
| fn=update_model_status, | |
| inputs=[model_dropdown], | |
| outputs=[model_status] | |
| ) | |
| ask_btn.click( | |
| fn=answer_question, | |
| inputs=[question_input], | |
| outputs=[answer_output, sources_output] | |
| ) | |
| question_input.submit( | |
| fn=answer_question, | |
| inputs=[question_input], | |
| outputs=[answer_output, sources_output] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| log_message("Запуск AIEXP - AI Expert для нормативной документации") | |
| if initialize_models(): | |
| log_message("Запуск веб-интерфейса") | |
| demo = create_demo_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| debug=False | |
| ) | |
| else: | |
| log_message("Невозможно запустить приложение из-за ошибки инициализации") | |
| sys.exit(1) |