linxinhua commited on
Commit
29555e6
·
verified ·
1 Parent(s): 2ad61c5

Update RAG_Learning_Assistant_with_Streaming.py from CIV3283/CIV3283_admin

Browse files
RAG_Learning_Assistant_with_Streaming.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import List, Dict, Tuple, Generator, Set
4
+ from openai import OpenAI
5
+ from vectorize_knowledge_base import KnowledgeBaseVectorizer
6
+ import json
7
+ from datetime import datetime
8
+ import re
9
+
10
+
11
+ class RAGLearningAssistant:
12
+ def __init__(self, api_key: str, model: str = "gpt-4.1-nano-2025-04-14", vector_db_path: str = ""):
13
+ """
14
+ 初始化RAG学习助手
15
+
16
+ Args:
17
+ api_key: OpenAI API密钥(必需)
18
+ model: 使用的模型名称
19
+ vector_db_path: 向量数据库路径
20
+ """
21
+ self.client = OpenAI(api_key=api_key)
22
+ self.vectorizer = KnowledgeBaseVectorizer(
23
+ api_key=api_key,
24
+ #data_path=os.path.join(vector_db_path, "knowledge_base.md")
25
+ data_path="knowledge_base.md"
26
+ )
27
+
28
+ # 预加载向量数据库到缓存
29
+ print("预加载向量数据库...")
30
+ self.vectorizer.load_vector_database()
31
+
32
+ # 模型配置
33
+ self.model = model
34
+ self.temperature = 0.2
35
+ self.max_tokens = 2000
36
+
37
+ # 系统提示词
38
+ self.system_prompt = """You are a helpful learning assistant specializing in road engineering.
39
+ Students can ask you questions with the following intents:
40
+ 1. Clarification: Requests to confirm understanding of a concept, parameter, or calculation.
41
+ 2. Instruction: Seeking step-by-step guidance for tasks or calculations.
42
+ 3. Explanatory: Asking for the reasoning behind a method, parameter choice, or principle.
43
+ 4. Information-seeking: Asking for where to find specific information in course materials.
44
+
45
+ You have access to a knowledge base of course materials. When answering questions:
46
+ 1. Stick to the provided context from the knowledge base.
47
+ 2. At the end of your response, provide students the 'title' & 'from' fields of the chunks that were used to answer the question. So that they can refer to the original source.
48
+ 3. If the knowledge base doesn't contain relevant information, say so. Students can go to the teaching team for further assistance.
49
+
50
+ In the response, enclose full mathematical formulas with $$ for proper Markdown rendering. Do not enclose individual parameters or variables with $$.
51
+ Bold key words if applicable.
52
+ """
53
+
54
+ # 查询重写的系统提示词 - 改进版本
55
+ self.rewrite_prompt = """You are a query rewriting assistant. Your task is to provide a summary of the conversation history and then rewrite user queries based on conversation history to make them more clear and complete.
56
+
57
+ Please format your response as follows:
58
+ SUMMARY: [Brief summary of the conversation context. Include key points, user intent, and any relevant details]
59
+ REWRITTEN_QUERY: [The rewritten query that incorporates context]
60
+
61
+ Rules:
62
+ 1. If there's relevant context from previous messages, incorporate it into the rewritten query
63
+ 2. Make implicit references explicit
64
+ 3. Maintain the original intent while adding clarity
65
+ 4. If the query is already clear and complete, keep it as is
66
+ 5. Always provide both SUMMARY and REWRITTEN_QUERY sections"""
67
+
68
+ # 实体提取的系统提示词
69
+ self.entity_extraction_prompt = """You are an expert in road engineering. Extract key entities from the given query.
70
+ Focus on:
71
+ 1. Technical terms and jargon specific to road engineering
72
+ 2. Formulas, equations, or mathematical concepts
73
+ 3. Parameters, specifications, or measurements
74
+ 4. Standards, methods, or procedures
75
+ 5. Materials, equipment, or structures
76
+
77
+ Return the entities as a JSON array of strings. Only include the most important and specific entities."""
78
+
79
+ # 对话历史
80
+ self.conversation_history = []
81
+
82
+
83
+ def rewrite_query(self, query: str) -> Tuple[str, str]:
84
+ """
85
+ 基于对话历史重写查询,并返回历史总结
86
+
87
+ Args:
88
+ query: 原始查询
89
+
90
+ Returns:
91
+ (历史总结, 重写后的查询)
92
+ """
93
+ # 构建消息
94
+ messages = [
95
+ {"role": "system", "content": self.rewrite_prompt}
96
+ ]
97
+
98
+ # 添加对话历史上下文
99
+ if self.conversation_history:
100
+ context = "Previous conversation:\n"
101
+ for msg in self.conversation_history[-6:]: # 最近3轮对话
102
+ role = "User" if msg["role"] == "user" else "Assistant"
103
+ # 截取前200个字符避免过长
104
+ content = msg["content"][:200] + "..." if len(msg["content"]) > 200 else msg["content"]
105
+ context += f"{role}: {content}\n"
106
+
107
+ messages.append({
108
+ "role": "user",
109
+ "content": f"{context}\n\nCurrent query: {query}\n\nPlease provide summary and rewritten query following the specified format:"
110
+ })
111
+ else:
112
+ # 没有历史时也要按格式返回
113
+ messages.append({
114
+ "role": "user",
115
+ "content": f"Current query: {query}\n\nPlease provide summary and rewritten query following the specified format:"
116
+ })
117
+
118
+ try:
119
+ response = self.client.chat.completions.create(
120
+ model=self.model,
121
+ messages=messages,
122
+ temperature=0.3, # 低温度确保一致性
123
+ max_tokens=2000
124
+ )
125
+
126
+ content = response.choices[0].message.content.strip()
127
+
128
+ # 改进的解析逻辑
129
+ summary = ""
130
+ rewritten = query # 默认值
131
+
132
+ # 使用正则表达式提取SUMMARY和REWRITTEN_QUERY
133
+ summary_match = re.search(r'SUMMARY:\s*(.*?)(?=REWRITTEN_QUERY:|$)', content, re.DOTALL | re.IGNORECASE)
134
+ rewritten_match = re.search(r'REWRITTEN_QUERY:\s*(.*?)$', content, re.DOTALL | re.IGNORECASE)
135
+
136
+
137
+ if summary_match:
138
+ summary = summary_match.group(1).strip()
139
+
140
+ if rewritten_match:
141
+ rewritten = rewritten_match.group(1).strip()
142
+
143
+ # 备用解析方法 - 如果正则表达式失败
144
+ if not summary and not rewritten_match:
145
+ lines = content.split('\n')
146
+ current_section = None
147
+ summary_lines = []
148
+ rewritten_lines = []
149
+
150
+ for line in lines:
151
+ line = line.strip()
152
+ if line.upper().startswith("SUMMARY"):
153
+ current_section = "summary"
154
+ # 提取SUMMARY:后面的内容
155
+ summary_part = line[line.upper().find("SUMMARY"):].replace("SUMMARY:", "").strip()
156
+ if summary_part:
157
+ summary_lines.append(summary_part)
158
+ elif line.upper().startswith("REWRITTEN_QUERY") or line.upper().startswith("REWRITTEN QUERY"):
159
+ current_section = "rewritten"
160
+ # 提取REWRITTEN_QUERY:后面的内容
161
+ rewritten_part = re.sub(r'^REWRITTEN[_\s]*QUERY[:\s]*', '', line, flags=re.IGNORECASE).strip()
162
+ if rewritten_part:
163
+ rewritten_lines.append(rewritten_part)
164
+ elif current_section == "summary" and line:
165
+ summary_lines.append(line)
166
+ elif current_section == "rewritten" and line:
167
+ rewritten_lines.append(line)
168
+
169
+ if summary_lines:
170
+ summary = " ".join(summary_lines)
171
+ if rewritten_lines:
172
+ rewritten = " ".join(rewritten_lines)
173
+
174
+ # 如果仍然没有获得有效结果,使用更简单的方法
175
+ if not summary and self.conversation_history:
176
+ summary = "继续之前的讨论"
177
+
178
+ if not rewritten or rewritten == query:
179
+ rewritten = query
180
+
181
+ print(f"Raw query: {query}")
182
+ print(f"Chat history summary: {summary}")
183
+ print(f"Rewrite query: {rewritten}")
184
+ return summary, rewritten
185
+
186
+ except Exception as e:
187
+ print(f"查询重写失败: {e}")
188
+ # 生成简单的历史总结作为备用
189
+ simple_summary = ""
190
+ if self.conversation_history:
191
+ simple_summary = "基于之前的对话内容"
192
+ return simple_summary, query # 失败时返回简单总结和原始查询
193
+
194
+ def extract_entities(self, original_query: str, summary: str, rewritten_query: str) -> List[str]:
195
+ """
196
+ 从原始查询、历史总结和重写查询中提取关键实体(专业术语、公式、参数等)
197
+
198
+ Args:
199
+ original_query: 原始用户查询
200
+ summary: 历史总结
201
+ rewritten_query: 重写后的查询文本
202
+
203
+ Returns:
204
+ 提取的实体列表
205
+ """
206
+ # 合并所有文本作为实体提取的输入
207
+ text_parts = []
208
+
209
+ # 添加原始查询
210
+ if original_query:
211
+ text_parts.append(f"Original query: {original_query}")
212
+
213
+ # 添加历史总结
214
+ if summary:
215
+ text_parts.append(f"Context summary: {summary}")
216
+
217
+ # 添加重写查询
218
+ if rewritten_query and rewritten_query != original_query:
219
+ text_parts.append(f"Rewritten query: {rewritten_query}")
220
+
221
+ combined_text = " | ".join(text_parts)
222
+
223
+ messages = [
224
+ {"role": "system", "content": self.entity_extraction_prompt},
225
+ {"role": "user", "content": f"Text to extract entities from: {combined_text}\n\nExtract entities as JSON array:"}
226
+ ]
227
+
228
+ try:
229
+ response = self.client.chat.completions.create(
230
+ model=self.model,
231
+ messages=messages,
232
+ temperature=0.3,
233
+ max_tokens=200
234
+ )
235
+
236
+ content = response.choices[0].message.content.strip()
237
+
238
+ # 尝试解析JSON
239
+ try:
240
+ # 提取JSON数组(处理可能的markdown格式)
241
+ json_match = re.search(r'\[.*?\]', content, re.DOTALL)
242
+ if json_match:
243
+ entities = json.loads(json_match.group())
244
+ else:
245
+ entities = json.loads(content)
246
+
247
+ print(f"Extracted entities: {entities}")
248
+ return entities
249
+
250
+ except json.JSONDecodeError:
251
+ # 如果JSON解析失败,尝试简单的文本处理
252
+ print(f"JSON解析失败,使用备用方法")
253
+ # 查找引号中的内容
254
+ entities = re.findall(r'"([^"]+)"', content)
255
+ return entities if entities else self.simple_entity_extraction(combined_text)
256
+
257
+ except Exception as e:
258
+ print(f"实体提取失败: {e}")
259
+ # 失败时使用简单的关键词提取
260
+ return self.simple_entity_extraction(combined_text)
261
+
262
+ def simple_entity_extraction(self, query: str) -> List[str]:
263
+ """
264
+ 简单的实体提取备用方法
265
+
266
+ Args:
267
+ query: 查询文本
268
+
269
+ Returns:
270
+ 提取的关键词列表
271
+ """
272
+ # 移除常见停用词
273
+ stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
274
+ 'of', 'with', 'by', 'from', 'what', 'how', 'when', 'where', 'why',
275
+ 'is', 'are', 'was', 'were', 'been', 'be', 'have', 'has', 'had',
276
+ 'original', 'query', 'context', 'summary', 'rewritten'} # 添加新的停用词
277
+
278
+ # 分词并过滤
279
+ words = query.lower().split()
280
+ entities = [w for w in words if w not in stop_words and len(w) > 2]
281
+
282
+ # 查找可能的专业术语(包含大写字母或数字)
283
+ special_terms = re.findall(r'\b[A-Z][a-zA-Z]*\b|\b\w*\d+\w*\b', query)
284
+ entities.extend(special_terms)
285
+
286
+ # 去重并返回
287
+ return list(set(entities))[:5] # 最多返回5个实体
288
+
289
+ def enhanced_search(self, query: str, top_k: int = 3) -> Tuple[str, str, List[str], List[Tuple[Dict, float, Dict]]]:
290
+ """
291
+ 增强搜索:重写查询 -> 提取实体 -> 基于实体搜索(优化版本)
292
+
293
+ Args:
294
+ query: 原始查询
295
+ top_k: 返回的结果数
296
+
297
+ Returns:
298
+ (历史总结, 重写后的查询, 提取的实体, 搜索结果)
299
+ """
300
+ # 1. 重写查询并获取历史总结
301
+ summary, rewritten_query = self.rewrite_query(query)
302
+
303
+ # 2. 基于原始查询、总结和重写查询提取实体
304
+ entities = self.extract_entities(query, summary, rewritten_query)
305
+
306
+ # 3. 基于实体搜索(使用优化的批量搜索)
307
+ if entities:
308
+ # 使用优化的批量搜索方法
309
+ search_results = self.vectorizer.search_with_entities_optimized(entities, top_k)
310
+ else:
311
+ # 如果没有提取到实体,使用重写后的查询进行搜索
312
+ print("未提取到实体,使用完整查询搜索")
313
+ search_results = self.vectorizer.search_similar(
314
+ rewritten_query,
315
+ top_k=top_k,
316
+ title_weight=0.4,
317
+ content_weight=0.3,
318
+ full_weight=0.3
319
+ )
320
+
321
+ return summary, rewritten_query, entities, search_results
322
+
323
+ def format_context(self, search_results: List[Tuple[Dict, float, Dict]]) -> str:
324
+ """
325
+ 格式化搜索结果作为上下文
326
+
327
+ Args:
328
+ search_results: 搜索结果列表
329
+
330
+ Returns:
331
+ 格式化的上下文字符串
332
+ """
333
+ if not search_results:
334
+ return ""
335
+
336
+ context_parts = []
337
+ for i, result in enumerate(search_results, 1):
338
+ entry, combined_score, details = result
339
+ # 只显示 title, source, content,不显示 id
340
+ context_parts.append(
341
+ #f"[Source {i}]\n"
342
+ f"Title: {entry['title']}\n"
343
+ f"From: {entry['source']}\n"
344
+ f"Content: {entry['content']}\n"
345
+ )
346
+
347
+ return "RELEVANT KNOWLEDGE BASE CONTENT:\n" + "\n---\n".join(context_parts)
348
+
349
+ def build_messages(self, query: str, context: str) -> List[Dict[str, str]]:
350
+ """
351
+ 构建消息列表,包含系统提示、上下文和用户查询
352
+
353
+ Args:
354
+ query: 用户查询
355
+ context: 知识库上下文
356
+
357
+ Returns:
358
+ 消息列表
359
+ """
360
+ messages = [
361
+ {"role": "system", "content": self.system_prompt}
362
+ ]
363
+
364
+ # 添加对话历史(保留最近5轮对话)
365
+ for msg in self.conversation_history[-10:]: # 最多保留5轮对话(10条消息)
366
+ messages.append(msg)
367
+
368
+ # 构建用户消息,包含上下文
369
+ user_message = query
370
+ if context:
371
+ user_message = f"{context}\n\nUSER QUESTION: {query}"
372
+
373
+ messages.append({"role": "user", "content": user_message})
374
+
375
+ return messages
376
+
377
+ def generate_response_stream(self, query: str) -> Generator[str, None, None]:
378
+ """
379
+ 生成流式响应
380
+
381
+ Args:
382
+ query: 用户查询
383
+
384
+ Yields:
385
+ 响应文本片段
386
+ """
387
+ # 1. 增强搜索(现在使用优化版本)
388
+ print("正在处理查询...")
389
+ summary, rewritten_query, entities, search_results = self.enhanced_search(query)
390
+
391
+ # 2. 格式化上下文
392
+ context = self.format_context(search_results)
393
+
394
+ # 3. 构建消息(使用原始查询,但包含基于实体搜索的上下文)
395
+ messages = self.build_messages(query, context)
396
+
397
+ # 4. 调用OpenAI API进行流式生成
398
+ try:
399
+ stream = self.client.chat.completions.create(
400
+ model=self.model,
401
+ messages=messages,
402
+ temperature=self.temperature,
403
+ max_tokens=self.max_tokens,
404
+ stream=True
405
+ )
406
+
407
+ # 收集完整响应用于保存到历史
408
+ full_response = ""
409
+
410
+ # 首先返回搜索信息
411
+ search_info = f"\n**Query Analysis:**\n"
412
+ search_info += f"- Query: {query}\n"
413
+ if summary:
414
+ search_info += f"- Summary of history: {summary}\n"
415
+ if rewritten_query != query:
416
+ search_info += f"- Rewrite query: {rewritten_query}\n"
417
+ search_info += f"- Key entities: {', '.join(entities) if entities else 'No specific entities extracted'}\n"
418
+
419
+ if search_results:
420
+ search_info += f"\n**Relevant Sources:**\n"
421
+ for result in search_results:
422
+ entry, combined_score, details = result
423
+ # 给用户显示时包含 ID 和相关度分数
424
+ search_info += f"- [{entry['id']}] {entry['title']} (Relevance: {combined_score:.3f})\n"
425
+ search_info += "\n**Response:**\n"
426
+ else:
427
+ search_info += "\n**Response:** (No relevant knowledge base content found, answering based on general knowledge)\n"
428
+
429
+ # 添加缓存信息(调试用)
430
+ cache_info = self.vectorizer.get_cache_info()
431
+ if cache_info['is_cached']:
432
+ search_info += f"The vector db has been cached, containing {cache_info['cache_size']} entries\n\n"
433
+
434
+ yield search_info
435
+
436
+ # 流式返回生成的内容
437
+ for chunk in stream:
438
+ if chunk.choices[0].delta.content is not None:
439
+ content = chunk.choices[0].delta.content
440
+ full_response += content
441
+ yield content
442
+
443
+ # 保存到对话历史
444
+ self.conversation_history.append({"role": "user", "content": query})
445
+ self.conversation_history.append({"role": "assistant", "content": full_response})
446
+
447
+ except Exception as e:
448
+ yield f"\n\n错误:生成响应时出现问题 - {str(e)}"
449
+
450
+ def generate_response(self, query: str) -> str:
451
+ """
452
+ 生成完整响应(非流式)
453
+
454
+ Args:
455
+ query: 用户查询
456
+
457
+ Returns:
458
+ 完整的响应文本
459
+ """
460
+ response_parts = []
461
+ for part in self.generate_response_stream(query):
462
+ response_parts.append(part)
463
+ return "".join(response_parts)
464
+
465
+ def clear_history(self):
466
+ """清除对话历史"""
467
+ self.conversation_history = []
468
+ print("对话历史已清除")
469
+
470
+ def clear_vector_cache(self):
471
+ """清除向量数据库缓存"""
472
+ self.vectorizer.clear_cache()
473
+ print("向量数据库缓存已清除")
474
+
475
+ def reload_vector_database(self):
476
+ """重新加载向量数据库"""
477
+ print("重新加载向量数据库...")
478
+ self.vectorizer.load_vector_database(force_reload=True)
479
+ print("向量数据库重新加载完成")
480
+
481
+ def get_system_status(self) -> Dict:
482
+ """
483
+ 获取系统状态信息
484
+
485
+ Returns:
486
+ 系统状态字典
487
+ """
488
+ cache_info = self.vectorizer.get_cache_info()
489
+ return {
490
+ 'model': self.model,
491
+ 'conversation_turns': len(self.conversation_history) // 2,
492
+ 'vector_cache': cache_info,
493
+ 'last_update': datetime.now().isoformat()
494
+ }
495
+
496
+ def save_conversation(self, filepath: str = None):
497
+ """
498
+ 保存对话历史
499
+
500
+ Args:
501
+ filepath: 保存路径
502
+ """
503
+ if filepath is None:
504
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
505
+ filepath = f"conversation_{timestamp}.json"
506
+
507
+ conversation_data = {
508
+ "timestamp": datetime.now().isoformat(),
509
+ "model": self.model,
510
+ "system_status": self.get_system_status(),
511
+ "history": self.conversation_history
512
+ }
513
+
514
+ with open(filepath, 'w', encoding='utf-8') as f:
515
+ json.dump(conversation_data, f, ensure_ascii=False, indent=2)
516
+
517
+ print(f"对话已保存到: {filepath}")
518
+