Course_Project / app.py
fruitpicker01's picture
Update app.py
bcb1683 verified
raw
history blame
19.9 kB
#!/usr/bin/env python3
"""
Демо RAG система для HuggingFace Spaces
Использует предварительно обработанные чанки отчета Сбера
Оптимизирована для быстрого запуска без тяжелых зависимостей
"""
import os
import sys
import json
import pickle
import tempfile
from pathlib import Path
from typing import Optional, Dict, Any, List, Tuple
import traceback
import re
import gradio as gr
import numpy as np
# OpenAI для генерации ответов
from openai import OpenAI
class LightweightRAGSystem:
"""Легковесная RAG система с предзагруженными чанками"""
def __init__(self):
self.chunks = []
self.word_index = {}
self.metadata = {}
self.client = None
self.is_initialized = False
# Конфигурация
self.generation_model = "gpt-4o"
self.reranking_model = "gpt-4o-mini"
self.max_chunks_for_rerank = 15
self.final_chunks_count = 5
def load_preprocessed_data(self) -> bool:
"""Загрузка предварительно обработанных данных"""
try:
print("🔄 Загрузка предварительно обработанных данных...")
# Загружаем улучшенный индекс с таблицами
index_file = "enhanced_sber_index.pkl"
if not os.path.exists(index_file):
print(f"❌ Файл индекса не найден: {index_file}")
return False
with open(index_file, 'rb') as f:
index_data = pickle.load(f)
self.chunks = index_data["chunks"]
self.word_index = index_data["word_index"]
self.metadata = index_data["metadata"]
print(f"✅ Загружено {len(self.chunks)} чанков")
print(f"✅ Создан словарный индекс из {len(self.word_index)} слов")
return True
except Exception as e:
print(f"❌ Ошибка загрузки данных: {e}")
traceback.print_exc()
return False
def initialize_with_api_key(self, api_key: str) -> Tuple[str, str]:
"""Инициализация системы с API ключом"""
try:
if not api_key.strip():
return "❌ Введите OpenAI API ключ", ""
# Инициализация OpenAI клиента
self.client = OpenAI(api_key=api_key.strip())
# Загрузка данных
if not self.load_preprocessed_data():
return "❌ Ошибка загрузки данных", ""
self.is_initialized = True
# Генерация статистики
stats = self._generate_stats()
return "✅ Система инициализирована успешно", stats
except Exception as e:
return f"❌ Ошибка инициализации: {str(e)}", ""
def _generate_stats(self) -> str:
"""Генерация статистики системы"""
total_chunks = self.metadata.get("total_chunks", 0)
avg_length = self.metadata.get("avg_chunk_length", 0)
avg_tokens = self.metadata.get("avg_token_count", 0)
pages = self.metadata.get("pages_processed", 0)
# Добавим информацию о таблицах
text_chunks = self.metadata.get("text_chunks", 0)
table_chunks = self.metadata.get("table_chunks", 0)
table_pages = self.metadata.get("table_pages", 0)
stats = f"""✅ **Улучшенная система готова к работе!**
📊 **Статистика:**
- 📦 Загружено чанков: {total_chunks}
- 📝 Текстовых чанков: {text_chunks}
- 📋 Табличных чанков: {table_chunks}
- 📏 Средняя длина чанка: {avg_length:.0f} символов
- 🔢 Средний размер: {avg_tokens:.0f} токенов
- 📖 Страниц отчета: {pages}
- 📊 Страниц с таблицами: {table_pages}
🔍 **Возможности:**
- 🔎 Быстрый поиск по ключевым словам
- 📋 Извлечение структурированных таблиц
- 🧠 LLM реранкинг результатов (GPT-4o-mini)
- 📝 Интеллектуальная генерация ответов (GPT-4o)
- 📊 Анализ годового отчета ПАО Сбербанк 2023
🚀 **Готова отвечать на вопросы с поддержкой таблиц!**"""
return stats
def search_by_keywords(self, query: str, max_results: int = 30) -> List[Dict]:
"""Поиск по ключевым словам"""
if not query.strip():
return []
# Извлекаем ключевые слова из запроса
query_words = set(re.findall(r'\b\w+\b', query.lower()))
# Находим чанки, содержащие эти слова
chunk_scores = {}
for word in query_words:
if word in self.word_index:
for chunk_idx in self.word_index[word]:
if chunk_idx not in chunk_scores:
chunk_scores[chunk_idx] = 0
chunk_scores[chunk_idx] += 1
# Сортируем по количеству совпадений
sorted_chunks = sorted(chunk_scores.items(), key=lambda x: x[1], reverse=True)
# Возвращаем результаты
results = []
for chunk_idx, score in sorted_chunks[:max_results]:
if chunk_idx < len(self.chunks):
chunk = self.chunks[chunk_idx].copy()
chunk["keyword_score"] = score
chunk["similarity"] = score / len(query_words) # Нормализованный score
results.append(chunk)
return results
def rerank_with_llm(self, query: str, chunks: List[Dict]) -> List[Dict]:
"""LLM реранкинг результатов"""
if not chunks or not self.client:
return chunks
try:
# Ограничиваем количество чанков для реранкинга
chunks_to_rerank = chunks[:self.max_chunks_for_rerank]
# Подготавливаем документы для реранкинга
docs_text = ""
for i, chunk in enumerate(chunks_to_rerank):
preview = chunk['text'][:300] + "..." if len(chunk['text']) > 300 else chunk['text']
docs_text += f"\nДокумент {i+1} (стр. {chunk['page']}):\n{preview}\n"
prompt = f"""Оцени релевантность каждого документа для ответа на вопрос по шкале 1-10.
Вопрос: {query}
Документы:{docs_text}
Инструкции:
1. Оценивай точность и полноту информации для ответа
2. Высшие баллы (8-10) - прямой ответ на вопрос
3. Средние баллы (5-7) - частично релевантная информация
4. Низкие баллы (1-4) - слабо связано с вопросом
Верни только числа через запятую (например: 8,6,9,4,7):"""
response = self.client.chat.completions.create(
model=self.reranking_model,
messages=[{"role": "user", "content": prompt}],
max_tokens=100,
temperature=0
)
# Парсим оценки
scores_text = response.choices[0].message.content.strip()
scores = []
numbers = re.findall(r'\d+\.?\d*', scores_text)
for num in numbers:
score = float(num)
score = max(0, min(10, score)) # Ограничиваем 0-10
scores.append(score)
# Применяем оценки
reranked = []
for i, chunk in enumerate(chunks):
chunk_copy = chunk.copy()
if i < len(scores):
chunk_copy["rerank_score"] = scores[i]
else:
chunk_copy["rerank_score"] = 0
reranked.append(chunk_copy)
# Сортируем по реранк скору
reranked.sort(key=lambda x: x["rerank_score"], reverse=True)
return reranked
except Exception as e:
print(f"❌ Ошибка реранкинга: {e}")
return chunks
def generate_answer(self, query: str, context_chunks: List[Dict]) -> str:
"""Генерация ответа на основе контекста"""
if not self.client:
return "❌ OpenAI API не настроен"
try:
# Подготавливаем контекст
context_parts = []
for i, chunk in enumerate(context_chunks[:self.final_chunks_count]):
context_parts.append(f"Фрагмент {i+1} (страница {chunk['page']}):\n{chunk['text']}")
context = "\n\n".join(context_parts)
prompt = f"""Ты - эксперт по анализу финансовых отчетов. Ответь на вопрос пользователя на основе предоставленного контекста из годового отчета ПАО Сбербанк 2023.
ВОПРОС: {query}
КОНТЕКСТ ИЗ ОТЧЕТА:
{context}
ИНСТРУКЦИИ:
1. Отвечай только на основе предоставленной информации
2. Если информации недостаточно, честно об этом скажи
3. Используй конкретные данные и цифры из отчета
4. Структурируй ответ четко и понятно
5. Указывай номера страниц при цитировании
6. Отвечай на русском языке
ОТВЕТ:"""
response = self.client.chat.completions.create(
model=self.generation_model,
messages=[{"role": "user", "content": prompt}],
max_tokens=1500,
temperature=0.1
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"❌ Ошибка генерации ответа: {str(e)}"
def process_query(self, query: str) -> Dict[str, Any]:
"""Обработка пользовательского запроса"""
if not self.is_initialized:
return {
"answer": "❌ Система не инициализирована. Введите API ключ.",
"sources": [],
"debug_info": {}
}
if not query.strip():
return {
"answer": "Пожалуйста, введите ваш вопрос.",
"sources": [],
"debug_info": {}
}
try:
# Шаг 1: Поиск по ключевым словам
initial_results = self.search_by_keywords(query, max_results=30)
if not initial_results:
return {
"answer": "К сожалению, не удалось найти релевантную информацию по вашему вопросу.",
"sources": [],
"debug_info": {"step": "keyword_search", "results_count": 0}
}
# Шаг 2: LLM реранкинг
reranked_results = self.rerank_with_llm(query, initial_results)
# Шаг 3: Генерация ответа
top_chunks = reranked_results[:self.final_chunks_count]
answer = self.generate_answer(query, top_chunks)
# Подготовка источников
sources = []
for chunk in top_chunks:
sources.append({
"page": chunk["page"],
"keyword_score": chunk.get("keyword_score", 0),
"rerank_score": chunk.get("rerank_score", 0),
"preview": chunk["text"][:200] + "..." if len(chunk["text"]) > 200 else chunk["text"]
})
debug_info = {
"initial_results": len(initial_results),
"reranked_results": len(reranked_results),
"final_chunks": len(top_chunks),
"avg_keyword_score": np.mean([s["keyword_score"] for s in sources]) if sources else 0,
"avg_rerank_score": np.mean([s["rerank_score"] for s in sources]) if sources else 0
}
return {
"answer": answer,
"sources": sources,
"debug_info": debug_info
}
except Exception as e:
print(f"❌ Ошибка обработки запроса: {e}")
traceback.print_exc()
return {
"answer": f"❌ Ошибка обработки запроса: {str(e)}",
"sources": [],
"debug_info": {"error": str(e)}
}
# Глобальная переменная системы
rag_system = LightweightRAGSystem()
def initialize_system(api_key: str) -> Tuple[str, str]:
"""Инициализация системы"""
return rag_system.initialize_with_api_key(api_key)
def ask_question(question: str) -> Tuple[str, str]:
"""Обработка вопроса"""
result = rag_system.process_query(question)
answer = result["answer"]
# Форматируем информацию об источниках
sources_info = ""
if result["sources"]:
sources_info = "\n📚 **Источники:**\n"
for i, source in enumerate(result["sources"], 1):
sources_info += f"\n**{i}.** Страница {source['page']} "
sources_info += f"(ключевые слова: {source['keyword_score']}, "
sources_info += f"релевантность: {source['rerank_score']:.1f}/10)\n"
sources_info += f"*Превью:* {source['preview']}\n"
# Добавляем отладочную информацию
if result.get("debug_info"):
debug = result["debug_info"]
sources_info += f"\n🔍 **Статистика поиска:**\n"
sources_info += f"- Найдено по ключевым словам: {debug.get('initial_results', 0)}\n"
sources_info += f"- После реранкинга: {debug.get('reranked_results', 0)}\n"
sources_info += f"- Использовано в ответе: {debug.get('final_chunks', 0)}\n"
if debug.get('avg_rerank_score'):
sources_info += f"- Средняя релевантность: {debug.get('avg_rerank_score', 0):.1f}/10\n"
return answer, sources_info
def create_demo_interface():
"""Создание демо интерфейса для HF"""
with gr.Blocks(
title="RAG Demo - Сбер 2023",
theme=gr.themes.Soft(),
css="""
.main-header { text-align: center; margin-bottom: 2rem; }
.feature-box { background-color: #f8f9fa; padding: 1rem; border-radius: 8px; margin: 1rem 0; }
"""
) as demo:
gr.Markdown("""
<div class="main-header">
<h1>🏆 Enhanced RAG Demo: Анализ отчета Сбера 2023</h1>
<p>Улучшенная система поиска с поддержкой таблиц</p>
<p><strong>84 извлеченные таблицы • 2009 чанков • pdfplumber обработка</strong></p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### ⚙️ Настройка")
api_key_input = gr.Textbox(
label="OpenAI API Key",
placeholder="sk-...",
type="password",
info="Введите ваш OpenAI API ключ для работы системы"
)
init_btn = gr.Button("🚀 Инициализировать", variant="primary")
status_output = gr.Textbox(
label="Статус",
interactive=False,
lines=2
)
with gr.Column(scale=1):
stats_output = gr.Markdown("### 📊 Ожидание инициализации...")
gr.Markdown("### 💬 Задайте вопрос")
with gr.Row():
question_input = gr.Textbox(
label="Ваш вопрос",
placeholder="Например: Каковы основные финансовые показатели Сбера за 2023 год?",
lines=2,
scale=4
)
ask_btn = gr.Button("📝 Спросить", variant="primary", scale=1)
with gr.Row():
with gr.Column(scale=2):
answer_output = gr.Textbox(
label="Ответ системы",
lines=12,
interactive=False
)
with gr.Column(scale=1):
sources_output = gr.Textbox(
label="Источники и статистика",
lines=12,
interactive=False
)
# Примеры вопросов
gr.Markdown("""
### 💡 Примеры вопросов:
- Каковы основные финансовые показатели Сбера за 2023 год?
- Какова чистая прибыль банка в 2023 году?
- Расскажите о кредитном портфеле Сбербанка
- Какие технологические инициативы развивает Сбер?
- Каковы показатели рентабельности банка?
""")
# Event handlers
init_btn.click(
fn=initialize_system,
inputs=[api_key_input],
outputs=[status_output, stats_output]
)
ask_btn.click(
fn=ask_question,
inputs=[question_input],
outputs=[answer_output, sources_output]
)
question_input.submit(
fn=ask_question,
inputs=[question_input],
outputs=[answer_output, sources_output]
)
return demo
if __name__ == "__main__":
demo = create_demo_interface()
demo.launch(
share=False,
server_name="0.0.0.0",
server_port=7860,
show_error=True
)