linxinhua commited on
Commit
e8fa581
·
verified ·
1 Parent(s): dbe46e8

Update RAG_Learning_Assistant_with_Streaming.py from CIV3283/CIV3283_admin

Browse files
RAG_Learning_Assistant_with_Streaming.py CHANGED
@@ -11,32 +11,27 @@ import re
11
  class RAGLearningAssistant:
12
  def __init__(self, api_key: str, model: str = "gpt-4.1-nano-2025-04-14", vector_db_path: str = ""):
13
  """
14
- 初始化RAG学习助手(适配学生Space)
15
 
16
  Args:
17
  api_key: OpenAI API密钥(必需)
18
  model: 使用的模型名称
19
- vector_db_path: 向量数据库所在目录路径(数据存储仓库的本地目录)
20
  """
21
  self.client = OpenAI(api_key=api_key)
22
-
23
- # 使用修改后的KnowledgeBaseVectorizer,指定vector_db_dir
24
  self.vectorizer = KnowledgeBaseVectorizer(
25
  api_key=api_key,
26
- vector_db_dir=vector_db_path # 传递数据存储仓库的本地目录
 
27
  )
28
 
29
  # 预加载向量数据库到缓存
30
- print("[RAGLearningAssistant] Preloading vector database...")
31
- load_result = self.vectorizer.load_vector_database()
32
- if load_result[0] is not None:
33
- print(f"[RAGLearningAssistant] Vector database loaded successfully")
34
- else:
35
- print(f"[RAGLearningAssistant] Warning: Failed to load vector database")
36
 
37
  # 模型配置
38
  self.model = model
39
- self.temperature = 0.1
40
  self.max_tokens = 2000
41
 
42
  # 系统提示词
@@ -51,6 +46,9 @@ You have access to a knowledge base of course materials. When answering question
51
  1. Stick to the provided context from the knowledge base.
52
  2. At the end of your response, provide students the 'title' & 'from' fields of the chunks that were used to answer the question. So that they can refer to the original source.
53
  3. If the knowledge base doesn't contain relevant information, say so. Students can go to the teaching team for further assistance.
 
 
 
54
  """
55
 
56
  # 查询重写的系统提示词 - 改进版本
@@ -121,7 +119,7 @@ Return the entities as a JSON array of strings. Only include the most important
121
  response = self.client.chat.completions.create(
122
  model=self.model,
123
  messages=messages,
124
- temperature=0.1, # 低温度确保一致性
125
  max_tokens=2000
126
  )
127
 
@@ -175,22 +173,22 @@ Return the entities as a JSON array of strings. Only include the most important
175
 
176
  # 如果仍然没有获得有效结果,使用更简单的方法
177
  if not summary and self.conversation_history:
178
- summary = "Continue previous discussion"
179
 
180
  if not rewritten or rewritten == query:
181
  rewritten = query
182
 
183
- print(f"[rewrite_query] Raw query: {query}")
184
- print(f"[rewrite_query] Chat history summary: {summary}")
185
- print(f"[rewrite_query] Rewritten query: {rewritten}")
186
  return summary, rewritten
187
 
188
  except Exception as e:
189
- print(f"[rewrite_query] Query rewriting failed: {e}")
190
  # 生成简单的历史总结作为备用
191
  simple_summary = ""
192
  if self.conversation_history:
193
- simple_summary = "Based on previous conversation content"
194
  return simple_summary, query # 失败时返回简单总结和原始查询
195
 
196
  def extract_entities(self, original_query: str, summary: str, rewritten_query: str) -> List[str]:
@@ -231,8 +229,8 @@ Return the entities as a JSON array of strings. Only include the most important
231
  response = self.client.chat.completions.create(
232
  model=self.model,
233
  messages=messages,
234
- temperature=self.temperature,
235
- max_tokens=self.max_tokens
236
  )
237
 
238
  content = response.choices[0].message.content.strip()
@@ -246,18 +244,18 @@ Return the entities as a JSON array of strings. Only include the most important
246
  else:
247
  entities = json.loads(content)
248
 
249
- print(f"[extract_entities] Extracted entities: {entities}")
250
  return entities
251
 
252
  except json.JSONDecodeError:
253
  # 如果JSON解析失败,尝试简单的文本处理
254
- print(f"[extract_entities] JSON parsing failed, using backup method")
255
  # 查找引号中的内容
256
  entities = re.findall(r'"([^"]+)"', content)
257
  return entities if entities else self.simple_entity_extraction(combined_text)
258
 
259
  except Exception as e:
260
- print(f"[extract_entities] Entity extraction failed: {e}")
261
  # 失败时使用简单的关键词提取
262
  return self.simple_entity_extraction(combined_text)
263
 
@@ -286,9 +284,9 @@ Return the entities as a JSON array of strings. Only include the most important
286
  entities.extend(special_terms)
287
 
288
  # 去重并返回
289
- return list(set(entities))[:10] # 最多返回5个实体
290
 
291
- def enhanced_search(self, query: str, top_k: int = 5) -> Tuple[str, str, List[str], List[Tuple[Dict, float, Dict]]]:
292
  """
