MrSimple07 commited on
Commit
df5cefa
·
1 Parent(s): efb39e4
Files changed (1) hide show
  1. index_retriever.py +185 -177
index_retriever.py CHANGED
@@ -14,194 +14,202 @@ from config import *
14
 
15
  logger = logging.getLogger(__name__)
16
 
 
 
 
 
 
17
  def log_message(message):
18
  logger.info(message)
19
  print(message, flush=True)
20
 
21
- class IndexRetriever:
22
- def __init__(self, config):
23
- self.config = config
24
- self.vector_index = None
25
- self.query_engine = None
26
- self.reranker = None
27
- self.current_model = config.DEFAULT_MODEL
28
-
29
- def get_llm_model(self, model_name):
30
- try:
31
- model_config = self.config.AVAILABLE_MODELS.get(model_name)
32
- if not model_config:
33
- log_message(f"Модель {model_name} не найдена, использую модель по умолчанию")
34
- model_config = self.config.AVAILABLE_MODELS[self.config.DEFAULT_MODEL]
35
-
36
- if not model_config.get("api_key"):
37
- raise Exception(f"API ключ не найден для модели {model_name}")
38
-
39
- if model_config["provider"] == "google":
40
- return GoogleGenAI(
41
- model=model_config["model_name"],
42
- api_key=model_config["api_key"]
43
- )
44
- elif model_config["provider"] == "openai":
45
- return OpenAI(
46
- model=model_config["model_name"],
47
- api_key=model_config["api_key"]
48
- )
49
- else:
50
- raise Exception(f"Неподдерживаемый провайдер: {model_config['provider']}")
51
-
52
- except Exception as e:
53
- log_message(f"Ошибка создания модели {model_name}: {str(e)}")
54
- return GoogleGenAI(model="gemini-2.0-flash", api_key=self.config.GOOGLE_API_KEY)
55
-
56
- def initialize_models(self, documents):
57
- try:
58
- log_message("Инициализация моделей и индекса")
59
-
60
- embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
61
- llm = self.get_llm_model(self.current_model)
62
-
63
- log_message("Инициализирую переранкер")
64
- self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
65
-
66
- Settings.embed_model = embed_model
67
- Settings.llm = llm
68
-
69
- log_message(f"Строю векторный индекс из {len(documents)} документов")
70
- self.vector_index = VectorStoreIndex.from_documents(documents)
71
-
72
- self.create_query_engine()
73
-
74
- log_message(f"Модели и индекс успешно инициализированы с моделью: {self.current_model}")
75
- return True
76
-
77
- except Exception as e:
78
- log_message(f"Ошибка инициализации моделей: {str(e)}")
79
- return False
80
-
81
- def create_query_engine(self):
82
- try:
83
- log_message(f"Применяется промпт: {self.config.PROMPT_SIMPLE_POISK[:100]}...")
84
-
85
- bm25_retriever = BM25Retriever.from_defaults(
86
- docstore=self.vector_index.docstore,
87
- similarity_top_k=15
88
- )
89
-
90
- vector_retriever = VectorIndexRetriever(
91
- index=self.vector_index,
92
- similarity_top_k=20,
93
- similarity_cutoff=0.5
94
- )
95
-
96
- hybrid_retriever = QueryFusionRetriever(
97
- [vector_retriever, bm25_retriever],
98
- similarity_top_k=30,
99
- num_queries=1
100
- )
101
-
102
- custom_prompt_template = PromptTemplate(self.config.PROMPT_SIMPLE_POISK)
103
- response_synthesizer = get_response_synthesizer(
104
- response_mode=ResponseMode.TREE_SUMMARIZE,
105
- text_qa_template=custom_prompt_template
106
  )
107
-
108
- self.query_engine = RetrieverQueryEngine(
109
- retriever=hybrid_retriever,
110
- response_synthesizer=response_synthesizer
111
  )
 
 
112
 
113
- log_message("Query engine успешно создан с кастомным промптом")
114
-
115
- except Exception as e:
116
- log_message(f"Ошибка создания query engine: {str(e)}")
117
- raise
118
 
119
- def query(self, question):
120
- """Метод для выполнения запроса с применением промпта"""
121
- if self.query_engine is None:
122
- log_message("❌ Query engine не инициализирован")
123
- return " Система не инициализирована"
124
-
125
- try:
126
- log_message(f"Получен вопрос: {question}")
127
- log_message(f"Используется модель: {self.current_model}")
128
- log_message(f"Применяется промпт: {self.config.PROMPT_SIMPLE_POISK[:150]}...")
129
- log_message(f"Обрабатываю запрос: {question}")
130
-
131
- response = self.query_engine.query(question)
132
- log_message(f"Ответ получен, длина: {len(str(response))}")
133
-
134
- return str(response)
135
-
136
- except Exception as e:
137
- error_msg = f"Ошибка обработки запроса: {str(e)}"
138
- log_message(error_msg)
139
- return f" {error_msg}"
 
 
 
 
 
