Course_Project / app.py
fruitpicker01's picture
Update app.py
12bd40e verified
raw
history blame
24.1 kB
#!/usr/bin/env python3
"""
Финальная векторная RAG система для HuggingFace Spaces
Адаптированная версия с поддержкой векторного поиска и резервным режимом
"""
import os
import json
import pickle
import tempfile
from pathlib import Path
from typing import Optional, Dict, Any, List, Tuple
import traceback
import re
try:
import numpy as np
import faiss
HAS_FAISS = True
except ImportError:
HAS_FAISS = False
print("⚠️ FAISS не установлен, будет использован поиск по ключевым словам")
try:
import gradio as gr
HAS_GRADIO = True
except ImportError:
HAS_GRADIO = False
print("⚠️ Gradio не установлен")
from openai import OpenAI
class VectorRAGSystem:
"""RAG система с векторным поиском и резервным режимом"""
def __init__(self):
self.chunks = []
self.word_index = {}
self.faiss_index = None
self.metadata = {}
self.client = None
self.is_initialized = False
# Модели и параметры
self.embedding_model = "text-embedding-3-large"
self.embedding_dim = 3072
self.generation_model = "gpt-4o"
self.reranking_model = "gpt-4o-mini"
# Параметры поиска
self.max_chunks_for_rerank = 15
self.final_chunks_count = 5
self.vector_search_k = 20
# Режим работы
self.vector_mode = HAS_FAISS
def initialize_with_api_key(self, api_key: str) -> Tuple[str, str]:
"""Инициализация системы с API ключом"""
try:
if not api_key.strip():
return "❌ Введите OpenAI API ключ", ""
# Инициализация OpenAI клиента
self.client = OpenAI(api_key=api_key.strip())
# Загрузка данных
if not self.load_data():
return "❌ Ошибка загрузки данных", ""
self.is_initialized = True
stats = self._generate_stats()
return "✅ Векторная RAG система инициализирована", stats
except Exception as e:
return f"❌ Ошибка инициализации: {str(e)}", ""
def load_data(self) -> bool:
"""Загрузка данных (векторных или обычных)"""
try:
# Сначала пробуем загрузить векторные данные
if self.vector_mode and self.load_vector_data():
return True
# Если не удалось, загружаем обычные данные
return self.load_fallback_data()
except Exception as e:
print(f"❌ Ошибка загрузки данных: {e}")
return False
def load_vector_data(self) -> bool:
"""Загрузка векторных данных"""
try:
print("🔄 Попытка загрузки векторных данных...")
# Файлы векторных данных
chunks_file = "vector_enhanced_sber_chunks.pkl"
metadata_file = "vector_enhanced_sber_metadata.json"
faiss_file = "vector_enhanced_sber_faiss.index"
if not all(os.path.exists(f) for f in [chunks_file, metadata_file, faiss_file]):
print("📁 Файлы векторных данных не найдены")
return False
# Загружаем чанки
with open(chunks_file, 'rb') as f:
chunks_data = pickle.load(f)
self.chunks = []
for chunk_data in chunks_data:
self.chunks.append({
"text": chunk_data["text"],
"page": chunk_data["page"],
"chunk_index": chunk_data["chunk_index"],
"embedding": np.array(chunk_data["embedding"]) if chunk_data.get("embedding") else None,
"metadata": chunk_data.get("metadata", {}),
"full_page_text": chunk_data.get("full_page_text", chunk_data["text"])
})
# Загружаем метаданные
with open(metadata_file, 'r', encoding='utf-8') as f:
self.metadata = json.load(f)
# Загружаем FAISS индекс
if HAS_FAISS:
self.faiss_index = faiss.read_index(faiss_file)
print(f"✅ Загружены векторные данные: {len(self.chunks)} чанков")
return True
except Exception as e:
print(f"❌ Ошибка загрузки векторных данных: {e}")
return False
def load_fallback_data(self) -> bool:
"""Загрузка обычных данных"""
try:
print("🔄 Загрузка резервных данных...")
index_file = "enhanced_sber_index.pkl"
if not os.path.exists(index_file):
print(f"❌ Файл резервных данных не найден: {index_file}")
return False
with open(index_file, 'rb') as f:
index_data = pickle.load(f)
# Конвертируем в формат чанков
self.chunks = []
chunk_texts = index_data.get("chunks", [])
for i, chunk_text in enumerate(chunk_texts):
chunk = {
"text": chunk_text,
"page": index_data.get("metadata", {}).get("chunk_pages", {}).get(str(i), 1),
"chunk_index": i,
"embedding": None,
"metadata": {},
"full_page_text": chunk_text
}
self.chunks.append(chunk)
# Создаем словарный индекс для поиска
self.word_index = index_data.get("word_index", {})
self.metadata = index_data.get("metadata", {})
self.vector_mode = False # Отключаем векторный режим
print(f"✅ Загружены резервные данные: {len(self.chunks)} чанков")
return True
except Exception as e:
print(f"❌ Ошибка загрузки резервных данных: {e}")
return False
def _generate_stats(self) -> str:
"""Генерация статистики системы"""
total_chunks = len(self.chunks)
mode = "Векторный поиск" if self.vector_mode and self.faiss_index else "Поиск по ключевым словам"
stats = f"""✅ **RAG система готова!**
📊 **Статистика:**
- 📦 Загружено чанков: {total_chunks}
- 🔍 Режим поиска: {mode}
- 🧠 Модель генерации: {self.generation_model}
- 🎯 Реранкинг: {self.reranking_model}
🔍 **Возможности:**
- 🔎 Семантический/ключевой поиск
- 📄 Контекстное обогащение
- 🧠 LLM реранкинг результатов
- 📝 Интеллектуальная генерация ответов
- 📊 Анализ годового отчета ПАО Сбербанк 2023
🚀 **Готова к работе!**"""
return stats
def search(self, query: str, k: int = 20) -> List[Tuple[Dict, float]]:
"""Основной метод поиска"""
if self.vector_mode and self.faiss_index and self.client:
return self.vector_search(query, k)
else:
return self.keyword_search(query, k)
def vector_search(self, query: str, k: int = 20) -> List[Tuple[Dict, float]]:
"""Векторный поиск по запросу"""
if not self.faiss_index or not self.client:
return self.keyword_search(query, k)
try:
# Создаем эмбеддинг для запроса
response = self.client.embeddings.create(
model=self.embedding_model,
input=[query]
)
query_embedding = np.array(response.data[0].embedding, dtype=np.float32)
query_embedding = query_embedding.reshape(1, -1)
# Нормализуем для Inner Product
faiss.normalize_L2(query_embedding)
# Поиск в FAISS индексе
scores, indices = self.faiss_index.search(query_embedding, k)
# Формируем результаты
results = []
for score, idx in zip(scores[0], indices[0]):
if 0 <= idx < len(self.chunks):
chunk = self.chunks[idx]
results.append((chunk, float(score)))
return results
except Exception as e:
print(f"❌ Ошибка векторного поиска: {e}")
return self.keyword_search(query, k)
def keyword_search(self, query: str, k: int = 20) -> List[Tuple[Dict, float]]:
"""Поиск по ключевым словам"""
query_words = set(re.findall(r'\b\w+\b', query.lower()))
if self.word_index:
# Используем готовый индекс
chunk_scores = {}
for word in query_words:
if word in self.word_index:
for chunk_idx in self.word_index[word]:
if chunk_idx not in chunk_scores:
chunk_scores[chunk_idx] = 0
chunk_scores[chunk_idx] += 1
else:
# Создаем индекс на лету
chunk_scores = {}
for i, chunk in enumerate(self.chunks):
text_words = set(re.findall(r'\b\w+\b', chunk["text"].lower()))
score = len(query_words.intersection(text_words))
if score > 0:
chunk_scores[i] = score
# Сортируем по скору
sorted_chunks = sorted(chunk_scores.items(), key=lambda x: x[1], reverse=True)
results = []
for chunk_idx, score in sorted_chunks[:k]:
if chunk_idx < len(self.chunks):
chunk = self.chunks[chunk_idx]
results.append((chunk, float(score)))
return results
def rerank_with_llm(self, query: str, chunks: List[Tuple[Dict, float]]) -> List[Tuple[Dict, float]]:
"""LLM реранкинг результатов"""
if not chunks or not self.client:
return chunks
try:
chunks_to_rerank = chunks[:self.max_chunks_for_rerank]
docs_text = ""
for i, (chunk, _) in enumerate(chunks_to_rerank):
preview = chunk['text'][:300] + "..." if len(chunk['text']) > 300 else chunk['text']
docs_text += f"\nДокумент {i+1} (стр. {chunk['page']}):\n{preview}\n"
prompt = f"""Оцени релевантность каждого документа для ответа на вопрос по шкале 1-10.
Вопрос: {query}
Документы:{docs_text}
Инструкции:
1. Оценивай точность и полноту информации для ответа
2. Высшие баллы (8-10) - прямой ответ на вопрос
3. Средние баллы (5-7) - частично релевантная информация
4. Низкие баллы (1-4) - слабо связано с вопросом
Верни только числа через запятую (например: 8,6,9,4,7):"""
response = self.client.chat.completions.create(
model=self.reranking_model,
messages=[{"role": "user", "content": prompt}],
max_tokens=100,
temperature=0
)
scores_text = response.choices[0].message.content.strip()
numbers = re.findall(r'\d+\.?\d*', scores_text)
scores = [max(0, min(10, float(num))) for num in numbers]
reranked = []
for i, (chunk, original_score) in enumerate(chunks):
rerank_score = scores[i] if i < len(scores) else 0
reranked.append((chunk, rerank_score))
reranked.sort(key=lambda x: x[1], reverse=True)
return reranked
except Exception as e:
print(f"❌ Ошибка реранкинга: {e}")
return chunks
def generate_answer(self, query: str, context_chunks: List[Tuple[Dict, float]]) -> str:
"""Генерация ответа на основе контекста"""
if not self.client:
return "❌ OpenAI API не настроен"
try:
context_parts = []
for i, (chunk, score) in enumerate(context_chunks[:self.final_chunks_count]):
text = chunk.get('full_page_text', chunk['text'])
clean_text = text.encode('utf-8', errors='ignore').decode('utf-8')
context_parts.append(f"Фрагмент {i+1} (страница {chunk['page']}, релевантность: {score:.2f}):\n{clean_text}")
context = "\n\n".join(context_parts)
clean_query = query.encode('utf-8', errors='ignore').decode('utf-8')
prompt = f"""Ты - эксперт по анализу финансовых отчетов. Ответь на вопрос пользователя на основе предоставленного контекста из годового отчета ПАО Сбербанк 2023.
ВОПРОС: {clean_query}
КОНТЕКСТ ИЗ ОТЧЕТА:
{context}
ИНСТРУКЦИИ:
1. Отвечай только на основе предоставленной информации
2. Если информации недостаточно, честно об этом скажи
3. Используй конкретные данные и цифры из отчета
4. Структурируй ответ четко и понятно
5. Указывай номера страниц при цитировании
6. Отвечай на русском языке
ОТВЕТ:"""
response = self.client.chat.completions.create(
model=self.generation_model,
messages=[{"role": "user", "content": prompt}],
max_tokens=1500,
temperature=0.1
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"❌ Ошибка генерации ответа: {str(e)}"
def process_query(self, query: str) -> Dict[str, Any]:
"""Обработка пользовательского запроса"""
if not self.is_initialized:
return {
"answer": "❌ Система не инициализирована. Введите API ключ.",
"sources": [],
"debug_info": {}
}
if not query.strip():
return {
"answer": "Пожалуйста, введите ваш вопрос.",
"sources": [],
"debug_info": {}
}
try:
# Поиск
search_results = self.search(query, k=self.vector_search_k)
if not search_results:
return {
"answer": "К сожалению, не удалось найти релевантную информацию по вашему вопросу.",
"sources": [],
"debug_info": {"step": "search", "results_count": 0}
}
# Реранкинг
reranked_results = self.rerank_with_llm(query, search_results)
# Генерация ответа
answer = self.generate_answer(query, reranked_results)
# Подготовка источников
sources = []
for chunk, score in reranked_results[:self.final_chunks_count]:
sources.append({
"page": chunk["page"],
"search_score": search_results[0][1] if search_results else 0,
"rerank_score": score,
"preview": chunk["text"][:200] + "..." if len(chunk["text"]) > 200 else chunk["text"]
})
debug_info = {
"search_results": len(search_results),
"reranked_results": len(reranked_results),
"final_chunks": len(sources),
"search_method": "vector" if self.vector_mode else "keyword"
}
return {
"answer": answer,
"sources": sources,
"debug_info": debug_info
}
except Exception as e:
print(f"❌ Ошибка обработки запроса: {e}")
traceback.print_exc()
return {
"answer": f"❌ Ошибка обработки запроса: {str(e)}",
"sources": [],
"debug_info": {"error": str(e)}
}
# Глобальная переменная системы
rag_system = VectorRAGSystem()
def initialize_system(api_key: str) -> Tuple[str, str]:
"""Инициализация системы"""
return rag_system.initialize_with_api_key(api_key)
def ask_question(question: str) -> Tuple[str, str]:
"""Обработка вопроса"""
result = rag_system.process_query(question)
answer = result["answer"]
# Форматируем информацию об источниках
sources_info = ""
if result["sources"]:
sources_info = "\n📚 **Источники:**\n"
for i, source in enumerate(result["sources"], 1):
sources_info += f"\n**{i}.** Страница {source['page']} "
sources_info += f"(поиск: {source['search_score']:.3f}, "
sources_info += f"релевантность: {source['rerank_score']:.1f}/10)\n"
sources_info += f"*Превью:* {source['preview']}\n"
# Добавляем отладочную информацию
if result.get("debug_info"):
debug = result["debug_info"]
sources_info += f"\n🔍 **Статистика поиска:**\n"
sources_info += f"- Метод поиска: {debug.get('search_method', 'unknown')}\n"
sources_info += f"- Найдено результатов: {debug.get('search_results', 0)}\n"
sources_info += f"- После реранкинга: {debug.get('reranked_results', 0)}\n"
sources_info += f"- Использовано в ответе: {debug.get('final_chunks', 0)}\n"
return answer, sources_info
def create_demo_interface():
"""Создание демо интерфейса"""
if not HAS_GRADIO:
print("❌ Gradio не установлен. Установите: pip install gradio")
return None
with gr.Blocks(
title="Vector RAG Demo - Сбер 2023",
theme=gr.themes.Soft(),
css="""
.main-header { text-align: center; margin-bottom: 2rem; }
.feature-box { background-color: #f8f9fa; padding: 1rem; border-radius: 8px; margin: 1rem 0; }
"""
) as demo:
gr.Markdown("""
<div class="main-header">
<h1>🚀 Advanced RAG Demo: Анализ отчета Сбера 2023</h1>
<p>Умная система с векторным поиском и адаптивным режимом</p>
<p><strong>OpenAI embeddings • FAISS IndexFlatIP • LLM reranking • Fallback mode</strong></p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### ⚙️ Настройка")
api_key_input = gr.Textbox(
label="OpenAI API Key",
placeholder="sk-proj-...",
type="password",
info="Введите ваш OpenAI API ключ"
)
init_btn = gr.Button("🚀 Инициализировать", variant="primary")
status_output = gr.Textbox(
label="Статус",
interactive=False,
lines=2
)
with gr.Column(scale=1):
stats_output = gr.Markdown("### 📊 Ожидание инициализации...")
gr.Markdown("### 💬 Задайте вопрос")
with gr.Row():
question_input = gr.Textbox(
label="Ваш вопрос",
placeholder="Например: Каковы основные финансовые показатели Сбера за 2023 год?",
lines=2,
scale=4
)
ask_btn = gr.Button("🔍 Поиск", variant="primary", scale=1)
with gr.Row():
with gr.Column(scale=2):
answer_output = gr.Textbox(
label="Ответ системы",
lines=12,
interactive=False
)
with gr.Column(scale=1):
sources_output = gr.Textbox(
label="Источники и статистика",
lines=12,
interactive=False
)
# Примеры вопросов
gr.Markdown("""
### 💡 Примеры вопросов:
- Каковы основные финансовые показатели Сбера за 2023 год?
- Какова чистая прибыль банка в 2023 году?
- Расскажите о кредитном портфеле Сбербанка
- Какие технологические инициативы развивает Сбер?
- Каковы показатели рентабельности банка?
- Какие ESG инициативы реализует Сбер?
""")
# Event handlers
init_btn.click(
fn=initialize_system,
inputs=[api_key_input],
outputs=[status_output, stats_output]
)
ask_btn.click(
fn=ask_question,
inputs=[question_input],
outputs=[answer_output, sources_output]
)
question_input.submit(
fn=ask_question,
inputs=[question_input],
outputs=[answer_output, sources_output]
)
return demo
# Запуск для Hugging Face Spaces
demo = create_demo_interface()
demo.launch()