linxinhua commited on
Commit
6fd3014
·
verified ·
1 Parent(s): 95b55e1

Update RAG_Learning_Assistant_with_Streaming.py via admin tool

Browse files
RAG_Learning_Assistant_with_Streaming.py CHANGED
@@ -11,27 +11,32 @@ import re
11
  class RAGLearningAssistant:
12
  def __init__(self, api_key: str, model: str = "gpt-4.1-nano-2025-04-14", vector_db_path: str = ""):
13
  """
14
- 初始化RAG学习助手
15
 
16
  Args:
17
  api_key: OpenAI API密钥(必需)
18
  model: 使用的模型名称
19
- vector_db_path: 向量数据库路径
20
  """
21
  self.client = OpenAI(api_key=api_key)
 
 
22
  self.vectorizer = KnowledgeBaseVectorizer(
23
  api_key=api_key,
24
- #data_path=os.path.join(vector_db_path, "knowledge_base.md")
25
- data_path="knowledge_base.md"
26
  )
27
 
28
  # 预加载向量数据库到缓存
29
- print("预加载向量数据库...")
30
- self.vectorizer.load_vector_database()
 
 
 
 
31
 
32
  # 模型配置
33
  self.model = model
34
- self.temperature = 0.2
35
  self.max_tokens = 2000
36
 
37
  # 系统提示词
@@ -119,7 +124,7 @@ Return the entities as a JSON array of strings. Only include the most important
119
  response = self.client.chat.completions.create(
120
  model=self.model,
121
  messages=messages,
122
- temperature=0.3, # 低温度确保一致性
123
  max_tokens=2000
124
  )
125
 
@@ -173,22 +178,22 @@ Return the entities as a JSON array of strings. Only include the most important
173
 
174
  # 如果仍然没有获得有效结果,使用更简单的方法
175
  if not summary and self.conversation_history:
176
- summary = "继续之前的讨论"
177
 
178
  if not rewritten or rewritten == query:
179
  rewritten = query
180
 
181
- print(f"Raw query: {query}")
182
- print(f"Chat history summary: {summary}")
183
- print(f"Rewrite query: {rewritten}")
184
  return summary, rewritten
185
 
186
  except Exception as e:
187
- print(f"查询重写失败: {e}")
188
  # 生成简单的历史总结作为备用
189
  simple_summary = ""
190
  if self.conversation_history:
191
- simple_summary = "基于之前的对话内容"
192
  return simple_summary, query # 失败时返回简单总结和原始查询
193
 
194
  def extract_entities(self, original_query: str, summary: str, rewritten_query: str) -> List[str]:
@@ -229,8 +234,8 @@ Return the entities as a JSON array of strings. Only include the most important
229
  response = self.client.chat.completions.create(
230
  model=self.model,
231
  messages=messages,
232
- temperature=0.3,
233
- max_tokens=200
234
  )
235
 
236
  content = response.choices[0].message.content.strip()
@@ -244,18 +249,18 @@ Return the entities as a JSON array of strings. Only include the most important
244
  else:
245
  entities = json.loads(content)
246
 
247
- print(f"Extracted entities: {entities}")
248
  return entities
249
 
250
  except json.JSONDecodeError:
251
  # 如果JSON解析失败,尝试简单的文本处理
252
- print(f"JSON解析失败,使用备用方法")
253
  # 查找引号中的内容
254
  entities = re.findall(r'"([^"]+)"', content)
255
  return entities if entities else self.simple_entity_extraction(combined_text)
256
 
257
  except Exception as e:
258
- print(f"实体提取失败: {e}")
259
  # 失败时使用��单的关键词提取
260
  return self.simple_entity_extraction(combined_text)
261
 
@@ -284,9 +289,9 @@ Return the entities as a JSON array of strings. Only include the most important
284
  entities.extend(special_terms)
285
 
286
  # 去重并返回
287
- return list(set(entities))[:5] # 最多返回5个实体
288
 
289
- def enhanced_search(self, query: str, top_k: int = 3) -> Tuple[str, str, List[str], List[Tuple[Dict, float, Dict]]]:
290
  """
291
  增强搜索:重写查询 -> 提取实体 -> 基于实体搜索(优化版本)
292
 
@@ -309,12 +314,12 @@ Return the entities as a JSON array of strings. Only include the most important
309
  search_results = self.vectorizer.search_with_entities_optimized(entities, top_k)
310
  else:
311
  # 如果没有提取到实体,使用重写后的查询进行搜索
312
- print("未提取到实体,使用完整查询搜索")
313
  search_results = self.vectorizer.search_similar(
314
  rewritten_query,
315
  top_k=top_k,
316
- title_weight=0.4,
317
- content_weight=0.3,
318
  full_weight=0.3
319
  )
320
 
@@ -338,7 +343,6 @@ Return the entities as a JSON array of strings. Only include the most important
338
  entry, combined_score, details = result
339
  # 只显示 title, source, content,不显示 id
340
  context_parts.append(
341
- #f"[Source {i}]\n"
342
  f"Title: {entry['title']}\n"
343
  f"From: {entry['source']}\n"
344
  f"Content: {entry['content']}\n"
@@ -385,7 +389,7 @@ Return the entities as a JSON array of strings. Only include the most important
385
  响应文本片段
386
  """