140
 
141
- def switch_model(self, model_name):
142
- try:
143
- log_message(f"Переключение на модель: {model_name}")
144
-
145
- new_llm = self.get_llm_model(model_name)
146
- Settings.llm = new_llm
147
-
148
- if self.vector_index is not None:
149
- self.create_query_engine()
150
- self.current_model = model_name
151
- log_message(f"Модель успешно переключена на: {model_name}")
152
- return f"✅ Модель переключена на: {model_name}"
153
- else:
154
- return "❌ Ошибка: система не инициализирована"
155
-
156
- except Exception as e:
157
- error_msg = f"Ошибка переключения модели: {str(e)}"
158
- log_message(error_msg)
159
- return f"❌ {error_msg}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- def rerank_nodes(self, query, nodes, top_k=10):
162
- if not nodes or not self.reranker:
163
- return nodes[:top_k]
 
 
 
 
 
 
 
 
 
164
 
165
- try:
166
- log_message(f"Переранжирую {len(nodes)} узлов")
167
-
168
- pairs = []
169
- for node in nodes:
170
- pairs.append([query, node.text])
171
-
172
- scores = self.reranker.predict(pairs)
173
-
174
- scored_nodes = list(zip(nodes, scores))
175
- scored_nodes.sort(key=lambda x: x[1], reverse=True)
176
-
177
- reranked_nodes = [node for node, score in scored_nodes[:top_k]]
178
- log_message(f"Возвращаю топ-{len(reranked_nodes)} переранжированных узлов")
179
-
180
- return reranked_nodes
181
- except Exception as e:
182
- log_message(f"Ошибка переранжировки: {str(e)}")
183
- return nodes[:top_k]
184
 
185
- def retrieve_nodes(self, question):
186
- if self.query_engine is None:
187
- return []
188
-
189
- try:
190
- log_message(f"Извлекаю релевантные узлы для вопроса: {question}")
191
- retrieved_nodes = self.query_engine.retriever.retrieve(question)
192
- log_message(f"Извлечено {len(retrieved_nodes)} узлов")
193
-
194
- log_message("Применяю переранжировку")
195
- reranked_nodes = self.rerank_nodes(question, retrieved_nodes, top_k=10)
196
-
197
- return reranked_nodes
198
-
199
- except Exception as e:
200
- log_message(f"Ошибка извлечения узлов: {str(e)}")
201
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
- def get_current_model(self):
204
- return self.current_model
205
 
206
- def is_initialized(self):
207
- return self.query_engine is not None
 
14
 
15
  logger = logging.getLogger(__name__)
16
 
17
+ vector_index = None
18
+ query_engine = None
19
+ reranker = None
20
+ current_model = DEFAULT_MODEL
21
+
22
  def log_message(message):
23
  logger.info(message)
24
  print(message, flush=True)
25
 
