RAG_AIEXP_01 / index_retriever.py
MrSimple07's picture
new index
56751ff
raw
history blame
7.84 kB
from llama_index.core import VectorStoreIndex, Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.google_genai import GoogleGenAI
from llama_index.llms.openai import OpenAI
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.response_synthesizers import get_response_synthesizer, ResponseMode
from llama_index.core.prompts import PromptTemplate
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from sentence_transformers import CrossEncoder
import logging
from config import *
logger = logging.getLogger(__name__)
vector_index = None
query_engine = None
reranker = None
current_model = DEFAULT_MODEL
def log_message(message):
logger.info(message)
print(message, flush=True)
def get_llm_model(model_name):
try:
model_config = AVAILABLE_MODELS.get(model_name)
if not model_config:
log_message(f"Модель {model_name} не найдена, использую модель по умолчанию")
model_config = AVAILABLE_MODELS[DEFAULT_MODEL]
if not model_config.get("api_key"):
raise Exception(f"API ключ не найден для модели {model_name}")
if model_config["provider"] == "google":
return GoogleGenAI(
model=model_config["model_name"],
api_key=model_config["api_key"]
)
elif model_config["provider"] == "openai":
return OpenAI(
model=model_config["model_name"],
api_key=model_config["api_key"]
)
else:
raise Exception(f"Неподдерживаемый провайдер: {model_config['provider']}")
except Exception as e:
log_message(f"Ошибка создания модели {model_name}: {str(e)}")
return GoogleGenAI(model="gemini-2.0-flash", api_key=GOOGLE_API_KEY)
def initialize_models(documents):
global vector_index, query_engine, reranker, current_model
try:
log_message("Инициализация моделей и индекса")
embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
llm = get_llm_model(current_model)
log_message("Инициализирую переранкер")
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
Settings.embed_model = embed_model
Settings.llm = llm
log_message(f"Строю векторный индекс из {len(documents)} документов")
vector_index = VectorStoreIndex.from_documents(documents)
create_query_engine()
log_message(f"Модели и индекс успешно инициализированы с моделью: {current_model}")
return True
except Exception as e:
log_message(f"Ошибка инициализации моделей: {str(e)}")
return False
def create_query_engine():
global query_engine
try:
log_message(f"Применяется промпт: {PROMPT_SIMPLE_POISK[:100]}...")
bm25_retriever = BM25Retriever.from_defaults(
docstore=vector_index.docstore,
similarity_top_k=15
)
vector_retriever = VectorIndexRetriever(
index=vector_index,
similarity_top_k=20,
similarity_cutoff=0.5
)
hybrid_retriever = QueryFusionRetriever(
[vector_retriever, bm25_retriever],
similarity_top_k=30,
num_queries=1
)
custom_prompt_template = PromptTemplate(PROMPT_SIMPLE_POISK)
response_synthesizer = get_response_synthesizer(
response_mode=ResponseMode.TREE_SUMMARIZE,
text_qa_template=custom_prompt_template
)
query_engine = RetrieverQueryEngine(
retriever=hybrid_retriever,
response_synthesizer=response_synthesizer
)
log_message("Query engine успешно создан с кастомным промптом")
except Exception as e:
log_message(f"Ошибка создания query engine: {str(e)}")
raise
def query(question):
if query_engine is None:
log_message("❌ Query engine не инициализирован")
return "❌ Система не инициализирована"
try:
log_message(f"Получен вопрос: {question}")
log_message(f"Используется модель: {current_model}")
log_message(f"Применяется промпт: {PROMPT_SIMPLE_POISK[:150]}...")
log_message(f"Обрабатываю запрос: {question}")
response = query_engine.query(question)
log_message(f"Ответ получен, длина: {len(str(response))}")
return str(response)
except Exception as e:
error_msg = f"Ошибка обработки запроса: {str(e)}"
log_message(error_msg)
return f"❌ {error_msg}"
def switch_model(model_name):
global current_model
try:
log_message(f"Переключение на модель: {model_name}")
new_llm = get_llm_model(model_name)
Settings.llm = new_llm
if vector_index is not None:
create_query_engine()
current_model = model_name
log_message(f"Модель успешно переключена на: {model_name}")
return f"✅ Модель переключена на: {model_name}"
else:
return "❌ Ошибка: система не инициализирована"
except Exception as e:
error_msg = f"Ошибка переключения модели: {str(e)}"
log_message(error_msg)
return f"❌ {error_msg}"
def rerank_nodes(query_text, nodes, top_k=10):
if not nodes or not reranker:
return nodes[:top_k]
try:
log_message(f"Переранжирую {len(nodes)} узлов")
pairs = []
for node in nodes:
pairs.append([query_text, node.text])
scores = reranker.predict(pairs)
scored_nodes = list(zip(nodes, scores))
scored_nodes.sort(key=lambda x: x[1], reverse=True)
reranked_nodes = [node for node, score in scored_nodes[:top_k]]
log_message(f"Возвращаю топ-{len(reranked_nodes)} переранжированных узлов")
return reranked_nodes
except Exception as e:
log_message(f"Ошибка переранжировки: {str(e)}")
return nodes[:top_k]
def retrieve_nodes(question):
if query_engine is None:
return []
try:
log_message(f"Извлекаю релевантные узлы для вопроса: {question}")
retrieved_nodes = query_engine.retriever.retrieve(question)
log_message(f"Извлечено {len(retrieved_nodes)} узлов")
log_message("Применяю переранжировку")
reranked_nodes = rerank_nodes(question, retrieved_nodes, top_k=10)
return reranked_nodes
except Exception as e:
log_message(f"Ошибка извлечения узлов: {str(e)}")
return []
def get_current_model():
return current_model
def is_initialized():
return query_engine is not None