387
  # 1. 增强搜索(现在使用优化版本)
388
- print("正在处理查询...")
389
  summary, rewritten_query, entities, search_results = self.enhanced_search(query)
390
 
391
  # 2. 格式化上下文
@@ -413,7 +417,7 @@ Return the entities as a JSON array of strings. Only include the most important
413
  if summary:
414
  search_info += f"- Summary of history: {summary}\n"
415
  if rewritten_query != query:
416
- search_info += f"- Rewrite query: {rewritten_query}\n"
417
  search_info += f"- Key entities: {', '.join(entities) if entities else 'No specific entities extracted'}\n"
418
 
419
  if search_results:
@@ -429,7 +433,7 @@ Return the entities as a JSON array of strings. Only include the most important
429
  # 添加缓存信息(调试用)
430
  cache_info = self.vectorizer.get_cache_info()
431
  if cache_info['is_cached']:
432
- search_info += f"The vector db has been cached, containing {cache_info['cache_size']} entries\n\n"
433
 
434
  yield search_info
435
 
@@ -445,7 +449,7 @@ Return the entities as a JSON array of strings. Only include the most important
445
  self.conversation_history.append({"role": "assistant", "content": full_response})
446
 
447
  except Exception as e:
448
- yield f"\n\n错误:生成响应时出现问题 - {str(e)}"
449
 
450
  def generate_response(self, query: str) -> str:
451
  """
@@ -465,18 +469,18 @@ Return the entities as a JSON array of strings. Only include the most important
465
  def clear_history(self):
466
  """清除对话历史"""
467
  self.conversation_history = []
468
- print("对话历史已清除")
469
 
470
  def clear_vector_cache(self):
471
  """清除向量数据库缓存"""
472
  self.vectorizer.clear_cache()
473
- print("向量数据库缓存已清除")
474
 
475
  def reload_vector_database(self):
476
  """重新加载向量数据库"""
477
- print("重新加载向量数据库...")
478
  self.vectorizer.load_vector_database(force_reload=True)
479
- print("向量数据库重新加载完成")
480
 
481
  def get_system_status(self) -> Dict:
482
  """
@@ -514,5 +518,4 @@ Return the entities as a JSON array of strings. Only include the most important
514
  with open(filepath, 'w', encoding='utf-8') as f:
515
  json.dump(conversation_data, f, ensure_ascii=False, indent=2)
516
 
517
- print(f"对话已保存到: {filepath}")
518
-
 
11
  class RAGLearningAssistant:
12
  def __init__(self, api_key: str, model: str = "gpt-4.1-nano-2025-04-14", vector_db_path: str = ""):
13
  """
14
+ 初始化RAG学习助手(适配学生Space)
15
 
16
  Args:
17
  api_key: OpenAI API密钥(必需)
18
  model: 使用的模型名称
19
+ vector_db_path: 向量数据库所在目录路径(数据存储仓库的本地目录)
20
  """
21
  self.client = OpenAI(api_key=api_key)
22
+
23
+ # 使用修改后的KnowledgeBaseVectorizer,指定vector_db_dir
24
  self.vectorizer = KnowledgeBaseVectorizer(
25
  api_key=api_key,
26
+ vector_db_dir=vector_db_path # 传递数据存储仓库的本地目录
 
27
  )
28
 
29
  # 预加载向量数据库到缓存
30
+ print("[RAGLearningAssistant] Preloading vector database...")
31
+ load_result = self.vectorizer.load_vector_database()
32
+ if load_result[0] is not None:
33
+ print(f"[RAGLearningAssistant] Vector database loaded successfully")
34
+ else:
35
+ print(f"[RAGLearningAssistant] Warning: Failed to load vector database")
36
 
37
  # 模型配置
38
  self.model = model
39
+ self.temperature = 0.1
40
  self.max_tokens = 2000
41
 
42
  # 系统提示词
 
124
  response = self.client.chat.completions.create(
125
  model=self.model,
126
  messages=messages,
127
+ temperature=0.1, # 低温度确保一致性
128
  max_tokens=2000
129
  )
130
 
 
178
 
179
  # 如果仍然没有获得有效结果,使用更简单的方法
180
  if not summary and self.conversation_history:
181
+ summary = "Continue previous discussion"
182
 
183
  if not rewritten or rewritten == query:
184
  rewritten = query
185
 
186
+ print(f"[rewrite_query] Raw query: {query}")
187
+ print(f"[rewrite_query] Chat history summary: {summary}")
188
+ print(f"[rewrite_query] Rewritten query: {rewritten}")
189
  return summary, rewritten
