MrSimple07 commited on
Commit
a7e15db
·
1 Parent(s): 0486693

fixed add_to_history error

Browse files
Files changed (3) hide show
  1. app.py +16 -54
  2. chat_handler.py +1 -1
  3. index_retriever.py +177 -179
app.py CHANGED
@@ -2,9 +2,9 @@ import gradio as gr
2
  import os
3
  import sys
4
  import logging
5
- from config import *
6
  from documents_prep import DocumentsPreparation
7
- import index_retriever
8
  from chat_handler import ChatHandler
9
 
10
  REPO_ID = "MrSimple01/AIEXP_RAG_FILES"
@@ -14,6 +14,7 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
14
  logger = logging.getLogger(__name__)
15
 
16
  doc_prep = None
 
17
  chat_handler = None
18
 
19
  def log_message(message):
@@ -22,12 +23,13 @@ def log_message(message):
22
  sys.stdout.flush()
23
 
24
  def initialize_system():
25
- global doc_prep, chat_handler
26
 
27
  try:
28
  log_message("Запуск инициализации системы AIEXP")
29
 
30
  doc_prep = DocumentsPreparation(REPO_ID, HF_TOKEN)
 
31
 
32
  log_message("Подготовка документов")
33
  all_documents = doc_prep.prepare_all_documents()
@@ -41,7 +43,7 @@ def initialize_system():
41
  log_message("Не удалось инициализировать модели")
42
  return False
43
 
44
- chat_handler = ChatHandler(None)
45
 
46
  log_message("Система успешно инициализирована")
47
  return True
@@ -53,57 +55,17 @@ def initialize_system():
53
  def handle_question(question):
54
  if chat_handler is None:
55
  return "Система не инициализирована", ""
56
-
57
- try:
58
- answer = index_retriever.query(question)
59
- sources = get_sources_for_question(question)
60
-
61
- # chat_handler.add_to_history(question, answer)
62
-
63
- return answer, sources
64
- except Exception as e:
65
- error_msg = f"Ошибка обработки вопроса: {str(e)}"
66
- log_message(error_msg)
67
- return error_msg, ""
68
-
69
- def get_sources_for_question(question):
70
- try:
71
- nodes = index_retriever.retrieve_nodes(question)
72
- if not nodes:
73
- return "<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; text-align: center;'>Источники не найдены</div>"
74
-
75
- sources_html = "<div style='background-color: #2d3748; color: white; padding: 15px; border-radius: 10px;'>"
76
- sources_html += "<h3 style='color: #4fd1c7; margin-top: 0;'>📚 Источники:</h3>"
77
-
78
- for i, node in enumerate(nodes[:5], 1):
79
- source_text = node.text[:200] + "..." if len(node.text) > 200 else node.text
80
- sources_html += f"<div style='margin: 10px 0; padding: 10px; background-color: #4a5568; border-radius: 5px;'>"
81
- sources_html += f"<strong>Источник {i}:</strong><br>"
82
- sources_html += f"<small>{source_text}</small>"
83
- sources_html += "</div>"
84
-
85
- sources_html += "</div>"
86
- return sources_html
87
-
88
- except Exception as e:
89
- log_message(f"Ошибка получения источников: {str(e)}")
90
- return "<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; text-align: center;'>Ошибка загрузки источников</div>"
91
 
92
  def handle_model_switch(model_name):
93
- try:
94
- return index_retriever.switch_model(model_name)
95
- except Exception as e:
96
- error_msg = f"Ошибка переключения модели: {str(e)}"
97
- log_message(error_msg)
98
- return f"❌ {error_msg}"
99
 
100
  def get_current_model_status():
101
- try:
102
- if not index_retriever.is_initialized():
103
- return "Система не инициализирована"
104
- return f"Текущая модель: {index_retriever.get_current_model()}"
105
- except Exception as e:
106
- return "Ошибка получения статуса модели"
107
 
