Spaces:
Sleeping
Sleeping
| from llama_index.core import VectorStoreIndex, Settings | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from llama_index.llms.google_genai import GoogleGenAI | |
| from llama_index.llms.openai import OpenAI | |
| from llama_index.core.query_engine import RetrieverQueryEngine | |
| from llama_index.core.retrievers import VectorIndexRetriever | |
| from llama_index.core.response_synthesizers import get_response_synthesizer, ResponseMode | |
| from llama_index.core.prompts import PromptTemplate | |
| from llama_index.retrievers.bm25 import BM25Retriever | |
| from llama_index.core.retrievers import QueryFusionRetriever | |
| from sentence_transformers import CrossEncoder | |
| import logging | |
| from config import * | |
| logger = logging.getLogger(__name__) | |
| vector_index = None | |
| query_engine = None | |
| reranker = None | |
| current_model = DEFAULT_MODEL | |
| def log_message(message): | |
| logger.info(message) | |
| print(message, flush=True) | |
| def get_llm_model(model_name): | |
| try: | |
| model_config = AVAILABLE_MODELS.get(model_name) | |
| if not model_config: | |
| log_message(f"Модель {model_name} не найдена, использую модель по умолчанию") | |
| model_config = AVAILABLE_MODELS[DEFAULT_MODEL] | |
| if not model_config.get("api_key"): | |
| raise Exception(f"API ключ не найден для модели {model_name}") | |
| if model_config["provider"] == "google": | |
| return GoogleGenAI( | |
| model=model_config["model_name"], | |
| api_key=model_config["api_key"] | |
| ) | |
| elif model_config["provider"] == "openai": | |
| return OpenAI( | |
| model=model_config["model_name"], | |
| api_key=model_config["api_key"] | |
| ) | |
| else: | |
| raise Exception(f"Неподдерживаемый провайдер: {model_config['provider']}") | |
| except Exception as e: | |
| log_message(f"Ошибка создания модели {model_name}: {str(e)}") | |
| return GoogleGenAI(model="gemini-2.0-flash", api_key=GOOGLE_API_KEY) | |
| def initialize_models(documents): | |
| global vector_index, query_engine, reranker, current_model | |
| try: | |
| log_message("Инициализация моделей и индекса") | |
| embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") | |
| llm = get_llm_model(current_model) | |
| log_message("Инициализирую переранкер") | |
| reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2') | |
| Settings.embed_model = embed_model | |
| Settings.llm = llm | |
| log_message(f"Строю векторный индекс из {len(documents)} документов") | |
| vector_index = VectorStoreIndex.from_documents(documents) | |
| create_query_engine() | |
| log_message(f"Модели и индекс успешно инициализированы с моделью: {current_model}") | |
| return True | |
| except Exception as e: | |
| log_message(f"Ошибка инициализации моделей: {str(e)}") | |
| return False | |
| def create_query_engine(): | |
| global query_engine | |
| try: | |
| log_message(f"Применяется промпт: {PROMPT_SIMPLE_POISK[:100]}...") | |
| bm25_retriever = BM25Retriever.from_defaults( | |
| docstore=vector_index.docstore, | |
| similarity_top_k=15 | |
| ) | |
| vector_retriever = VectorIndexRetriever( | |
| index=vector_index, | |
| similarity_top_k=20, | |
| similarity_cutoff=0.5 | |
| ) | |
| hybrid_retriever = QueryFusionRetriever( | |
| [vector_retriever, bm25_retriever], | |
| similarity_top_k=30, | |
| num_queries=1 | |
| ) | |
| custom_prompt_template = PromptTemplate(PROMPT_SIMPLE_POISK) | |
| response_synthesizer = get_response_synthesizer( | |
| response_mode=ResponseMode.TREE_SUMMARIZE, | |
| text_qa_template=custom_prompt_template | |
| ) | |
| query_engine = RetrieverQueryEngine( | |
| retriever=hybrid_retriever, | |
| response_synthesizer=response_synthesizer | |
| ) | |
| log_message("Query engine успешно создан с кастомным промптом") | |
| except Exception as e: | |
| log_message(f"Ошибка создания query engine: {str(e)}") | |
| raise | |
| def query(question): | |
| if query_engine is None: | |
| log_message("❌ Query engine не инициализирован") | |
| return "❌ Система не инициализирована" | |
| try: | |
| log_message(f"Получен вопрос: {question}") | |
| log_message(f"Используется модель: {current_model}") | |
| log_message(f"Применяется промпт: {PROMPT_SIMPLE_POISK[:150]}...") | |
| log_message(f"Обрабатываю запрос: {question}") | |
| response = query_engine.query(question) | |
| log_message(f"Ответ получен, длина: {len(str(response))}") | |
| return str(response) | |
| except Exception as e: | |
| error_msg = f"Ошибка обработки запроса: {str(e)}" | |
| log_message(error_msg) | |
| return f"❌ {error_msg}" | |
| def switch_model(model_name): | |
| global current_model | |
| try: | |
| log_message(f"Переключение на модель: {model_name}") | |
| new_llm = get_llm_model(model_name) | |
| Settings.llm = new_llm | |
| if vector_index is not None: | |
| create_query_engine() | |
| current_model = model_name | |
| log_message(f"Модель успешно переключена на: {model_name}") | |
| return f"✅ Модель переключена на: {model_name}" | |
| else: | |
| return "❌ Ошибка: система не инициализирована" | |
| except Exception as e: | |
| error_msg = f"Ошибка переключения модели: {str(e)}" | |
| log_message(error_msg) | |
| return f"❌ {error_msg}" | |
| def rerank_nodes(query_text, nodes, top_k=10): | |
| if not nodes or not reranker: | |
| return nodes[:top_k] | |
| try: | |
| log_message(f"Переранжирую {len(nodes)} узлов") | |
| pairs = [] | |
| for node in nodes: | |
| pairs.append([query_text, node.text]) | |
| scores = reranker.predict(pairs) | |
| scored_nodes = list(zip(nodes, scores)) | |
| scored_nodes.sort(key=lambda x: x[1], reverse=True) | |
| reranked_nodes = [node for node, score in scored_nodes[:top_k]] | |
| log_message(f"Возвращаю топ-{len(reranked_nodes)} переранжированных узлов") | |
| return reranked_nodes | |
| except Exception as e: | |
| log_message(f"Ошибка переранжировки: {str(e)}") | |
| return nodes[:top_k] | |
| def retrieve_nodes(question): | |
| if query_engine is None: | |
| return [] | |
| try: | |
| log_message(f"Извлекаю релевантные узлы для вопроса: {question}") | |
| retrieved_nodes = query_engine.retriever.retrieve(question) | |
| log_message(f"Извлечено {len(retrieved_nodes)} узлов") | |
| log_message("Применяю переранжировку") | |
| reranked_nodes = rerank_nodes(question, retrieved_nodes, top_k=10) | |
| return reranked_nodes | |
| except Exception as e: | |
| log_message(f"Ошибка извлечения узлов: {str(e)}") | |
| return [] | |
| def get_current_model(): | |
| return current_model | |
| def is_initialized(): | |
| return query_engine is not None |