Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Демо RAG система для HuggingFace Spaces | |
| Использует предварительно обработанные чанки отчета Сбера | |
| Оптимизирована для быстрого запуска без тяжелых зависимостей | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import pickle | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Optional, Dict, Any, List, Tuple | |
| import traceback | |
| import re | |
| import gradio as gr | |
| import numpy as np | |
| # OpenAI для генерации ответов | |
| from openai import OpenAI | |
| class LightweightRAGSystem: | |
| """Легковесная RAG система с предзагруженными чанками""" | |
| def __init__(self): | |
| self.chunks = [] | |
| self.word_index = {} | |
| self.metadata = {} | |
| self.client = None | |
| self.is_initialized = False | |
| # Конфигурация | |
| self.generation_model = "gpt-4o" | |
| self.reranking_model = "gpt-4o-mini" | |
| self.max_chunks_for_rerank = 15 | |
| self.final_chunks_count = 5 | |
| def load_preprocessed_data(self) -> bool: | |
| """Загрузка предварительно обработанных данных""" | |
| try: | |
| print("🔄 Загрузка предварительно обработанных данных...") | |
| # Загружаем улучшенный индекс с таблицами | |
| index_file = "enhanced_sber_index.pkl" | |
| if not os.path.exists(index_file): | |
| print(f"❌ Файл индекса не найден: {index_file}") | |
| return False | |
| with open(index_file, 'rb') as f: | |
| index_data = pickle.load(f) | |
| self.chunks = index_data["chunks"] | |
| self.word_index = index_data["word_index"] | |
| self.metadata = index_data["metadata"] | |
| print(f"✅ Загружено {len(self.chunks)} чанков") | |
| print(f"✅ Создан словарный индекс из {len(self.word_index)} слов") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Ошибка загрузки данных: {e}") | |
| traceback.print_exc() | |
| return False | |
| def initialize_with_api_key(self, api_key: str) -> Tuple[str, str]: | |
| """Инициализация системы с API ключом""" | |
| try: | |
| if not api_key.strip(): | |
| return "❌ Введите OpenAI API ключ", "" | |
| # Инициализация OpenAI клиента | |
| self.client = OpenAI(api_key=api_key.strip()) | |
| # Загрузка данных | |
| if not self.load_preprocessed_data(): | |
| return "❌ Ошибка загрузки данных", "" | |
| self.is_initialized = True | |
| # Генерация статистики | |
| stats = self._generate_stats() | |
| return "✅ Система инициализирована успешно", stats | |
| except Exception as e: | |
| return f"❌ Ошибка инициализации: {str(e)}", "" | |
| def _generate_stats(self) -> str: | |
| """Генерация статистики системы""" | |
| total_chunks = self.metadata.get("total_chunks", 0) | |
| avg_length = self.metadata.get("avg_chunk_length", 0) | |
| avg_tokens = self.metadata.get("avg_token_count", 0) | |
| pages = self.metadata.get("pages_processed", 0) | |
| # Добавим информацию о таблицах | |
| text_chunks = self.metadata.get("text_chunks", 0) | |
| table_chunks = self.metadata.get("table_chunks", 0) | |
| table_pages = self.metadata.get("table_pages", 0) | |
| stats = f"""✅ **Улучшенная система готова к работе!** | |
| 📊 **Статистика:** | |
| - 📦 Загружено чанков: {total_chunks} | |
| - 📝 Текстовых чанков: {text_chunks} | |
| - 📋 Табличных чанков: {table_chunks} | |
| - 📏 Средняя длина чанка: {avg_length:.0f} символов | |
| - 🔢 Средний размер: {avg_tokens:.0f} токенов | |
| - 📖 Страниц отчета: {pages} | |
| - 📊 Страниц с таблицами: {table_pages} | |
| 🔍 **Возможности:** | |
| - 🔎 Быстрый поиск по ключевым словам | |
| - 📋 Извлечение структурированных таблиц | |
| - 🧠 LLM реранкинг результатов (GPT-4o-mini) | |
| - 📝 Интеллектуальная генерация ответов (GPT-4o) | |
| - 📊 Анализ годового отчета ПАО Сбербанк 2023 | |
| 🚀 **Готова отвечать на вопросы с поддержкой таблиц!**""" | |
| return stats | |
| def search_by_keywords(self, query: str, max_results: int = 30) -> List[Dict]: | |
| """Поиск по ключевым словам""" | |
| if not query.strip(): | |
| return [] | |
| # Извлекаем ключевые слова из запроса | |
| query_words = set(re.findall(r'\b\w+\b', query.lower())) | |
| # Находим чанки, содержащие эти слова | |
| chunk_scores = {} | |
| for word in query_words: | |
| if word in self.word_index: | |
| for chunk_idx in self.word_index[word]: | |
| if chunk_idx not in chunk_scores: | |
| chunk_scores[chunk_idx] = 0 | |
| chunk_scores[chunk_idx] += 1 | |
| # Сортируем по количеству совпадений | |
| sorted_chunks = sorted(chunk_scores.items(), key=lambda x: x[1], reverse=True) | |
| # Возвращаем результаты | |
| results = [] | |
| for chunk_idx, score in sorted_chunks[:max_results]: | |
| if chunk_idx < len(self.chunks): | |
| chunk = self.chunks[chunk_idx].copy() | |
| chunk["keyword_score"] = score | |
| chunk["similarity"] = score / len(query_words) # Нормализованный score | |
| results.append(chunk) | |
| return results | |
| def rerank_with_llm(self, query: str, chunks: List[Dict]) -> List[Dict]: | |
| """LLM реранкинг результатов""" | |
| if not chunks or not self.client: | |
| return chunks | |
| try: | |
| # Ограничиваем количество чанков для реранкинга | |
| chunks_to_rerank = chunks[:self.max_chunks_for_rerank] | |
| # Подготавливаем документы для реранкинга | |
| docs_text = "" | |
| for i, chunk in enumerate(chunks_to_rerank): | |
| preview = chunk['text'][:300] + "..." if len(chunk['text']) > 300 else chunk['text'] | |
| docs_text += f"\nДокумент {i+1} (стр. {chunk['page']}):\n{preview}\n" | |
| prompt = f"""Оцени релевантность каждого документа для ответа на вопрос по шкале 1-10. | |
| Вопрос: {query} | |
| Документы:{docs_text} | |
| Инструкции: | |
| 1. Оценивай точность и полноту информации для ответа | |
| 2. Высшие баллы (8-10) - прямой ответ на вопрос | |
| 3. Средние баллы (5-7) - частично релевантная информация | |
| 4. Низкие баллы (1-4) - слабо связано с вопросом | |
| Верни только числа через запятую (например: 8,6,9,4,7):""" | |
| response = self.client.chat.completions.create( | |
| model=self.reranking_model, | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=100, | |
| temperature=0 | |
| ) | |
| # Парсим оценки | |
| scores_text = response.choices[0].message.content.strip() | |
| scores = [] | |
| numbers = re.findall(r'\d+\.?\d*', scores_text) | |
| for num in numbers: | |
| score = float(num) | |
| score = max(0, min(10, score)) # Ограничиваем 0-10 | |
| scores.append(score) | |
| # Применяем оценки | |
| reranked = [] | |
| for i, chunk in enumerate(chunks): | |
| chunk_copy = chunk.copy() | |
| if i < len(scores): | |
| chunk_copy["rerank_score"] = scores[i] | |
| else: | |
| chunk_copy["rerank_score"] = 0 | |
| reranked.append(chunk_copy) | |
| # Сортируем по реранк скору | |
| reranked.sort(key=lambda x: x["rerank_score"], reverse=True) | |
| return reranked | |
| except Exception as e: | |
| print(f"❌ Ошибка реранкинга: {e}") | |
| return chunks | |
| def generate_answer(self, query: str, context_chunks: List[Dict]) -> str: | |
| """Генерация ответа на основе контекста""" | |
| if not self.client: | |
| return "❌ OpenAI API не настроен" | |
| try: | |
| # Подготавливаем контекст | |
| context_parts = [] | |
| for i, chunk in enumerate(context_chunks[:self.final_chunks_count]): | |
| context_parts.append(f"Фрагмент {i+1} (страница {chunk['page']}):\n{chunk['text']}") | |
| context = "\n\n".join(context_parts) | |
| prompt = f"""Ты - эксперт по анализу финансовых отчетов. Ответь на вопрос пользователя на основе предоставленного контекста из годового отчета ПАО Сбербанк 2023. | |
| ВОПРОС: {query} | |
| КОНТЕКСТ ИЗ ОТЧЕТА: | |
| {context} | |
| ИНСТРУКЦИИ: | |
| 1. Отвечай только на основе предоставленной информации | |
| 2. Если информации недостаточно, честно об этом скажи | |
| 3. Используй конкретные данные и цифры из отчета | |
| 4. Структурируй ответ четко и понятно | |
| 5. Указывай номера страниц при цитировании | |
| 6. Отвечай на русском языке | |
| ОТВЕТ:""" | |
| response = self.client.chat.completions.create( | |
| model=self.generation_model, | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=1500, | |
| temperature=0.1 | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| return f"❌ Ошибка генерации ответа: {str(e)}" | |
| def process_query(self, query: str) -> Dict[str, Any]: | |
| """Обработка пользовательского запроса""" | |
| if not self.is_initialized: | |
| return { | |
| "answer": "❌ Система не инициализирована. Введите API ключ.", | |
| "sources": [], | |
| "debug_info": {} | |
| } | |
| if not query.strip(): | |
| return { | |
| "answer": "Пожалуйста, введите ваш вопрос.", | |
| "sources": [], | |
| "debug_info": {} | |
| } | |
| try: | |
| # Шаг 1: Поиск по ключевым словам | |
| initial_results = self.search_by_keywords(query, max_results=30) | |
| if not initial_results: | |
| return { | |
| "answer": "К сожалению, не удалось найти релевантную информацию по вашему вопросу.", | |
| "sources": [], | |
| "debug_info": {"step": "keyword_search", "results_count": 0} | |
| } | |
| # Шаг 2: LLM реранкинг | |
| reranked_results = self.rerank_with_llm(query, initial_results) | |
| # Шаг 3: Генерация ответа | |
| top_chunks = reranked_results[:self.final_chunks_count] | |
| answer = self.generate_answer(query, top_chunks) | |
| # Подготовка источников | |
| sources = [] | |
| for chunk in top_chunks: | |
| sources.append({ | |
| "page": chunk["page"], | |
| "keyword_score": chunk.get("keyword_score", 0), | |
| "rerank_score": chunk.get("rerank_score", 0), | |
| "preview": chunk["text"][:200] + "..." if len(chunk["text"]) > 200 else chunk["text"] | |
| }) | |
| debug_info = { | |
| "initial_results": len(initial_results), | |
| "reranked_results": len(reranked_results), | |
| "final_chunks": len(top_chunks), | |
| "avg_keyword_score": np.mean([s["keyword_score"] for s in sources]) if sources else 0, | |
| "avg_rerank_score": np.mean([s["rerank_score"] for s in sources]) if sources else 0 | |
| } | |
| return { | |
| "answer": answer, | |
| "sources": sources, | |
| "debug_info": debug_info | |
| } | |
| except Exception as e: | |
| print(f"❌ Ошибка обработки запроса: {e}") | |
| traceback.print_exc() | |
| return { | |
| "answer": f"❌ Ошибка обработки запроса: {str(e)}", | |
| "sources": [], | |
| "debug_info": {"error": str(e)} | |
| } | |
| # Глобальная переменная системы | |
| rag_system = LightweightRAGSystem() | |
| def initialize_system(api_key: str) -> Tuple[str, str]: | |
| """Инициализация системы""" | |
| return rag_system.initialize_with_api_key(api_key) | |
| def ask_question(question: str) -> Tuple[str, str]: | |
| """Обработка вопроса""" | |
| result = rag_system.process_query(question) | |
| answer = result["answer"] | |
| # Форматируем информацию об источниках | |
| sources_info = "" | |
| if result["sources"]: | |
| sources_info = "\n📚 **Источники:**\n" | |
| for i, source in enumerate(result["sources"], 1): | |
| sources_info += f"\n**{i}.** Страница {source['page']} " | |
| sources_info += f"(ключевые слова: {source['keyword_score']}, " | |
| sources_info += f"релевантность: {source['rerank_score']:.1f}/10)\n" | |
| sources_info += f"*Превью:* {source['preview']}\n" | |
| # Добавляем отладочную информацию | |
| if result.get("debug_info"): | |
| debug = result["debug_info"] | |
| sources_info += f"\n🔍 **Статистика поиска:**\n" | |
| sources_info += f"- Найдено по ключевым словам: {debug.get('initial_results', 0)}\n" | |
| sources_info += f"- После реранкинга: {debug.get('reranked_results', 0)}\n" | |
| sources_info += f"- Использовано в ответе: {debug.get('final_chunks', 0)}\n" | |
| if debug.get('avg_rerank_score'): | |
| sources_info += f"- Средняя релевантность: {debug.get('avg_rerank_score', 0):.1f}/10\n" | |
| return answer, sources_info | |
| def create_demo_interface(): | |
| """Создание демо интерфейса для HF""" | |
| with gr.Blocks( | |
| title="RAG Demo - Сбер 2023", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .main-header { text-align: center; margin-bottom: 2rem; } | |
| .feature-box { background-color: #f8f9fa; padding: 1rem; border-radius: 8px; margin: 1rem 0; } | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| <div class="main-header"> | |
| <h1>🏆 Enhanced RAG Demo: Анализ отчета Сбера 2023</h1> | |
| <p>Улучшенная система поиска с поддержкой таблиц</p> | |
| <p><strong>84 извлеченные таблицы • 2009 чанков • pdfplumber обработка</strong></p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ⚙️ Настройка") | |
| api_key_input = gr.Textbox( | |
| label="OpenAI API Key", | |
| placeholder="sk-...", | |
| type="password", | |
| info="Введите ваш OpenAI API ключ для работы системы" | |
| ) | |
| init_btn = gr.Button("🚀 Инициализировать", variant="primary") | |
| status_output = gr.Textbox( | |
| label="Статус", | |
| interactive=False, | |
| lines=2 | |
| ) | |
| with gr.Column(scale=1): | |
| stats_output = gr.Markdown("### 📊 Ожидание инициализации...") | |
| gr.Markdown("### 💬 Задайте вопрос") | |
| with gr.Row(): | |
| question_input = gr.Textbox( | |
| label="Ваш вопрос", | |
| placeholder="Например: Каковы основные финансовые показатели Сбера за 2023 год?", | |
| lines=2, | |
| scale=4 | |
| ) | |
| ask_btn = gr.Button("📝 Спросить", variant="primary", scale=1) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| answer_output = gr.Textbox( | |
| label="Ответ системы", | |
| lines=12, | |
| interactive=False | |
| ) | |
| with gr.Column(scale=1): | |
| sources_output = gr.Textbox( | |
| label="Источники и статистика", | |
| lines=12, | |
| interactive=False | |
| ) | |
| # Примеры вопросов | |
| gr.Markdown(""" | |
| ### 💡 Примеры вопросов: | |
| - Каковы основные финансовые показатели Сбера за 2023 год? | |
| - Какова чистая прибыль банка в 2023 году? | |
| - Расскажите о кредитном портфеле Сбербанка | |
| - Какие технологические инициативы развивает Сбер? | |
| - Каковы показатели рентабельности банка? | |
| """) | |
| # Event handlers | |
| init_btn.click( | |
| fn=initialize_system, | |
| inputs=[api_key_input], | |
| outputs=[status_output, stats_output] | |
| ) | |
| ask_btn.click( | |
| fn=ask_question, | |
| inputs=[question_input], | |
| outputs=[answer_output, sources_output] | |
| ) | |
| question_input.submit( | |
| fn=ask_question, | |
| inputs=[question_input], | |
| outputs=[answer_output, sources_output] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_demo_interface() | |
| demo.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) |