108
  def get_chat_history_html():
109
  if chat_handler is None:
@@ -130,8 +92,8 @@ def create_demo_interface():
130
  with gr.Row():
131
  with gr.Column(scale=2):
132
  model_dropdown = gr.Dropdown(
133
- choices=list(AVAILABLE_MODELS.keys()),
134
- value=DEFAULT_MODEL,
135
  label="🤖 Выберите языковую модель",
136
  info="Выберите модель для генерации ответов"
137
  )
@@ -167,7 +129,7 @@ def create_demo_interface():
167
  with gr.Column(scale=2):
168
  answer_output = gr.HTML(
169
  label="",
170
- value=f"<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; text-align: center;'>Здесь появится ответ на ваш вопрос...<br><small>Текущая модель: {DEFAULT_MODEL}</small></div>",
171
  )
172
 
173
  with gr.Column(scale=1):
 
2
  import os
3
  import sys
4
  import logging
5
+ import config
6
  from documents_prep import DocumentsPreparation
7
+ from index_retriever import IndexRetriever
8
  from chat_handler import ChatHandler
9
 
10
  REPO_ID = "MrSimple01/AIEXP_RAG_FILES"
 
14
  logger = logging.getLogger(__name__)
15
 
16
  doc_prep = None
17
+ index_retriever = None
18
  chat_handler = None
19
 
20
  def log_message(message):
 
23
  sys.stdout.flush()
24
 
25
  def initialize_system():
26
+ global doc_prep, index_retriever, chat_handler
27
 
28
  try:
29
  log_message("Запуск инициализации системы AIEXP")
30
 
31
  doc_prep = DocumentsPreparation(REPO_ID, HF_TOKEN)
32
+ index_retriever = IndexRetriever(config=config)
33
 
34
  log_message("Подготовка документов")
35
  all_documents = doc_prep.prepare_all_documents()
 
43
  log_message("Не удалось инициализировать модели")
44
  return False
45
 
46
+ chat_handler = ChatHandler(index_retriever)
47
 
48
  log_message("Система успешно инициализирована")
49
  return True
 
55
  def handle_question(question):
56
  if chat_handler is None:
57
  return "Система не инициализирована", ""