190
 
191
  except Exception as e:
192
+ print(f"[rewrite_query] Query rewriting failed: {e}")
193
  # 生成简单的历史总结作为备用
194
  simple_summary = ""
195
  if self.conversation_history:
196
+ simple_summary = "Based on previous conversation content"
197
  return simple_summary, query # 失败时返回简单总结和原始查询
198
 
199
  def extract_entities(self, original_query: str, summary: str, rewritten_query: str) -> List[str]:
 
234
  response = self.client.chat.completions.create(
235
  model=self.model,
236
  messages=messages,
237
+ temperature=self.temperature,
238
+ max_tokens=self.max_tokens
239
  )
240
 
241
  content = response.choices[0].message.content.strip()
 
249
  else:
250
  entities = json.loads(content)
251
 
252
+ print(f"[extract_entities] Extracted entities: {entities}")
253
  return entities
254
 
255
  except json.JSONDecodeError:
256
  # 如果JSON解析失败,尝试简单的文本处理
257
+ print(f"[extract_entities] JSON parsing failed, using backup method")
258
  # 查找引号中的内容
259
  entities = re.findall(r'"([^"]+)"', content)
260
  return entities if entities else self.simple_entity_extraction(combined_text)
261
 
262
  except Exception as e:
263
+ print(f"[extract_entities] Entity extraction failed: {e}")
264
  # 失败时使用��单的关键词提取
265
  return self.simple_entity_extraction(combined_text)
266
 
 
289
  entities.extend(special_terms)
290
 
291
  # 去重并返回
292
+ return list(set(entities))[:10] # 最多返回5个实体
293
 
294
+ def enhanced_search(self, query: str, top_k: int = 5) -> Tuple[str, str, List[str], List[Tuple[Dict, float, Dict]]]:
295
  """
296
  增强搜索:重写查询 -> 提取实体 -> 基于实体搜索(优化版本)
297
 
 
314
  search_results = self.vectorizer.search_with_entities_optimized(entities, top_k)
315
  else:
316
  # 如果没有提取到实体,使用重写后的查询进行搜索
317
+ print("[enhanced_search] No entities extracted, using full query search")
318
  search_results = self.vectorizer.search_similar(
319
  rewritten_query,
320
  top_k=top_k,
321
+ title_weight=0.2,
322
+ content_weight=0.5,
323
  full_weight=0.3
324
  )
325
 
 
343
  entry, combined_score, details = result
344
  # 只显示 title, source, content,不显示 id
345
  context_parts.append(
 
346
  f"Title: {entry['title']}\n"
347
  f"From: {entry['source']}\n"
348
  f"Content: {entry['content']}\n"
 
389
  响应文本片段
390
  """
391
  # 1. 增强搜索(现在使用优化版本)
392
+ print("[generate_response_stream] Processing query...")
393
  summary, rewritten_query, entities, search_results = self.enhanced_search(query)
394
 
395
  # 2. 格式化上下文
 
417
  if summary:
418
  search_info += f"- Summary of history: {summary}\n"
419
  if rewritten_query != query:
420
+ search_info += f"- Rewritten query: {rewritten_query}\n"
421
  search_info += f"- Key entities: {', '.join(entities) if entities else 'No specific entities extracted'}\n"
422
 
423
  if search_results:
 
433
  # 添加缓存信息(调试用)
434
  cache_info = self.vectorizer.get_cache_info()
435
  if cache_info['is_cached']:
436
+ search_info += f"💡 Vector database cached with {cache_info['cache_size']} entries\n\n"
437
 
438
  yield search_info
439
 
 
449
  self.conversation_history.append({"role": "assistant", "content": full_response})
450
 
451
  except Exception as e:
452
+ yield f"\n\nError: Problem occurred while generating response - {str(e)}"
453
 
454
  def generate_response(self, query: str) -> str:
455
  """
 
469
  def clear_history(self):
470
  """清除对话历史"""
471
  self.conversation_history = []
472
+ print("[clear_history] Conversation history cleared")
473
 
474
  def clear_vector_cache(self):
475
  """清除向量数据库缓存"""
476
  self.vectorizer.clear_cache()
477
+ print("[clear_vector_cache] Vector database cache cleared")
478
 
479
  def reload_vector_database(self):
480
  """重新加载向量数据库"""
481
+ print("[reload_vector_database] Reloading vector database...")
482
  self.vectorizer.load_vector_database(force_reload=True)
483
+ print("[reload_vector_database] Vector database reload completed")
484
 
485
  def get_system_status(self) -> Dict:
486
  """
 
518
  with open(filepath, 'w', encoding='utf-8') as f:
519
  json.dump(conversation_data, f, ensure_ascii=False, indent=2)
520
 
521
+ print(f"[save_conversation] Conversation saved to: {filepath}")