293
  增强搜索:重写查询 -> 提取实体 -> 基于实体搜索(优化版本)
294
 
@@ -311,12 +309,12 @@ Return the entities as a JSON array of strings. Only include the most important
311
  search_results = self.vectorizer.search_with_entities_optimized(entities, top_k)
312
  else:
313
  # 如果没有提取到实体,使用重写后的查询进行搜索
314
- print("[enhanced_search] No entities extracted, using full query search")
315
  search_results = self.vectorizer.search_similar(
316
  rewritten_query,
317
  top_k=top_k,
318
- title_weight=0.2,
319
- content_weight=0.5,
320
  full_weight=0.3
321
  )
322
 
@@ -340,6 +338,7 @@ Return the entities as a JSON array of strings. Only include the most important
340
  entry, combined_score, details = result
341
  # 只显示 title, source, content,不显示 id
342
  context_parts.append(
 
343
  f"Title: {entry['title']}\n"
344
  f"From: {entry['source']}\n"
345
  f"Content: {entry['content']}\n"
@@ -386,7 +385,7 @@ Return the entities as a JSON array of strings. Only include the most important
386
  响应文本片段
387
  """
388
  # 1. 增强搜索(现在使用优化版本)
389
- print("[generate_response_stream] Processing query...")
390
  summary, rewritten_query, entities, search_results = self.enhanced_search(query)
391
 
392
  # 2. 格式化上下文
@@ -414,7 +413,7 @@ Return the entities as a JSON array of strings. Only include the most important
414
  if summary:
415
  search_info += f"- Summary of history: {summary}\n"
416
  if rewritten_query != query:
417
- search_info += f"- Rewritten query: {rewritten_query}\n"
418
  search_info += f"- Key entities: {', '.join(entities) if entities else 'No specific entities extracted'}\n"
419
 
420
  if search_results:
@@ -430,7 +429,7 @@ Return the entities as a JSON array of strings. Only include the most important
430
  # 添加缓存信息(调试用)
431
  cache_info = self.vectorizer.get_cache_info()
432
  if cache_info['is_cached']:
433
- search_info += f"💡 Vector database cached with {cache_info['cache_size']} entries\n\n"
434
 
435
  yield search_info
436
 
@@ -446,7 +445,7 @@ Return the entities as a JSON array of strings. Only include the most important
446
  self.conversation_history.append({"role": "assistant", "content": full_response})
447
 
448
  except Exception as e:
449
- yield f"\n\nError: Problem occurred while generating response - {str(e)}"
450
 
451
  def generate_response(self, query: str) -> str:
452
  """
@@ -466,18 +465,18 @@ Return the entities as a JSON array of strings. Only include the most important
466
  def clear_history(self):
467
  """清除对话历史"""
468
  self.conversation_history = []
469
- print("[clear_history] Conversation history cleared")
470
 
471
  def clear_vector_cache(self):
472
  """清除向量数据库缓存"""
473
  self.vectorizer.clear_cache()
474
- print("[clear_vector_cache] Vector database cache cleared")
475
 
476
  def reload_vector_database(self):
477
  """重新加载向量数据库"""
478
- print("[reload_vector_database] Reloading vector database...")
479
  self.vectorizer.load_vector_database(force_reload=True)
480
- print("[reload_vector_database] Vector database reload completed")
481
 
482
  def get_system_status(self) -> Dict:
483
  """
@@ -515,4 +514,5 @@ Return the entities as a JSON array of strings. Only include the most important
515
  with open(filepath, 'w', encoding='utf-8') as f:
516
  json.dump(conversation_data, f, ensure_ascii=False, indent=2)
517
 
518
- print(f"[save_conversation] Conversation saved to: {filepath}")
 
 
11
  class RAGLearningAssistant:
12
  def __init__(self, api_key: str, model: str = "gpt-4.1-nano-2025-04-14", vector_db_path: str = ""):
13
  """
14
+ 初始化RAG学习助手
15
 
16
  Args:
17
  api_key: OpenAI API密钥(必需)
18
  model: 使用的模型名称
19
+ vector_db_path: 向量数据库路径
20
  """
21
  self.client = OpenAI(api_key=api_key)
 
 
22
  self.vectorizer = KnowledgeBaseVectorizer(
23
  api_key=api_key,
24
+ #data_path=os.path.join(vector_db_path, "knowledge_base.md")
25
+ data_path="knowledge_base.md"
26
  )
27
 
28
  # 预加载向量数据库到缓存
29
+ print("预加载向量数据库...")
30
+ self.vectorizer.load_vector_database()
 
 
 
 
31
 
32
  # 模型配置
33
  self.model = model
34
+ self.temperature = 0.2
35
  self.max_tokens = 2000
36
 
37
  # 系统提示词
 
46
  1. Stick to the provided context from the knowledge base.
47
  2. At the end of your response, provide students the 'title' & 'from' fields of the chunks that were used to answer the question. So that they can refer to the original source.
48
  3. If the knowledge base doesn't contain relevant information, say so. Students can go to the teaching team for further assistance.
49
+
50
+ In the response, enclose full mathematical formulas with $$ for proper Markdown rendering. Do not enclose individual parameters or variables with $$.
51
+ Bold key words if applicable.
52
  """
53
 
54
  # 查询重写的系统提示词 - 改进版本
 
119
  response = self.client.chat.completions.create(
120
  model=self.model,
121
  messages=messages,
122
+ temperature=0.3, # 低温度确保一致性
123
  max_tokens=2000
124
  )
125
 
 
173
 
174
  # 如果仍然没有获得有效结果,使用更简单的方法
175
  if not summary and self.conversation_history:
176
+ summary = "继续之前的讨论"
177
 
178
  if not rewritten or rewritten == query:
179
  rewritten = query
180
 
181
+ print(f"Raw query: {query}")
182
+ print(f"Chat history summary: {summary}")
183
+ print(f"Rewrite query: {rewritten}")
184
  return summary, rewritten
185
 
186
  except Exception as e:
187
+ print(f"查询重写失败: {e}")
188
  # 生成简单的历史总结作为备用
189
  simple_summary = ""
190
  if self.conversation_history:
191
+ simple_summary = "基于之前的对话内容"
192
  return simple_summary, query # 失败时返回简单总结和原始查询
193
 
194
  def extract_entities(self, original_query: str, summary: str, rewritten_query: str) -> List[str]:
 
229
  response = self.client.chat.completions.create(
230
  model=self.model,
231
  messages=messages,
232
+ temperature=0.3,
233
+ max_tokens=200
234
  )
235
 
236
  content = response.choices[0].message.content.strip()
 
244
  else:
245
  entities = json.loads(content)
246
 
247
+ print(f"Extracted entities: {entities}")
248
  return entities
249
 
250
  except json.JSONDecodeError:
251
  # 如果JSON解析失败,尝试简单的文本处理
252
+ print(f"JSON解析失败,使用备用方法")
253
  # 查找引号中的内容
254
  entities = re.findall(r'"([^"]+)"', content)
255
  return entities if entities else self.simple_entity_extraction(combined_text)
256
 
257
  except Exception as e:
258
+ print(f"实体提取失败: {e}")
259
  # 失败时使用简单的关键词提取
260
  return self.simple_entity_extraction(combined_text)
261
 
 
284
  entities.extend(special_terms)
285
 
286
  # 去重并返回
287
+ return list(set(entities))[:5] # 最多返回5个实体
288
 
289
+ def enhanced_search(self, query: str, top_k: int = 3) -> Tuple[str, str, List[str], List[Tuple[Dict, float, Dict]]]:
290
  """
291
  增强搜索:重写查询 -> 提取实体 -> 基于实体搜索(优化版本)
292
 
 
309
  search_results = self.vectorizer.search_with_entities_optimized(entities, top_k)
310
  else:
311
  # 如果没有提取到实体,使用重写后的查询进行搜索
312
+ print("未提取到实体,使用完整查询搜索")
313
  search_results = self.vectorizer.search_similar(
314
  rewritten_query,
315
  top_k=top_k,
316
+ title_weight=0.4,
317
+ content_weight=0.3,
318
  full_weight=0.3
319
  )
320
 
 
338
  entry, combined_score, details = result
339
  # 只显示 title, source, content,不显示 id
340
  context_parts.append(
341
+ #f"[Source {i}]\n"
342
  f"Title: {entry['title']}\n"
343
  f"From: {entry['source']}\n"
344
  f"Content: {entry['content']}\n"
 
385
  响应文本片段
386
  """
387
  # 1. 增强搜索(现在使用优化版本)
388
+ print("正在处理查询...")
389
  summary, rewritten_query, entities, search_results = self.enhanced_search(query)
390
 
391
  # 2. 格式化上下文
 
413
  if summary:
414
  search_info += f"- Summary of history: {summary}\n"
415
  if rewritten_query != query:
416
+ search_info += f"- Rewrite query: {rewritten_query}\n"
417
  search_info += f"- Key entities: {', '.join(entities) if entities else 'No specific entities extracted'}\n"
418
 
419
  if search_results:
 
429
  # 添加缓存信息(调试用)
430
  cache_info = self.vectorizer.get_cache_info()
431
  if cache_info['is_cached']:
432
+ search_info += f"The vector db has been cached, containing {cache_info['cache_size']} entries\n\n"
433
 
434
  yield search_info
435
 
 
445
  self.conversation_history.append({"role": "assistant", "content": full_response})
446
 
447
  except Exception as e:
448
+ yield f"\n\n错误:生成响应时出现问题 - {str(e)}"
449
 
450
  def generate_response(self, query: str) -> str:
451
  """
 
465
  def clear_history(self):
466
  """清除对话历史"""
467
  self.conversation_history = []
468
+ print("对话历史已清除")
469
 
470
  def clear_vector_cache(self):
471
  """清除向量数据库缓存"""
472
  self.vectorizer.clear_cache()
473
+ print("向量数据库缓存已清除")
474
 
475
  def reload_vector_database(self):
476
  """重新加载向量数据库"""
477
+ print("重新加载向量数据库...")
478
  self.vectorizer.load_vector_database(force_reload=True)
479
+ print("向量数据库重新加载完成")
480
 
481
  def get_system_status(self) -> Dict:
482
  """
 
514
  with open(filepath, 'w', encoding='utf-8') as f:
515
  json.dump(conversation_data, f, ensure_ascii=False, indent=2)
516
 
517
+ print(f"对话已保存到: {filepath}")
518
+