26
+ def get_llm_model(model_name):
27
+ try:
28
+ model_config = AVAILABLE_MODELS.get(model_name)
29
+ if not model_config:
30
+ log_message(f"Модель {model_name} не найдена, использую модель по умолчанию")
31
+ model_config = AVAILABLE_MODELS[DEFAULT_MODEL]
32
+
33
+ if not model_config.get("api_key"):
34
+ raise Exception(f"API ключ не найден для модели {model_name}")
35
+
36
+ if model_config["provider"] == "google":
37
+ return GoogleGenAI(
38
+ model=model_config["model_name"],
39
+ api_key=model_config["api_key"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  )
41
+ elif model_config["provider"] == "openai":
42
+ return OpenAI(
43
+ model=model_config["model_name"],
44
+ api_key=model_config["api_key"]
45
  )
46
+ else:
47
+ raise Exception(f"Неподдерживаемый провайдер: {model_config['provider']}")
48
 
49
+ except Exception as e:
50
+ log_message(f"Ошибка создания модели {model_name}: {str(e)}")
51
+ return GoogleGenAI(model="gemini-2.0-flash", api_key=GOOGLE_API_KEY)
 
 
52
 
53
+ def initialize_models(documents):
54
+ global vector_index, query_engine, reranker, current_model
55
+
56
+ try:
57
+ log_message("Инициализация моделей и индекса")
58
+
59
+ embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
60
+ llm = get_llm_model(current_model)
61
+
62
+ log_message("Инициализирую переранкер")
63
+ reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
64
+
65
+ Settings.embed_model = embed_model
66
+ Settings.llm = llm
67
+
68
+ log_message(f"Строю векторный индекс из {len(documents)} документов")
69
+ vector_index = VectorStoreIndex.from_documents(documents)
70
+
71
+ create_query_engine()
72
+
73
+ log_message(f"Модели и индекс успешно инициализированы с моделью: {current_model}")
74
+ return True
75
+
76
+ except Exception as e:
77
+ log_message(f"Ошибка инициализации моделей: {str(e)}")
78
+ return False
79
 
80
+ def create_query_engine():
81
+ global query_engine, vector_index
82
+
83
+ try:
84
+ log_message(f"Применяется промпт: {PROMPT_SIMPLE_POISK[:100]}...")
85
+
86
+ bm25_retriever = BM25Retriever.from_defaults(
87
+ docstore=vector_index.docstore,
88
+ similarity_top_k=15
89
+ )
90
+
91
+ vector_retriever = VectorIndexRetriever(
92
+ index=vector_index,
93
+ similarity_top_k=20,
94
+ similarity_cutoff=0.5
95
+ )
96
+
97
+ hybrid_retriever = QueryFusionRetriever(
98
+ [vector_retriever, bm25_retriever],
99
+ similarity_top_k=30,
100
+ num_queries=1
101
+ )
102
+
103
+ custom_prompt_template = PromptTemplate(PROMPT_SIMPLE_POISK)
104
+ response_synthesizer = get_response_synthesizer(
105
+ response_mode=ResponseMode.TREE_SUMMARIZE,
106
+ text_qa_template=custom_prompt_template
107
+ )
108
+
109
+ query_engine = RetrieverQueryEngine(
110
+ retriever=hybrid_retriever,
111
+ response_synthesizer=response_synthesizer
112
+ )
113
+
114
+ log_message("Query engine успешно создан с кастомным промптом")
115
+
116
+ except Exception as e:
117
+ log_message(f"Ошибка создания query engine: {str(e)}")
118
+ raise
119
 
120
+ def query(question):
121
+ global query_engine, current_model
122
+
123
+ if query_engine is None:
124
+ log_message("❌ Query engine не инициализирован")
125
+ return "❌ Система не инициализирована"
126
+
127
+ try:
128
+ log_message(f"Получен вопрос: {question}")
129
+ log_message(f"Используется модель: {current_model}")
130
+ log_message(f"Применяется промпт: {PROMPT_SIMPLE_POISK[:150]}...")
131
+ log_message(f"Обрабатываю запрос: {question}")
132
 
133
+ response = query_engine.query(question)
134
+ log_message(f"Ответ получен, длина: {len(str(response))}")
135
+
136
+ return str(response)
137
+
138
+ except Exception as e:
139
+ error_msg = f"Ошибка обработки запроса: {str(e)}"
140
+ log_message(error_msg)
141
+ return f"❌ {error_msg}"
 
 
 
 
 
 
 
 
 
 
142
 
143
+ def switch_model(model_name):
144
+ global current_model, vector_index
145
+
146
+ try:
147
+ log_message(f"Переключение на модель: {model_name}")
148
+
149
+ new_llm = get_llm_model(model_name)
150
+ Settings.llm = new_llm
151
+
152
+ if vector_index is not None:
153
+ create_query_engine()
154
+ current_model = model_name
155
+ log_message(f"Модель успешно переключена на: {model_name}")
156
+ return f"✅ Модель переключена на: {model_name}"
157
+ else:
158
+ return " Ошибка: система не инициализирована"
159
+
160
+ except Exception as e:
161
+ error_msg = f"Ошибка переключения модели: {str(e)}"
162
+ log_message(error_msg)
163
+ return f"❌ {error_msg}"
164
+
165
+ def rerank_nodes(query_text, nodes, top_k=10):
166
+ global reranker
167
+
168
+ if not nodes or not reranker:
169
+ return nodes[:top_k]
170
+
171
+ try:
172
+ log_message(f"Переранжирую {len(nodes)} узлов")
173
+
174
+ pairs = []
175
+ for node in nodes:
176
+ pairs.append([query_text, node.text])
177
+
178
+ scores = reranker.predict(pairs)
179
+
180
+ scored_nodes = list(zip(nodes, scores))
181
+ scored_nodes.sort(key=lambda x: x[1], reverse=True)
182
+
183
+ reranked_nodes = [node for node, score in scored_nodes[:top_k]]
184
+ log_message(f"Возвращаю топ-{len(reranked_nodes)} переранжированных узлов")
185
+
186
+ return reranked_nodes
187
+ except Exception as e:
188
+ log_message(f"Ошибка переранжировки: {str(e)}")
189
+ return nodes[:top_k]
190
+
191
+ def retrieve_nodes(question):
192
+ global query_engine
193
+
194
+ if query_engine is None:
195
+ return []
196
+
197
+ try:
198
+ log_message(f"Извлекаю релевантные узлы для вопроса: {question}")
199
+ retrieved_nodes = query_engine.retriever.retrieve(question)
200
+ log_message(f"Извлечено {len(retrieved_nodes)} узлов")
201
+
202
+ log_message("Применяю переранжировку")
203
+ reranked_nodes = rerank_nodes(question, retrieved_nodes, top_k=10)
204
+
205
+ return reranked_nodes
206
+
207
+ except Exception as e:
208
+ log_message(f"Ошибка извлечения узлов: {str(e)}")
209
+ return []
210
 
211
+ def get_current_model():
212
+ return current_model
213
 
214
+ def is_initialized():
215
+ return query_engine is not None