58
+ return chat_handler.answer_question(question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def handle_model_switch(model_name):
61
+ if index_retriever is None:
62
+ return "Система не инициализирована"
63
+ return index_retriever.switch_model(model_name)
 
 
 
64
 
65
  def get_current_model_status():
66
+ if index_retriever is None:
67
+ return "Система не инициализирована"
68
+ return f"Текущая модель: {index_retriever.get_current_model()}"
 
 
 
69
 
70
  def get_chat_history_html():
71
  if chat_handler is None:
 
92
  with gr.Row():
93
  with gr.Column(scale=2):
94
  model_dropdown = gr.Dropdown(
95
+ choices=list(config.AVAILABLE_MODELS.keys()),
96
+ value=config.DEFAULT_MODEL,
97
  label="🤖 Выберите языковую модель",
98
  info="Выберите модель для генерации ответов"
99
  )
 
129
  with gr.Column(scale=2):
130
  answer_output = gr.HTML(
131
  label="",
132
+ value=f"<div style='background-color: #2d3748; color: white; padding: 20px; border-radius: 10px; text-align: center;'>Здесь появится ответ на ваш вопрос...<br><small>Текущая модель: {config.DEFAULT_MODEL}</small></div>",
133
  )
134
 
135
  with gr.Column(scale=1):
chat_handler.py CHANGED
@@ -147,7 +147,7 @@ class ChatHandler:
147
  html += f"<div style='margin-bottom: 10px; font-size: 14px;'>{entry['question']}</div>"
148
  html += f"<div style='color: #63b3ed; font-weight: bold; margin-bottom: 8px;'>Ответ ({entry['model']}):</div>"
149
  html += f"<div style='margin-bottom: 10px; font-size: 14px; line-height: 1.4;'>{entry['answer'][:300]}{'...' if len(entry['answer']) > 300 else ''}</div>"
150
- html += f"<div style='color: #a0aec0; font-size: 12px;'>Время: {entry['processing_time']:.2f}с | Источников: {entry['nodes_count']}</div>"
151
  html += "</div>"
152
 
153
  html += "</div>"
 
147
  html += f"<div style='margin-bottom: 10px; font-size: 14px;'>{entry['question']}</div>"
148
  html += f"<div style='color: #63b3ed; font-weight: bold; margin-bottom: 8px;'>Ответ ({entry['model']}):</div>"
149
  html += f"<div style='margin-bottom: 10px; font-size: 14px; line-height: 1.4;'>{entry['answer'][:300]}{'...' if len(entry['answer']) > 300 else ''}</div>"
150
+ html += f"<div style='color: #a0aec0; font-size: 12px;'>Время: {entry['processing_time']:.2f}с</div>"
151
  html += "</div>"
152
 
153
  html += "</div>"
index_retriever.py CHANGED
@@ -14,196 +14,194 @@ from config import *
14
 
15
  logger = logging.getLogger(__name__)
16
 
17
- vector_index = None
18
- query_engine = None
19
- reranker = None
20
- current_model = DEFAULT_MODEL
21
-
22
  def log_message(message):
23
  logger.info(message)
24
  print(message, flush=True)
25
 
26
- def get_llm_model(model_name):
27
- try:
28
- model_config = AVAILABLE_MODELS.get(model_name)
29
- if not model_config:
30
- log_message(f"Модель {model_name} не найдена, использую модель по умолчанию")
31
- model_config = AVAILABLE_MODELS[DEFAULT_MODEL]
32
-
33
- if not model_config.get("api_key"):
34
- raise Exception(f"API ключ не найден для модели {model_name}")
35
-
36
- if model_config["provider"] == "google":
37
- return GoogleGenAI(
38
- model=model_config["model_name"],
39
- api_key=model_config["api_key"]
40
- )
41
- elif model_config["provider"] == "openai":
42
- return OpenAI(
43
- model=model_config["model_name"],
44
- api_key=model_config["api_key"]
45
- )
46
- else:
47
- raise Exception(f"Неподдерживаемый провайдер: {model_config['provider']}")
48
 
49
- except Exception as e:
50
- log_message(f"Ошибка создания модели {model_name}: {str(e)}")
51
- return GoogleGenAI(model="gemini-2.0-flash", api_key=GOOGLE_API_KEY)
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- def initialize_models(documents):
54
- global vector_index, query_engine, reranker, current_model
55
-
56
- try:
57
- log_message("Инициализация моделей и индекса")
58
-
59
- embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
60
- llm = get_llm_model(current_model)
61
-
62
- log_message("Инициализирую переранкер")
63
- reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
64
-
65
- Settings.embed_model = embed_model
66
- Settings.llm = llm
67
-
68
- log_message(f"Строю векторный индекс из {len(documents)} документов")
69
- vector_index = VectorStoreIndex.from_documents(documents)
70
-
71
- create_query_engine()
72
-
73
- log_message(f"Модели и индекс успешно инициализированы с моделью: {current_model}")
74
- return True
75
-
76
- except Exception as e:
77
- log_message(f"Ошибка инициализации моделей: {str(e)}")
78
- return False
79
 
80
- def create_query_engine():
81
- global query_engine
82
-
83
- try:
84
- log_message(f"Применяется промпт: {PROMPT_SIMPLE_POISK[:100]}...")
85
-
86
- bm25_retriever = BM25Retriever.from_defaults(
87
- docstore=vector_index.docstore,
88
- similarity_top_k=15
89
- )
90
-
91
- vector_retriever = VectorIndexRetriever(
92
- index=vector_index,
93
- similarity_top_k=20,
94
- similarity_cutoff=0.5
95
- )
96
-
97
- hybrid_retriever = QueryFusionRetriever(
98
- [vector_retriever, bm25_retriever],
99
- similarity_top_k=30,
100
- num_queries=1
101
- )
102
-
103
- custom_prompt_template = PromptTemplate(PROMPT_SIMPLE_POISK)
104
- response_synthesizer = get_response_synthesizer(
105
- response_mode=ResponseMode.TREE_SUMMARIZE,
106
- text_qa_template=custom_prompt_template
107
- )
108
-
109
- query_engine = RetrieverQueryEngine(
110
- retriever=hybrid_retriever,
111
- response_synthesizer=response_synthesizer
112
- )
113
-
114
- log_message("Query engine успешно создан с кастомным промптом")
115
-
116
- except Exception as e:
117
- log_message(f"Ошибка создания query engine: {str(e)}")
118
- raise
119
 
120
- def query(question):
121
- if query_engine is None:
122
- log_message("❌ Query engine не инициализирован")
123
- return "❌ Система не инициализирована"
124
-
125
- try:
126
- log_message(f"Получен вопрос: {question}")
127
- log_message(f"Используется модель: {current_model}")
128
- log_message(f"Применяется промпт: {PROMPT_SIMPLE_POISK[:150]}...")
129
- log_message(f"Обрабатываю запрос: {question}")
130
-
131
- response = query_engine.query(question)
132
- log_message(f"Ответ получен, длина: {len(str(response))}")
133
-
134
- return str(response)
135
-
136
- except Exception as e:
137
- error_msg = f"Ошибка обработки запроса: {str(e)}"
138
- log_message(error_msg)
139
- return f"❌ {error_msg}"
 
140
 
141
- def switch_model(model_name):
142
- global current_model
143
-
144
- try:
145
- log_message(f"Переключение на модель: {model_name}")
146
-
147
- new_llm = get_llm_model(model_name)
148
- Settings.llm = new_llm
149
-
150
- if vector_index is not None:
151
- create_query_engine()
152
- current_model = model_name
153
- log_message(f"Модель успешно переключена на: {model_name}")
154
- return f" Модель переключена на: {model_name}"
155
- else:
156
- return "❌ Ошибка: система не инициализирована"
157
-
158
- except Exception as e:
159
- error_msg = f"Ошибка переключения модели: {str(e)}"
160
- log_message(error_msg)
161
- return f"❌ {error_msg}"
162
 
163
- def rerank_nodes(query_text, nodes, top_k=10):
164
- if not nodes or not reranker:
165
- return nodes[:top_k]
166
-
167
- try:
168
- log_message(f"Переранжирую {len(nodes)} узлов")
169
-
170
- pairs = []
171
- for node in nodes:
172
- pairs.append([query_text, node.text])
173
-
174
- scores = reranker.predict(pairs)
175
-
176
- scored_nodes = list(zip(nodes, scores))
177
- scored_nodes.sort(key=lambda x: x[1], reverse=True)
178
 
179
- reranked_nodes = [node for node, score in scored_nodes[:top_k]]
180
- log_message(f"Возвращаю топ-{len(reranked_nodes)} переранжированных узлов")
181
-
182
- return reranked_nodes
183
- except Exception as e:
184
- log_message(f"Ошибка переранжировки: {str(e)}")
185
- return nodes[:top_k]
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
- def retrieve_nodes(question):
188
- if query_engine is None:
189
- return []
190
-
191
- try:
192
- log_message(f"Извлекаю релевантные узлы для вопроса: {question}")
193
- retrieved_nodes = query_engine.retriever.retrieve(question)
194
- log_message(f"Извлечено {len(retrieved_nodes)} узлов")
195
-
196
- log_message("Применяю переранжировку")
197
- reranked_nodes = rerank_nodes(question, retrieved_nodes, top_k=10)
198
-
199
- return reranked_nodes
200
-
201
- except Exception as e:
202
- log_message(f"Ошибка извлечения узлов: {str(e)}")
203
- return []
204
 
205
- def get_current_model():
206
- return current_model
207
 
208
- def is_initialized():
209
- return query_engine is not None
 
14
 
15
  logger = logging.getLogger(__name__)
16
 
 
 
 
 
 
17
  def log_message(message):
18
  logger.info(message)
19
  print(message, flush=True)
20
 
21
+ class IndexRetriever:
22
+ def __init__(self, config):
23
+ self.config = config
24
+ self.vector_index = None
25
+ self.query_engine = None
26
+ self.reranker = None
27
+ self.current_model = config.DEFAULT_MODEL
28
+
29
+ def get_llm_model(self, model_name):
30
+ try:
31
+ model_config = self.config.AVAILABLE_MODELS.get(model_name)
32
+ if not model_config:
33
+ log_message(f"Модель {model_name} не найдена, использую модель по умолчанию")
34
+ model_config = self.config.AVAILABLE_MODELS[self.config.DEFAULT_MODEL]
35
+
36
+ if not model_config.get("api_key"):
37
+ raise Exception(f"API ключ не найден для модели {model_name}")
 
 
 
 
 
38
 
39
+ if model_config["provider"] == "google":
40
+ return GoogleGenAI(
41
+ model=model_config["model_name"],
42
+ api_key=model_config["api_key"]
43
+ )
44
+ elif model_config["provider"] == "openai":
45
+ return OpenAI(
46
+ model=model_config["model_name"],
47
+ api_key=model_config["api_key"]
48
+ )
49
+ else:
50
+ raise Exception(f"Неподдерживаемый провайдер: {model_config['provider']}")
51
+
52
+ except Exception as e:
53
+ log_message(f"Ошибка создания модели {model_name}: {str(e)}")
54
+ return GoogleGenAI(model="gemini-2.0-flash", api_key=self.config.GOOGLE_API_KEY)
55
 
56
+ def initialize_models(self, documents):
57
+ try:
58
+ log_message("Инициализация моделей и индекса")
59
+
60
+ embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
61
+ llm = self.get_llm_model(self.current_model)
62
+
63
+ log_message("Инициализирую переранкер")
64
+ self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
65
+
66
+ Settings.embed_model = embed_model
67
+ Settings.llm = llm
68
+
69
+ log_message(f"Строю векторный индекс из {len(documents)} документов")
70
+ self.vector_index = VectorStoreIndex.from_documents(documents)
71
+
72
+ self.create_query_engine()
73
+
74
+ log_message(f"Модели и индекс успешно инициализированы с моделью: {self.current_model}")
75
+ return True
76
+
77
+ except Exception as e:
78
+ log_message(f"Ошибка инициализации моделей: {str(e)}")
79
+ return False
 
 
80
 
81
+ def create_query_engine(self):
82
+ try:
83
+ log_message(f"Применяется промпт: {self.config.PROMPT_SIMPLE_POISK[:100]}...")
84
+
85
+ bm25_retriever = BM25Retriever.from_defaults(
86
+ docstore=self.vector_index.docstore,
87
+ similarity_top_k=15
88
+ )
89
+
90
+ vector_retriever = VectorIndexRetriever(
91
+ index=self.vector_index,
92
+ similarity_top_k=20,
93
+ similarity_cutoff=0.5
94
+ )
95
+
96
+ hybrid_retriever = QueryFusionRetriever(
97
+ [vector_retriever, bm25_retriever],
98
+ similarity_top_k=30,
99
+ num_queries=1
100
+ )
101
+
102
+ custom_prompt_template = PromptTemplate(self.config.PROMPT_SIMPLE_POISK)
103
+ response_synthesizer = get_response_synthesizer(
104
+ response_mode=ResponseMode.TREE_SUMMARIZE,
105
+ text_qa_template=custom_prompt_template
106
+ )
107
+
108
+ self.query_engine = RetrieverQueryEngine(
109
+ retriever=hybrid_retriever,
110
+ response_synthesizer=response_synthesizer
111
+ )
112
+
113
+ log_message("Query engine успешно создан с кастомным промптом")
114
+
115
+ except Exception as e:
116
+ log_message(f"Ошибка создания query engine: {str(e)}")
117
+ raise
 
 
118
 
119
+ def query(self, question):
120
+ """Метод для выполнения запроса с применением промпта"""
121
+ if self.query_engine is None:
122
+ log_message("❌ Query engine не инициализирован")
123
+ return "❌ Система не инициализирована"
124
+
125
+ try:
126
+ log_message(f"Получен вопрос: {question}")
127
+ log_message(f"Используется модель: {self.current_model}")
128
+ log_message(f"Применяется промпт: {self.config.PROMPT_SIMPLE_POISK[:150]}...")
129
+ log_message(f"Обрабатываю запрос: {question}")
130
+
131
+ response = self.query_engine.query(question)
132
+ log_message(f"Ответ получен, длина: {len(str(response))}")
133
+
134
+ return str(response)
135
+
136
+ except Exception as e:
137
+ error_msg = f"Ошибка обработки запроса: {str(e)}"
138
+ log_message(error_msg)
139
+ return f"❌ {error_msg}"
140
 
141
+ def switch_model(self, model_name):
142
+ try:
143
+ log_message(f"Переключение на модель: {model_name}")
144
+
145
+ new_llm = self.get_llm_model(model_name)
146
+ Settings.llm = new_llm
147
+
148
+ if self.vector_index is not None:
149
+ self.create_query_engine()
150
+ self.current_model = model_name
151
+ log_message(f"Модель успешно переключена на: {model_name}")
152
+ return f"✅ Модель переключена на: {model_name}"
153
+ else:
154
+ return " Ошибка: система не инициализирована"
155
+
156
+ except Exception as e:
157
+ error_msg = f"Ошибка переключения модели: {str(e)}"
158
+ log_message(error_msg)
159
+ return f" {error_msg}"
 
 
160
 
161
+ def rerank_nodes(self, query, nodes, top_k=10):
162
+ if not nodes or not self.reranker:
163
+ return nodes[:top_k]
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ try:
166
+ log_message(f"Переранжирую {len(nodes)} узлов")
167
+
168
+ pairs = []
169
+ for node in nodes:
170
+ pairs.append([query, node.text])
171
+
172
+ scores = self.reranker.predict(pairs)
173
+
174
+ scored_nodes = list(zip(nodes, scores))
175
+ scored_nodes.sort(key=lambda x: x[1], reverse=True)
176
+
177
+ reranked_nodes = [node for node, score in scored_nodes[:top_k]]
178
+ log_message(f"Возвращаю топ-{len(reranked_nodes)} переранжированных узлов")
179
+
180
+ return reranked_nodes
181
+ except Exception as e:
182
+ log_message(f"Ошибка переранжировки: {str(e)}")
183
+ return nodes[:top_k]
184
 
185
+ def retrieve_nodes(self, question):
186
+ if self.query_engine is None:
187
+ return []
188
+
189
+ try:
190
+ log_message(f"Извлекаю релевантные узлы для вопроса: {question}")
191
+ retrieved_nodes = self.query_engine.retriever.retrieve(question)
192
+ log_message(f"Извлечено {len(retrieved_nodes)} узлов")
193
+
194
+ log_message("Применяю переранжировку")
195
+ reranked_nodes = self.rerank_nodes(question, retrieved_nodes, top_k=10)
196
+
197
+ return reranked_nodes
198
+
199
+ except Exception as e:
200
+ log_message(f"Ошибка извлечения узлов: {str(e)}")
201
+ return []
202
 
203
+ def get_current_model(self):
204
+ return self.current_model
205
 
206
+ def is_initialized(self):
207
+ return self.query_engine is not None