linxinhua commited on
Commit
100af11
·
verified ·
1 Parent(s): 5dc4fb8

Upload 6 files

Browse files
RAG_Learning_Assistant_with_Streaming.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import List, Dict, Tuple, Generator, Set
4
+ from openai import OpenAI
5
+ from vectorize_knowledge_base import KnowledgeBaseVectorizer
6
+ import json
7
+ from datetime import datetime
8
+ import re
9
+
10
+
11
+ class RAGLearningAssistant:
12
+ def __init__(self, api_key: str, model: str = "gpt-4.1-nano-2025-04-14", vector_db_path: str = ""):
13
+ """
14
+ 初始化RAG学习助手(适配学生Space)
15
+
16
+ Args:
17
+ api_key: OpenAI API密钥(必需)
18
+ model: 使用的模型名称
19
+ vector_db_path: 向量数据库所在目录路径(数据存储仓库的本地目录)
20
+ """
21
+ self.client = OpenAI(api_key=api_key)
22
+
23
+ # 使用修改后的KnowledgeBaseVectorizer,指定vector_db_dir
24
+ self.vectorizer = KnowledgeBaseVectorizer(
25
+ api_key=api_key,
26
+ vector_db_dir=vector_db_path # 传递数据存储仓库的本地目录
27
+ )
28
+
29
+ # 预加载向量数据库到缓存
30
+ print("[RAGLearningAssistant] Preloading vector database...")
31
+ load_result = self.vectorizer.load_vector_database()
32
+ if load_result[0] is not None:
33
+ print(f"[RAGLearningAssistant] Vector database loaded successfully")
34
+ else:
35
+ print(f"[RAGLearningAssistant] Warning: Failed to load vector database")
36
+
37
+ # 模型配置
38
+ self.model = model
39
+ self.temperature = 0.1
40
+ self.max_tokens = 2000
41
+
42
+ # 系统提示词
43
+ self.system_prompt = """You are a helpful learning assistant specializing in road engineering.
44
+ Students can ask you questions with the following intents:
45
+ 1. Clarification: Requests to confirm understanding of a concept, parameter, or calculation.
46
+ 2. Instruction: Seeking step-by-step guidance for tasks or calculations.
47
+ 3. Explanatory: Asking for the reasoning behind a method, parameter choice, or principle.
48
+ 4. Information-seeking: Asking for where to find specific information in course materials.
49
+
50
+ You have access to a knowledge base of course materials. When answering questions:
51
+ 1. Stick to the provided context from the knowledge base.
52
+ 2. At the end of your response, provide students the 'title' & 'from' fields of the chunks that were used to answer the question. So that they can refer to the original source.
53
+ 3. If the knowledge base doesn't contain relevant information, say so. Students can go to the teaching team for further assistance.
54
+ """
55
+
56
+ # 查询重写的系统提示词 - 改进版本
57
+ self.rewrite_prompt = """You are a query rewriting assistant. Your task is to provide a summary of the conversation history and then rewrite user queries based on conversation history to make them more clear and complete.
58
+
59
+ Please format your response as follows:
60
+ SUMMARY: [Brief summary of the conversation context. Include key points, user intent, and any relevant details]
61
+ REWRITTEN_QUERY: [The rewritten query that incorporates context]
62
+
63
+ Rules:
64
+ 1. If there's relevant context from previous messages, incorporate it into the rewritten query
65
+ 2. Make implicit references explicit
66
+ 3. Maintain the original intent while adding clarity
67
+ 4. If the query is already clear and complete, keep it as is
68
+ 5. Always provide both SUMMARY and REWRITTEN_QUERY sections"""
69
+
70
+ # 实体提取的系统提示词
71
+ self.entity_extraction_prompt = """You are an expert in road engineering. Extract key entities from the given query.
72
+ Focus on:
73
+ 1. Technical terms and jargon specific to road engineering
74
+ 2. Formulas, equations, or mathematical concepts
75
+ 3. Parameters, specifications, or measurements
76
+ 4. Standards, methods, or procedures
77
+ 5. Materials, equipment, or structures
78
+
79
+ Return the entities as a JSON array of strings. Only include the most important and specific entities."""
80
+
81
+ # 对话历史
82
+ self.conversation_history = []
83
+
84
+
85
+ def rewrite_query(self, query: str) -> Tuple[str, str]:
86
+ """
87
+ 基于对话历史重写查询,并返回历史总结
88
+
89
+ Args:
90
+ query: 原始查询
91
+
92
+ Returns:
93
+ (历史总结, 重写后的查询)
94
+ """
95
+ # 构建消息
96
+ messages = [
97
+ {"role": "system", "content": self.rewrite_prompt}
98
+ ]
99
+
100
+ # 添加对话历史上下文
101
+ if self.conversation_history:
102
+ context = "Previous conversation:\n"
103
+ for msg in self.conversation_history[-6:]: # 最近3轮对话
104
+ role = "User" if msg["role"] == "user" else "Assistant"
105
+ # 截取前200个字符避免过长
106
+ content = msg["content"][:200] + "..." if len(msg["content"]) > 200 else msg["content"]
107
+ context += f"{role}: {content}\n"
108
+
109
+ messages.append({
110
+ "role": "user",
111
+ "content": f"{context}\n\nCurrent query: {query}\n\nPlease provide summary and rewritten query following the specified format:"
112
+ })
113
+ else:
114
+ # 没有历史时也要按格式返回
115
+ messages.append({
116
+ "role": "user",
117
+ "content": f"Current query: {query}\n\nPlease provide summary and rewritten query following the specified format:"
118
+ })
119
+
120
+ try:
121
+ response = self.client.chat.completions.create(
122
+ model=self.model,
123
+ messages=messages,
124
+ temperature=0.1, # 低温度确保一致性
125
+ max_tokens=2000
126
+ )
127
+
128
+ content = response.choices[0].message.content.strip()
129
+
130
+ # 改进的解析逻辑
131
+ summary = ""
132
+ rewritten = query # 默认值
133
+
134
+ # 使用正则表达式提取SUMMARY和REWRITTEN_QUERY
135
+ summary_match = re.search(r'SUMMARY:\s*(.*?)(?=REWRITTEN_QUERY:|$)', content, re.DOTALL | re.IGNORECASE)
136
+ rewritten_match = re.search(r'REWRITTEN_QUERY:\s*(.*?)$', content, re.DOTALL | re.IGNORECASE)
137
+
138
+
139
+ if summary_match:
140
+ summary = summary_match.group(1).strip()
141
+
142
+ if rewritten_match:
143
+ rewritten = rewritten_match.group(1).strip()
144
+
145
+ # 备用解析方法 - 如果正则表达式失败
146
+ if not summary and not rewritten_match:
147
+ lines = content.split('\n')
148
+ current_section = None
149
+ summary_lines = []
150
+ rewritten_lines = []
151
+
152
+ for line in lines:
153
+ line = line.strip()
154
+ if line.upper().startswith("SUMMARY"):
155
+ current_section = "summary"
156
+ # 提取SUMMARY:后面的内容
157
+ summary_part = line[line.upper().find("SUMMARY"):].replace("SUMMARY:", "").strip()
158
+ if summary_part:
159
+ summary_lines.append(summary_part)
160
+ elif line.upper().startswith("REWRITTEN_QUERY") or line.upper().startswith("REWRITTEN QUERY"):
161
+ current_section = "rewritten"
162
+ # 提取REWRITTEN_QUERY:后面的内容
163
+ rewritten_part = re.sub(r'^REWRITTEN[_\s]*QUERY[:\s]*', '', line, flags=re.IGNORECASE).strip()
164
+ if rewritten_part:
165
+ rewritten_lines.append(rewritten_part)
166
+ elif current_section == "summary" and line:
167
+ summary_lines.append(line)
168
+ elif current_section == "rewritten" and line:
169
+ rewritten_lines.append(line)
170
+
171
+ if summary_lines:
172
+ summary = " ".join(summary_lines)
173
+ if rewritten_lines:
174
+ rewritten = " ".join(rewritten_lines)
175
+
176
+ # 如果仍然没有获得有效结果,使用更简单的方法
177
+ if not summary and self.conversation_history:
178
+ summary = "Continue previous discussion"
179
+
180
+ if not rewritten or rewritten == query:
181
+ rewritten = query
182
+
183
+ print(f"[rewrite_query] Raw query: {query}")
184
+ print(f"[rewrite_query] Chat history summary: {summary}")
185
+ print(f"[rewrite_query] Rewritten query: {rewritten}")
186
+ return summary, rewritten
187
+
188
+ except Exception as e:
189
+ print(f"[rewrite_query] Query rewriting failed: {e}")
190
+ # 生成简单的历史总结作为备用
191
+ simple_summary = ""
192
+ if self.conversation_history:
193
+ simple_summary = "Based on previous conversation content"
194
+ return simple_summary, query # 失败时返回简单总结和原始查询
195
+
196
+ def extract_entities(self, original_query: str, summary: str, rewritten_query: str) -> List[str]:
197
+ """
198
+ 从原始查询、历史总结和重写查询中提取关键实体(专业术语、公式、参数等)
199
+
200
+ Args:
201
+ original_query: 原始用户查询
202
+ summary: 历史总结
203
+ rewritten_query: 重写后的查询文本
204
+
205
+ Returns:
206
+ 提取的实体列表
207
+ """
208
+ # 合并所有文本作为实体提取的输入
209
+ text_parts = []
210
+
211
+ # 添加原始查询
212
+ if original_query:
213
+ text_parts.append(f"Original query: {original_query}")
214
+
215
+ # 添加历史总结
216
+ if summary:
217
+ text_parts.append(f"Context summary: {summary}")
218
+
219
+ # 添加重写查询
220
+ if rewritten_query and rewritten_query != original_query:
221
+ text_parts.append(f"Rewritten query: {rewritten_query}")
222
+
223
+ combined_text = " | ".join(text_parts)
224
+
225
+ messages = [
226
+ {"role": "system", "content": self.entity_extraction_prompt},
227
+ {"role": "user", "content": f"Text to extract entities from: {combined_text}\n\nExtract entities as JSON array:"}
228
+ ]
229
+
230
+ try:
231
+ response = self.client.chat.completions.create(
232
+ model=self.model,
233
+ messages=messages,
234
+ temperature=self.temperature,
235
+ max_tokens=self.max_tokens
236
+ )
237
+
238
+ content = response.choices[0].message.content.strip()
239
+
240
+ # 尝试解析JSON
241
+ try:
242
+ # 提取JSON数组(处理可能的markdown格式)
243
+ json_match = re.search(r'\[.*?\]', content, re.DOTALL)
244
+ if json_match:
245
+ entities = json.loads(json_match.group())
246
+ else:
247
+ entities = json.loads(content)
248
+
249
+ print(f"[extract_entities] Extracted entities: {entities}")
250
+ return entities
251
+
252
+ except json.JSONDecodeError:
253
+ # 如果JSON解析失败,尝试简单的文本处理
254
+ print(f"[extract_entities] JSON parsing failed, using backup method")
255
+ # 查找引号中的内容
256
+ entities = re.findall(r'"([^"]+)"', content)
257
+ return entities if entities else self.simple_entity_extraction(combined_text)
258
+
259
+ except Exception as e:
260
+ print(f"[extract_entities] Entity extraction failed: {e}")
261
+ # 失败时使用简单的关键词提取
262
+ return self.simple_entity_extraction(combined_text)
263
+
264
+ def simple_entity_extraction(self, query: str) -> List[str]:
265
+ """
266
+ 简单的实体提取备用方法
267
+
268
+ Args:
269
+ query: 查询文本
270
+
271
+ Returns:
272
+ 提取的关键词列表
273
+ """
274
+ # 移除常见停用词
275
+ stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
276
+ 'of', 'with', 'by', 'from', 'what', 'how', 'when', 'where', 'why',
277
+ 'is', 'are', 'was', 'were', 'been', 'be', 'have', 'has', 'had',
278
+ 'original', 'query', 'context', 'summary', 'rewritten'} # 添加新的停用词
279
+
280
+ # 分词并过滤
281
+ words = query.lower().split()
282
+ entities = [w for w in words if w not in stop_words and len(w) > 2]
283
+
284
+ # 查找可能的专业术语(包含大写字母或数字)
285
+ special_terms = re.findall(r'\b[A-Z][a-zA-Z]*\b|\b\w*\d+\w*\b', query)
286
+ entities.extend(special_terms)
287
+
288
+ # 去重并返回
289
+ return list(set(entities))[:10] # 最多返回5个实体
290
+
291
+ def enhanced_search(self, query: str, top_k: int = 5) -> Tuple[str, str, List[str], List[Tuple[Dict, float, Dict]]]:
292
+ """
293
+ 增强搜索:重写查询 -> 提取实体 -> 基于实体搜索(优化版本)
294
+
295
+ Args:
296
+ query: 原始查询
297
+ top_k: 返回的结果数
298
+
299
+ Returns:
300
+ (历史总结, 重写后的查询, 提取的实体, 搜索结果)
301
+ """
302
+ # 1. 重写查询并获取历史总结
303
+ summary, rewritten_query = self.rewrite_query(query)
304
+
305
+ # 2. 基于原始查询、总结和重写查询提取实体
306
+ entities = self.extract_entities(query, summary, rewritten_query)
307
+
308
+ # 3. 基于实体搜索(使用优化的批量搜索)
309
+ if entities:
310
+ # 使用优化的批量搜索方法
311
+ search_results = self.vectorizer.search_with_entities_optimized(entities, top_k)
312
+ else:
313
+ # 如果没有提取到实体,使用重写后的查询进行搜索
314
+ print("[enhanced_search] No entities extracted, using full query search")
315
+ search_results = self.vectorizer.search_similar(
316
+ rewritten_query,
317
+ top_k=top_k,
318
+ title_weight=0.2,
319
+ content_weight=0.5,
320
+ full_weight=0.3
321
+ )
322
+
323
+ return summary, rewritten_query, entities, search_results
324
+
325
+ def format_context(self, search_results: List[Tuple[Dict, float, Dict]]) -> str:
326
+ """
327
+ 格式化搜索结果作为上下文
328
+
329
+ Args:
330
+ search_results: 搜索结果列表
331
+
332
+ Returns:
333
+ 格式化的上下文字符串
334
+ """
335
+ if not search_results:
336
+ return ""
337
+
338
+ context_parts = []
339
+ for i, result in enumerate(search_results, 1):
340
+ entry, combined_score, details = result
341
+ # 只显示 title, source, content,不显示 id
342
+ context_parts.append(
343
+ f"Title: {entry['title']}\n"
344
+ f"From: {entry['source']}\n"
345
+ f"Content: {entry['content']}\n"
346
+ )
347
+
348
+ return "RELEVANT KNOWLEDGE BASE CONTENT:\n" + "\n---\n".join(context_parts)
349
+
350
+ def build_messages(self, query: str, context: str) -> List[Dict[str, str]]:
351
+ """
352
+ 构建消息列表,包含系统提示、上下文和用户查询
353
+
354
+ Args:
355
+ query: 用户查询
356
+ context: 知识库上下文
357
+
358
+ Returns:
359
+ 消息列表
360
+ """
361
+ messages = [
362
+ {"role": "system", "content": self.system_prompt}
363
+ ]
364
+
365
+ # 添加对话历史(保留最近5轮对话)
366
+ for msg in self.conversation_history[-10:]: # 最多保留5轮对话(10条消息)
367
+ messages.append(msg)
368
+
369
+ # 构建用户消息,包含上下文
370
+ user_message = query
371
+ if context:
372
+ user_message = f"{context}\n\nUSER QUESTION: {query}"
373
+
374
+ messages.append({"role": "user", "content": user_message})
375
+
376
+ return messages
377
+
378
+ def generate_response_stream(self, query: str) -> Generator[str, None, None]:
379
+ """
380
+ 生成流式响应
381
+
382
+ Args:
383
+ query: 用户查询
384
+
385
+ Yields:
386
+ 响应文本片段
387
+ """
388
+ # 1. 增强搜索(现在使用优化版本)
389
+ print("[generate_response_stream] Processing query...")
390
+ summary, rewritten_query, entities, search_results = self.enhanced_search(query)
391
+
392
+ # 2. 格式化上下文
393
+ context = self.format_context(search_results)
394
+
395
+ # 3. 构建消息(使用原始查询,但包含基于实体搜索的上下文)
396
+ messages = self.build_messages(query, context)
397
+
398
+ # 4. 调用OpenAI API进行流式生成
399
+ try:
400
+ stream = self.client.chat.completions.create(
401
+ model=self.model,
402
+ messages=messages,
403
+ temperature=self.temperature,
404
+ max_tokens=self.max_tokens,
405
+ stream=True
406
+ )
407
+
408
+ # 收集完整响应用于保存到历史
409
+ full_response = ""
410
+
411
+ # 首先返回搜索信息
412
+ search_info = f"\n**Query Analysis:**\n"
413
+ search_info += f"- Query: {query}\n"
414
+ if summary:
415
+ search_info += f"- Summary of history: {summary}\n"
416
+ if rewritten_query != query:
417
+ search_info += f"- Rewritten query: {rewritten_query}\n"
418
+ search_info += f"- Key entities: {', '.join(entities) if entities else 'No specific entities extracted'}\n"
419
+
420
+ if search_results:
421
+ search_info += f"\n**Relevant Sources:**\n"
422
+ for result in search_results:
423
+ entry, combined_score, details = result
424
+ # 给用户显示时包含 ID 和相关度分数
425
+ search_info += f"- [{entry['id']}] {entry['title']} (Relevance: {combined_score:.3f})\n"
426
+ search_info += "\n**Response:**\n"
427
+ else:
428
+ search_info += "\n**Response:** (No relevant knowledge base content found, answering based on general knowledge)\n"
429
+
430
+ # 添加缓存信息(调试用)
431
+ cache_info = self.vectorizer.get_cache_info()
432
+ if cache_info['is_cached']:
433
+ search_info += f"💡 Vector database cached with {cache_info['cache_size']} entries\n\n"
434
+
435
+ yield search_info
436
+
437
+ # 流式返回生成的内容
438
+ for chunk in stream:
439
+ if chunk.choices[0].delta.content is not None:
440
+ content = chunk.choices[0].delta.content
441
+ full_response += content
442
+ yield content
443
+
444
+ # 保存到对话历史
445
+ self.conversation_history.append({"role": "user", "content": query})
446
+ self.conversation_history.append({"role": "assistant", "content": full_response})
447
+
448
+ except Exception as e:
449
+ yield f"\n\nError: Problem occurred while generating response - {str(e)}"
450
+
451
+ def generate_response(self, query: str) -> str:
452
+ """
453
+ 生成完整响应(非流式)
454
+
455
+ Args:
456
+ query: 用户查询
457
+
458
+ Returns:
459
+ 完整的响应文本
460
+ """
461
+ response_parts = []
462
+ for part in self.generate_response_stream(query):
463
+ response_parts.append(part)
464
+ return "".join(response_parts)
465
+
466
+ def clear_history(self):
467
+ """清除对话历史"""
468
+ self.conversation_history = []
469
+ print("[clear_history] Conversation history cleared")
470
+
471
+ def clear_vector_cache(self):
472
+ """清除向量数据库缓存"""
473
+ self.vectorizer.clear_cache()
474
+ print("[clear_vector_cache] Vector database cache cleared")
475
+
476
+ def reload_vector_database(self):
477
+ """重新加载向量数据库"""
478
+ print("[reload_vector_database] Reloading vector database...")
479
+ self.vectorizer.load_vector_database(force_reload=True)
480
+ print("[reload_vector_database] Vector database reload completed")
481
+
482
+ def get_system_status(self) -> Dict:
483
+ """
484
+ 获取系统状态信息
485
+
486
+ Returns:
487
+ 系统状态字典
488
+ """
489
+ cache_info = self.vectorizer.get_cache_info()
490
+ return {
491
+ 'model': self.model,
492
+ 'conversation_turns': len(self.conversation_history) // 2,
493
+ 'vector_cache': cache_info,
494
+ 'last_update': datetime.now().isoformat()
495
+ }
496
+
497
+ def save_conversation(self, filepath: str = None):
498
+ """
499
+ 保存对话历史
500
+
501
+ Args:
502
+ filepath: 保存路径
503
+ """
504
+ if filepath is None:
505
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
506
+ filepath = f"conversation_{timestamp}.json"
507
+
508
+ conversation_data = {
509
+ "timestamp": datetime.now().isoformat(),
510
+ "model": self.model,
511
+ "system_status": self.get_system_status(),
512
+ "history": self.conversation_history
513
+ }
514
+
515
+ with open(filepath, 'w', encoding='utf-8') as f:
516
+ json.dump(conversation_data, f, ensure_ascii=False, indent=2)
517
+
518
+ print(f"[save_conversation] Conversation saved to: {filepath}")
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
- title: CIV3283 Student 17
3
- emoji: 🐠
4
- colorFrom: purple
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.38.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: CIV3283 Student 16
3
+ emoji:
4
+ colorFrom: gray
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 5.34.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: alternative space 16
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import csv
3
+ import os
4
+ import re
5
+ from datetime import datetime, timedelta
6
+ from huggingface_hub import Repository
7
+ from RAG_Learning_Assistant_with_Streaming import RAGLearningAssistant
8
+
9
+ # Configuration for Student Space
10
+ STUDENT_SPACE_NAME = "CIV3283_Student_16" # Replace with actual student space name (e.g., "student-group-1")
11
+ DATA_STORAGE_REPO = "CIV3283/Data_Storage" # Centralized data storage repo
12
+ DATA_BRANCH_NAME = "data_branch"
13
+ LOCAL_DATA_DIR = "temp_data_repo"
14
+
15
+ # Session timeout configuration (in minutes)
16
+ SESSION_TIMEOUT_MINUTES = 30 # Adjust this value as needed
17
+
18
+ # File names in data storage
19
+ KNOWLEDGE_FILE = "knowledge_base.md"
20
+ VECTOR_DB_FILE = "vector_database.csv"
21
+ METADATA_FILE = "vector_metadata.json"
22
+ VECTORIZER_FILE = "vectorize_knowledge_base.py"
23
+
24
+ # Student-specific log files (with space name prefix)
25
+ QUERY_LOG_FILE = f"{STUDENT_SPACE_NAME}_query_log.csv"
26
+ FEEDBACK_LOG_FILE = f"{STUDENT_SPACE_NAME}_feedback_log.csv"
27
+
28
+ # Environment variables
29
+ HF_HUB_TOKEN = os.environ.get("HF_HUB_TOKEN", None)
30
+ if HF_HUB_TOKEN is None:
31
+ raise ValueError("Set HF_HUB_TOKEN in Space Settings -> Secrets")
32
+
33
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
34
+ if OPENAI_API_KEY is None:
35
+ raise ValueError("Set OPENAI_API_KEY in Space Settings -> Secrets")
36
+
37
+ MODEL = "gpt-4.1-nano-2025-04-14"
38
+
39
+ def check_session_validity(check_id):
40
+ """
41
+ Check if the current session is valid based on:
42
+ 1. If user ID matches last query → Allow continue
43
+ 2. If user ID doesn't match → Check time interval:
44
+ - If time interval is small → Block (previous user just finished)
45
+ - If time interval is large → Allow (assistant has been idle)
46
+
47
+ Returns:
48
+ tuple: (is_valid: bool, error_message: str)
49
+ """
50
+ try:
51
+ filepath = os.path.join(LOCAL_DATA_DIR, QUERY_LOG_FILE)
52
+
53
+ # If no log file exists, this is the first query - allow it
54
+ if not os.path.exists(filepath):
55
+ print(f"[check_session_validity] No existing log file, allowing first query for student {check_id}")
56
+ return True, ""
57
+
58
+ # Read the last record from the CSV file
59
+ with open(filepath, 'r', encoding='utf-8') as csvfile:
60
+ reader = csv.reader(csvfile)
61
+ rows = list(reader)
62
+
63
+ # If only header exists, this is effectively the first query
64
+ if len(rows) <= 1:
65
+ print(f"[check_session_validity] Only header in log file, allowing first query for student {check_id}")
66
+ return True, ""
67
+
68
+ # Get the last record (most recent query)
69
+ last_record = rows[-1]
70
+
71
+ # CSV format: [student_space, student_id, timestamp, search_info, query_and_response, thumb_feedback]
72
+ if len(last_record) < 3:
73
+ print(f"[check_session_validity] Invalid last record format, allowing query")
74
+ return True, ""
75
+
76
+ last_student_id = last_record[1]
77
+ last_timestamp_str = last_record[2]
78
+
79
+ print(f"[check_session_validity] Last record - Student ID: {last_student_id}, Timestamp: {last_timestamp_str}")
80
+ print(f"[check_session_validity] Current request - Student ID: {check_id}")
81
+
82
+ # If student ID matches, allow continuation
83
+ if last_student_id == check_id:
84
+ print(f"[check_session_validity] Same user, allowing continuation for student {check_id}")
85
+ return True, ""
86
+
87
+ # If student ID doesn't match, check time interval
88
+ try:
89
+ last_timestamp = datetime.strptime(last_timestamp_str, '%Y-%m-%d %H:%M:%S')
90
+ current_timestamp = datetime.now()
91
+ time_diff = current_timestamp - last_timestamp
92
+
93
+ print(f"[check_session_validity] Different user - Time difference: {time_diff.total_seconds()} seconds ({time_diff.total_seconds()/60:.1f} minutes)")
94
+
95
+ # If time difference is small, block access (previous user just finished)
96
+ if time_diff <= timedelta(minutes=SESSION_TIMEOUT_MINUTES):
97
+ error_msg = "⚠️ The assistant is currently being used by another user. Please return to the load distributor page."
98
+ print(f"[check_session_validity] Blocking access - Previous user ({last_student_id}) used assistant {time_diff.total_seconds()/60:.1f} minutes ago")
99
+ return False, error_msg
100
+
101
+ # If time difference is large, allow access (assistant has been idle)
102
+ print(f"[check_session_validity] Assistant has been idle for {time_diff.total_seconds()/60:.1f} minutes, allowing new user {check_id}")
103
+ return True, ""
104
+
105
+ except ValueError as e:
106
+ print(f"[check_session_validity] Error parsing timestamp: {e}")
107
+ # If we can't parse the timestamp, allow the query to proceed
108
+ return True, ""
109
+
110
+ except Exception as e:
111
+ print(f"[check_session_validity] Error checking session validity: {e}")
112
+ import traceback
113
+ print(f"[check_session_validity] Traceback: {traceback.format_exc()}")
114
+ # On error, allow the query to proceed to avoid blocking legitimate users
115
+ return True, ""
116
+
117
+ def init_data_storage_repo():
118
+ """Initialize connection to centralized data storage repository"""
119
+ try:
120
+ repo = Repository(
121
+ local_dir=LOCAL_DATA_DIR,
122
+ clone_from=DATA_STORAGE_REPO,
123
+ revision=DATA_BRANCH_NAME,
124
+ repo_type="space",
125
+ use_auth_token=HF_HUB_TOKEN
126
+ )
127
+ # Configure git user
128
+ repo.git_config_username_and_email("git_user", f"Student_Space_{STUDENT_SPACE_NAME}")
129
+ repo.git_config_username_and_email("git_email", f"{STUDENT_SPACE_NAME}@student.space")
130
+
131
+ # Pull latest changes
132
+ print(f"[init_data_storage_repo] Pulling latest changes from {DATA_STORAGE_REPO}...")
133
+ repo.git_pull(rebase=True)
134
+
135
+ print(f"[init_data_storage_repo] Successfully connected to data storage repo: {DATA_STORAGE_REPO}")
136
+ print(f"[init_data_storage_repo] Local directory: {LOCAL_DATA_DIR}")
137
+ print(f"[init_data_storage_repo] Branch: {DATA_BRANCH_NAME}")
138
+
139
+ # Check if required files exist
140
+ required_files = [KNOWLEDGE_FILE, VECTOR_DB_FILE, METADATA_FILE]
141
+ for file_name in required_files:
142
+ file_path = os.path.join(LOCAL_DATA_DIR, file_name)
143
+ if os.path.exists(file_path):
144
+ print(f"[init_data_storage_repo] Found required file: {file_name}")
145
+ else:
146
+ print(f"[init_data_storage_repo] Warning: Missing required file: {file_name}")
147
+
148
+ return repo
149
+
150
+ except Exception as e:
151
+ print(f"[init_data_storage_repo] Error initializing repository: {e}")
152
+ import traceback
153
+ print(f"[init_data_storage_repo] Traceback: {traceback.format_exc()}")
154
+ return None
155
+
156
+ def commit_student_logs(commit_message: str):
157
+ """Commit student logs to data storage repository with conflict resolution"""
158
+ if repo is None:
159
+ print("[commit_student_logs] Error: Repository not initialized")
160
+ return False
161
+
162
+ max_retries = 3
163
+ retry_count = 0
164
+
165
+ while retry_count < max_retries:
166
+ try:
167
+ # Check if log files exist before adding
168
+ query_log_path = os.path.join(LOCAL_DATA_DIR, QUERY_LOG_FILE)
169
+ feedback_log_path = os.path.join(LOCAL_DATA_DIR, FEEDBACK_LOG_FILE)
170
+
171
+ files_to_add = []
172
+ if os.path.exists(query_log_path):
173
+ files_to_add.append(QUERY_LOG_FILE)
174
+ print(f"[commit_student_logs] Found query log: {query_log_path}")
175
+
176
+ if os.path.exists(feedback_log_path):
177
+ files_to_add.append(FEEDBACK_LOG_FILE)
178
+ print(f"[commit_student_logs] Found feedback log: {feedback_log_path}")
179
+
180
+ if not files_to_add:
181
+ print("[commit_student_logs] No log files to commit")
182
+ return False
183
+
184
+ # Add files individually
185
+ for file_name in files_to_add:
186
+ print(f"[commit_student_logs] Adding file: {file_name}")
187
+ repo.git_add(pattern=file_name)
188
+
189
+ # Check if there are changes to commit
190
+ try:
191
+ import subprocess
192
+ result = subprocess.run(
193
+ ["git", "status", "--porcelain"],
194
+ cwd=LOCAL_DATA_DIR,
195
+ capture_output=True,
196
+ text=True,
197
+ check=True
198
+ )
199
+
200
+ if not result.stdout.strip():
201
+ print("[commit_student_logs] No changes to commit")
202
+ return True
203
+
204
+ print(f"[commit_student_logs] Changes detected: {result.stdout.strip()}")
205
+
206
+ except Exception as status_error:
207
+ print(f"[commit_student_logs] Warning: Could not check git status: {status_error}")
208
+
209
+ # Commit changes locally first
210
+ print(f"[commit_student_logs] Attempt {retry_count + 1}/{max_retries}: Committing locally: {commit_message}")
211
+ repo.git_commit(commit_message)
212
+
213
+ # Now try to pull and push
214
+ print("[commit_student_logs] Pulling latest changes...")
215
+ repo.git_pull(rebase=True)
216
+
217
+ # Push changes
218
+ print("[commit_student_logs] Pushing to remote...")
219
+ repo.git_push()
220
+
221
+ print(f"[commit_student_logs] Success: {commit_message}")
222
+ return True
223
+
224
+ except Exception as e:
225
+ error_msg = str(e)
226
+ print(f"[commit_student_logs] Attempt {retry_count + 1} failed: {error_msg}")
227
+
228
+ # Check if it's a push conflict or pull conflict
229
+ if ("rejected" in error_msg and "fetch first" in error_msg) or ("cannot pull with rebase" in error_msg):
230
+ print("[commit_student_logs] Detected Git conflict, will retry...")
231
+ retry_count += 1
232
+
233
+ if retry_count < max_retries:
234
+ # Try to reset and start fresh
235
+ try:
236
+ print("[commit_student_logs] Resetting repository state for retry...")
237
+ # Reset to remote state
238
+ repo.git_reset("--hard", "HEAD~1") # Undo the commit
239
+ repo.git_pull(rebase=True) # Get latest changes
240
+
241
+ # Wait a bit before retrying to avoid rapid conflicts
242
+ import time
243
+ wait_time = retry_count * 2 # 2, 4, 6 seconds
244
+ print(f"[commit_student_logs] Waiting {wait_time} seconds before retry...")
245
+ time.sleep(wait_time)
246
+ continue
247
+
248
+ except Exception as reset_error:
249
+ print(f"[commit_student_logs] Reset failed: {reset_error}")
250
+ # If reset fails, try alternative approach
251
+ try:
252
+ # Alternative: stash changes and pull
253
+ repo.git_stash()
254
+ repo.git_pull(rebase=True)
255
+ repo.git_stash("pop")
256
+ continue
257
+ except Exception as stash_error:
258
+ print(f"[commit_student_logs] Stash approach failed: {stash_error}")
259
+ return False
260
+ else:
261
+ print("[commit_student_logs] Max retries reached, giving up")
262
+ return False
263
+ else:
264
+ # Other types of errors, don't retry
265
+ print(f"[commit_student_logs] Non-conflict error, not retrying: {error_msg}")
266
+ return False
267
+
268
+ print("[commit_student_logs] Failed after all retry attempts")
269
+ return False
270
+
271
+ def save_student_query_to_csv(query, search_info, response, check_id, thumb_feedback=None):
272
+ """Save student query record to centralized CSV file"""
273
+ try:
274
+ # Validate check_id
275
+ if not check_id:
276
+ print("[save_student_query_to_csv] Error: No valid check_id provided")
277
+ return False
278
+
279
+ # Ensure the local data directory exists
280
+ os.makedirs(LOCAL_DATA_DIR, exist_ok=True)
281
+
282
+ filepath = os.path.join(LOCAL_DATA_DIR, QUERY_LOG_FILE)
283
+ file_exists = os.path.isfile(filepath)
284
+
285
+ print(f"[save_student_query_to_csv] Saving to: {filepath}")
286
+ print(f"[save_student_query_to_csv] File exists: {file_exists}")
287
+ print(f"[save_student_query_to_csv] Student ID: {check_id}")
288
+
289
+ with open(filepath, 'a', newline='', encoding='utf-8') as csvfile:
290
+ writer = csv.writer(csvfile)
291
+ if not file_exists:
292
+ print("[save_student_query_to_csv] Writing header row")
293
+ writer.writerow(['student_space', 'student_id', 'timestamp', 'search_info', 'query_and_response', 'thumb_feedback'])
294
+
295
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
296
+ query_and_response = f"Query: {query}\nResponse: {response}"
297
+ writer.writerow([STUDENT_SPACE_NAME, check_id, timestamp, search_info, query_and_response, thumb_feedback or ""])
298
+
299
+ print(f"[save_student_query_to_csv] Query saved to local file: {filepath}")
300
+
301
+ # Commit student logs to data storage
302
+ print("[save_student_query_to_csv] Attempting to commit to remote repository...")
303
+ commit_success = commit_student_logs(f"Add query log from student {check_id} at {timestamp}")
304
+
305
+ if commit_success:
306
+ print("[save_student_query_to_csv] Successfully committed to remote repository")
307
+ else:
308
+ print("[save_student_query_to_csv] Failed to commit to remote repository")
309
+
310
+ return True
311
+ except Exception as e:
312
+ print(f"[save_student_query_to_csv] Error: {e}")
313
+ import traceback
314
+ print(f"[save_student_query_to_csv] Traceback: {traceback.format_exc()}")
315
+ return False
316
+
317
+ def update_latest_student_query_feedback(feedback_type, check_id):
318
+ """Update thumb feedback for the latest student query in CSV"""
319
+ try:
320
+ # Validate check_id
321
+ if not check_id:
322
+ print("[update_latest_student_query_feedback] Error: No valid check_id provided")
323
+ return False
324
+
325
+ filepath = os.path.join(LOCAL_DATA_DIR, QUERY_LOG_FILE)
326
+ if not os.path.exists(filepath):
327
+ print("[update_latest_student_query_feedback] Error: Query log file not found")
328
+ return False
329
+
330
+ # Read existing data
331
+ rows = []
332
+ with open(filepath, 'r', encoding='utf-8') as csvfile:
333
+ reader = csv.reader(csvfile)
334
+ rows = list(reader)
335
+
336
+ # Update the last row (most recent query)
337
+ if len(rows) > 1: # Ensure there's at least one data row beyond header
338
+ rows[-1][5] = feedback_type # thumb_feedback column (index 5 for student format)
339
+
340
+ # Write back to file
341
+ with open(filepath, 'w', newline='', encoding='utf-8') as csvfile:
342
+ writer = csv.writer(csvfile)
343
+ writer.writerows(rows)
344
+
345
+ print(f"[update_latest_student_query_feedback] Updated feedback: {feedback_type}")
346
+
347
+ # Commit the update
348
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
349
+ commit_student_logs(f"Update feedback from student {check_id}: {feedback_type} at {timestamp}")
350
+ return True
351
+
352
+ return False
353
+ except Exception as e:
354
+ print(f"[update_latest_student_query_feedback] Error: {e}")
355
+ return False
356
+
357
+ def save_student_comment_feedback(comment, check_id):
358
+ """Save student comment feedback to centralized feedback file"""
359
+ try:
360
+ # Validate check_id
361
+ if not check_id:
362
+ print("[save_student_comment_feedback] Error: No valid check_id provided")
363
+ return False
364
+
365
+ filepath = os.path.join(LOCAL_DATA_DIR, FEEDBACK_LOG_FILE)
366
+ file_exists = os.path.isfile(filepath)
367
+
368
+ with open(filepath, 'a', newline='', encoding='utf-8') as csvfile:
369
+ writer = csv.writer(csvfile)
370
+ if not file_exists:
371
+ writer.writerow(['student_space', 'student_id', 'timestamp', 'comment'])
372
+
373
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
374
+ writer.writerow([STUDENT_SPACE_NAME, check_id, timestamp, comment])
375
+
376
+ print(f"[save_student_comment_feedback] Saved comment to {filepath}")
377
+
378
+ # Commit student logs
379
+ commit_student_logs(f"Add comment feedback from student {check_id} at {timestamp}")
380
+
381
+ return True
382
+ except Exception as e:
383
+ print(f"[save_student_comment_feedback] Error: {e}")
384
+ return False
385
+
386
+ def get_url_params(request: gr.Request):
387
+ """Extract URL parameters from request"""
388
+ if request:
389
+ query_params = dict(request.query_params)
390
+ check_id = query_params.get('check', None)
391
+ if check_id:
392
+ return f"RAG Learning Assistant - Student", check_id
393
+ else:
394
+ return "RAG Learning Assistant - Student", None
395
+ return "RAG Learning Assistant - Student", None
396
+
397
+ def chat_response(message, history, search_info_display, check_id, has_query):
398
+ """Process user input and return streaming response"""
399
+ if not message.strip():
400
+ return history, search_info_display, has_query
401
+
402
+ # Check access permission first
403
+ if not check_id:
404
+ print(f"[chat_response] Access denied: No valid check ID provided")
405
+ # Raise error dialog for access denial
406
+ raise gr.Error(
407
+ "⚠️ Access Restricted\n\n"
408
+ "Please access this system through the link provided in Moodle.\n\n"
409
+ "If you are a student in this course:\n"
410
+ "1. Go to your Moodle course page\n"
411
+ "2. Find the 'CivASK' link\n"
412
+ "3. Click the link to access the system\n\n"
413
+ "If you continue to experience issues, please contact your instructor.",
414
+ duration=8
415
+ )
416
+
417
+ # NEW: Check session validity before proceeding
418
+ session_valid, error_message = check_session_validity(check_id)
419
+ if not session_valid:
420
+ print(f"[chat_response] Session invalid for student {check_id}")
421
+ raise gr.Error(error_message, duration=10)
422
+
423
+ # Valid access and valid session - proceed with normal AI conversation
424
+ print(f"[chat_response] Valid access and session for student ID: {check_id}")
425
+
426
+ # Convert to messages format if needed
427
+ if history and isinstance(history[0], list):
428
+ # Convert from tuples to messages format
429
+ messages_history = []
430
+ for user_msg, assistant_msg in history:
431
+ messages_history.append({"role": "user", "content": user_msg})
432
+ if assistant_msg:
433
+ messages_history.append({"role": "assistant", "content": assistant_msg})
434
+ history = messages_history
435
+
436
+ # Add user message
437
+ history.append({"role": "user", "content": message})
438
+ history.append({"role": "assistant", "content": ""})
439
+
440
+ search_info_collected = False
441
+ search_info_content = ""
442
+ content_part = ""
443
+
444
+ # Process streaming response
445
+ for chunk in assistant.generate_response_stream(message):
446
+ if not search_info_collected:
447
+ if "**Response:**" in chunk: # Support English markers
448
+ search_info_content += chunk
449
+ search_info_collected = True
450
+ yield history, search_info_content, has_query
451
+ else:
452
+ search_info_content += chunk
453
+ yield history, search_info_content, has_query
454
+ else:
455
+ content_part += chunk
456
+ # Update the last assistant message
457
+ history[-1]["content"] = content_part
458
+ yield history, search_info_content, has_query
459
+
460
+ # After streaming is complete, save to CSV (only for valid access)
461
+ try:
462
+ print(f"[chat_response] Saving student query to CSV...")
463
+ print(f"Student Space: {STUDENT_SPACE_NAME}")
464
+ print(f"Student ID: {check_id}")
465
+ print(f"Query: {message}")
466
+
467
+ save_success = save_student_query_to_csv(message, search_info_content, content_part, check_id)
468
+ if save_success:
469
+ print(f"[chat_response] Student query saved successfully")
470
+ has_query = True # Mark that we have a query to rate
471
+ else:
472
+ print(f"[chat_response] Failed to save student query")
473
+
474
+ except Exception as e:
475
+ print(f"[chat_response] Error saving student query: {e}")
476
+
477
+ return history, search_info_content, has_query
478
+
479
+ # Global variables
480
+ repo = None
481
+ assistant = None
482
+
483
+ def main():
484
+ """Main function to initialize and launch the student application"""
485
+ global repo, assistant
486
+
487
+ # Initialize data storage repository connection
488
+ repo = init_data_storage_repo()
489
+
490
+ # Initialize RAG assistant with centralized data storage directory
491
+ print(f"[main] Initializing RAG assistant with data directory: {LOCAL_DATA_DIR}")
492
+ print(f"[main] Session timeout set to: {SESSION_TIMEOUT_MINUTES} minutes")
493
+ assistant = RAGLearningAssistant(
494
+ api_key=OPENAI_API_KEY,
495
+ model=MODEL,
496
+ vector_db_path=LOCAL_DATA_DIR # Pass the data storage repo directory
497
+ )
498
+
499
+ print(f"[main] RAG assistant initialized successfully")
500
+ print(f"[main] Student space: {STUDENT_SPACE_NAME}")
501
+ print(f"[main] Data storage repo: {DATA_STORAGE_REPO}")
502
+ print(f"[main] Query log file: {QUERY_LOG_FILE}")
503
+ print(f"[main] Feedback log file: {FEEDBACK_LOG_FILE}")
504
+
505
+ # Create interface
506
+ with gr.Blocks(title=f"RAG Assistant - {STUDENT_SPACE_NAME}") as interface:
507
+ check_id_state = gr.State("1")
508
+ has_query_state = gr.State(False) # Track if there's a query to rate
509
+ title_display = gr.Markdown(f"# RAG Learning Assistant - {STUDENT_SPACE_NAME}", elem_id="title")
510
+
511
+ # Only Query Check functionality for students
512
+ with gr.Row():
513
+ with gr.Column(scale=4):
514
+ chatbot = gr.Chatbot(label="Ask Your Questions", height=500, type="messages", render_markdown=True)
515
+ msg = gr.Textbox(placeholder="Type your message here...", label="Your Message", show_label=True)
516
+
517
+ # Feedback buttons row
518
+ with gr.Row():
519
+ thumbs_up_btn = gr.Button("👍 Good Answer", variant="secondary", size="sm")
520
+ thumbs_down_btn = gr.Button("👎 Poor Answer", variant="secondary", size="sm")
521
+
522
+ feedback_status = gr.Textbox(label="Feedback Status", interactive=False, lines=1)
523
+
524
+ # Comment section
525
+ with gr.Row():
526
+ comment_input = gr.Textbox(placeholder="Share your comments or suggestions...", label="Comments", lines=2)
527
+ submit_comment_btn = gr.Button("Submit Comment", variant="outline")
528
+
529
+ with gr.Column(scale=1):
530
+ search_info = gr.Markdown(label="Search Analysis Information", value="")
531
+
532
+ # Event handlers
533
+ def init_from_url(request: gr.Request):
534
+ title, check_id = get_url_params(request)
535
+ print(f"[init_from_url] Extracted check_id: {check_id}")
536
+ return f"# {title}", check_id, False # Reset has_query state
537
+
538
+ # Feedback handlers
539
+ def handle_thumbs_up(check_id, has_query):
540
+ if not check_id:
541
+ raise gr.Error(
542
+ "⚠️ Access Restricted\n\n"
543
+ "Please access this system through the CivASK link provided in Moodle to use the feedback features.",
544
+ duration=5
545
+ )
546
+
547
+ print(f"[handle_thumbs_up] Student: {STUDENT_SPACE_NAME}, check_id: {check_id}")
548
+
549
+ # Check if student query log exists and has queries
550
+ filepath = os.path.join(LOCAL_DATA_DIR, QUERY_LOG_FILE)
551
+ if os.path.exists(filepath):
552
+ with open(filepath, 'r', encoding='utf-8') as csvfile:
553
+ reader = csv.reader(csvfile)
554
+ rows = list(reader)
555
+ if len(rows) > 1: # Has header + at least one data row
556
+ success = update_latest_student_query_feedback("thumbs_up", check_id)
557
+ return "👍 Thank you for your positive feedback!" if success else "Failed to save feedback"
558
+
559
+ return "No query to rate yet"
560
+
561
+ def handle_thumbs_down(check_id, has_query):
562
+ if not check_id:
563
+ raise gr.Error(
564
+ "⚠️ Access Restricted\n\n"
565
+ "Please access this system through the CivASK link provided in Moodle to use the feedback features.",
566
+ duration=5
567
+ )
568
+
569
+ print(f"[handle_thumbs_down] Student: {STUDENT_SPACE_NAME}, check_id: {check_id}")
570
+
571
+ # Check if student query log exists and has queries
572
+ filepath = os.path.join(LOCAL_DATA_DIR, QUERY_LOG_FILE)
573
+ if os.path.exists(filepath):
574
+ with open(filepath, 'r', encoding='utf-8') as csvfile:
575
+ reader = csv.reader(csvfile)
576
+ rows = list(reader)
577
+ if len(rows) > 1: # Has header + at least one data row
578
+ success = update_latest_student_query_feedback("thumbs_down", check_id)
579
+ return "👎 Thank you for your feedback. We'll work to improve!" if success else "Failed to save feedback"
580
+
581
+ return "No query to rate yet"
582
+
583
+ def handle_comment_submission(comment, check_id):
584
+ if not check_id:
585
+ raise gr.Error(
586
+ "⚠️ Access Restricted\n\n"
587
+ "Please access this system through the CivASK link provided in Moodle to submit comments.",
588
+ duration=5
589
+ )
590
+
591
+ if comment.strip():
592
+ success = save_student_comment_feedback(comment.strip(), check_id)
593
+ if success:
594
+ return "💬 Thank you for your comment!", ""
595
+ else:
596
+ return "Failed to save comment", comment
597
+ return "Please enter a comment", comment
598
+
599
+ interface.load(fn=init_from_url, outputs=[title_display, check_id_state, has_query_state])
600
+
601
+ # Query events
602
+ msg.submit(
603
+ chat_response,
604
+ [msg, chatbot, search_info, check_id_state, has_query_state],
605
+ [chatbot, search_info, has_query_state]
606
+ ).then(lambda: "", outputs=[msg])
607
+
608
+ # Feedback events
609
+ thumbs_up_btn.click(
610
+ handle_thumbs_up,
611
+ inputs=[check_id_state, has_query_state],
612
+ outputs=[feedback_status]
613
+ )
614
+
615
+ thumbs_down_btn.click(
616
+ handle_thumbs_down,
617
+ inputs=[check_id_state, has_query_state],
618
+ outputs=[feedback_status]
619
+ )
620
+
621
+ submit_comment_btn.click(
622
+ handle_comment_submission,
623
+ inputs=[comment_input, check_id_state],
624
+ outputs=[feedback_status, comment_input]
625
+ )
626
+
627
+ interface.launch()
628
+
629
+ if __name__ == "__main__":
630
+ main()
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==5.34.0
2
+ openai==1.86.0
3
+ pandas==2.2.3
4
+ numpy==2.2.3
5
+ huggingface-hub==0.33.0
6
+ scipy==1.15.2
vectorize_knowledge_base.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import numpy as np
5
+ import pandas as pd
6
+ from typing import List, Dict, Tuple, Optional
7
+ from openai import OpenAI
8
+ from datetime import datetime
9
+ import csv
10
+
11
+ class KnowledgeBaseVectorizer:
12
+ def __init__(self, api_key: str, data_path: str = "", vector_db_dir: str = ""):
13
+ """
14
+ 初始化向量化器(适配学生Space)
15
+
16
+ Args:
17
+ api_key: OpenAI API密钥
18
+ data_path: knowledge_base.md文件的路径(如果为空,使用vector_db_dir中的文件)
19
+ vector_db_dir: 向量数据库所在目录(通常是数据存储仓库的本地目录)
20
+ """
21
+ self.client = OpenAI(api_key=api_key)
22
+ self.embedding_model = "text-embedding-3-small"
23
+
24
+ # 如果指定了vector_db_dir,优先使用该目录中的文件
25
+ if vector_db_dir:
26
+ self.data_path = os.path.join(vector_db_dir, "knowledge_base.md")
27
+ self.vector_db_path = os.path.join(vector_db_dir, "vector_database.csv")
28
+ self.metadata_path = os.path.join(vector_db_dir, "vector_metadata.json")
29
+ else:
30
+ # 保持原有逻辑用于向后兼容
31
+ self.data_path = data_path if data_path else "knowledge_base.md"
32
+ self.vector_db_path = "vector_database.csv"
33
+ self.metadata_path = "vector_metadata.json"
34
+
35
+ # 缓存相关属性
36
+ self._cached_df = None
37
+ self._cached_metadata = None
38
+ self._cached_embeddings = {} # 缓存不同类型的向量矩阵
39
+ self._last_load_time = None
40
+
41
+ print(f"[KnowledgeBaseVectorizer] Initialized with:")
42
+ print(f" - Knowledge base: {self.data_path}")
43
+ print(f" - Vector database: {self.vector_db_path}")
44
+ print(f" - Metadata: {self.metadata_path}")
45
+
46
+ def parse_knowledge_base(self) -> List[Dict]:
47
+ """
48
+ 解析knowledge_base.md文件,提取所有数据条目
49
+ 支持包含表格的完整内容提取
50
+
51
+ Returns:
52
+ 包含所有数据条目的列表,每个条目是一个字典
53
+ """
54
+ entries = []
55
+
56
+ try:
57
+ with open(self.data_path, 'r', encoding='utf-8') as f:
58
+ content = f.read()
59
+ print(f"[parse_knowledge_base] Successfully read file: {self.data_path}")
60
+ except FileNotFoundError:
61
+ print(f"[parse_knowledge_base] Error: File not found - {self.data_path}")
62
+ return entries
63
+ except Exception as e:
64
+ print(f"[parse_knowledge_base] Error reading file: {e}")
65
+ return entries
66
+
67
+ # 改进的匹配策略:使用更精确的正则表达式
68
+ # 匹配模式:# xx-xx-xx title **source** ... **content** ... (直到下一个 # 或文件结尾)
69
+ pattern = r'#\s+(\d{2}-\d{2}-\d{2})\s+([^\n]+)\s+\*\*source\*\*\s+([^\n]+)\s+\*\*content\*\*\s+(.*?)(?=\n#\s+\d{2}-\d{2}-\d{2}|$)'
70
+
71
+ matches = re.findall(pattern, content, re.DOTALL)
72
+
73
+ for match in matches:
74
+ # 清理内容:移除多余的空白行,但保留表格格式
75
+ content_text = match[3].strip()
76
+
77
+ # 保留表格的结构,但清理多余的空白
78
+ content_lines = content_text.split('\n')
79
+ cleaned_lines = []
80
+
81
+ for line in content_lines:
82
+ # 保留非空行和表格行
83
+ if line.strip() or (line.startswith('|') and line.endswith('|')):
84
+ cleaned_lines.append(line.rstrip())
85
+
86
+ # 重新组合内容
87
+ cleaned_content = '\n'.join(cleaned_lines)
88
+
89
+ entry = {
90
+ 'id': match[0].strip(),
91
+ 'title': match[1].strip(),
92
+ 'source': match[2].strip(),
93
+ 'content': cleaned_content,
94
+ 'full_text': f"{match[1].strip()} {cleaned_content}" # 用于向量化的完整文本
95
+ }
96
+ entries.append(entry)
97
+
98
+ print(f"[parse_knowledge_base] Successfully parsed {len(entries)} entries")
99
+
100
+ # 打印一些调试信息
101
+ if entries:
102
+ print("[parse_knowledge_base] First 3 entries info:")
103
+ for i, entry in enumerate(entries[:3]):
104
+ content_lines = entry['content'].count('\n') + 1
105
+ has_table = '|' in entry['content']
106
+ print(f" Entry {entry['id']}: {len(entry['content'])} chars, {content_lines} lines, has table: {has_table}")
107
+
108
+ return entries
109
+
110
+ def get_embedding(self, text: str) -> List[float]:
111
+ """
112
+ 使用OpenAI API获取文本的向量表示
113
+
114
+ Args:
115
+ text: 要向量化的文本
116
+
117
+ Returns:
118
+ 文本的向量表示
119
+ """
120
+ try:
121
+ response = self.client.embeddings.create(
122
+ input=text,
123
+ model=self.embedding_model
124
+ )
125
+ return response.data[0].embedding
126
+ except Exception as e:
127
+ print(f"[get_embedding] Error: {e}")
128
+ return []
129
+
130
+ def batch_get_embeddings(self, texts: List[str], batch_size: int = 10) -> List[List[float]]:
131
+ """
132
+ 批量获取文本的向量表示
133
+
134
+ Args:
135
+ texts: 要向量化的文本列表
136
+ batch_size: 批处理大小
137
+
138
+ Returns:
139
+ 向量列表
140
+ """
141
+ embeddings = []
142
+
143
+ for i in range(0, len(texts), batch_size):
144
+ batch = texts[i:i + batch_size]
145
+ print(f"[batch_get_embeddings] Processing batch {i//batch_size + 1}/{(len(texts) + batch_size - 1)//batch_size}")
146
+
147
+ try:
148
+ response = self.client.embeddings.create(
149
+ input=batch,
150
+ model=self.embedding_model
151
+ )
152
+ batch_embeddings = [item.embedding for item in response.data]
153
+ embeddings.extend(batch_embeddings)
154
+ except Exception as e:
155
+ print(f"[batch_get_embeddings] Batch error: {e}")
156
+ # 如果批处理失败,尝试单个处理
157
+ for text in batch:
158
+ embedding = self.get_embedding(text)
159
+ embeddings.append(embedding if embedding else [0] * 1536) # 默认维度
160
+
161
+ return embeddings
162
+
163
+ def create_vector_database(self):
164
+ """
165
+ 创建向量数据库并保存为CSV文件
166
+ 支持标题和内容的分别向量化
167
+ """
168
+ print("[create_vector_database] Starting to create vector database...")
169
+
170
+ # 1. 解析知识库
171
+ entries = self.parse_knowledge_base()
172
+ if not entries:
173
+ print("[create_vector_database] No entries found")
174
+ return
175
+
176
+ # 2. 准备要向量化的文本
177
+ titles = [entry['title'] for entry in entries]
178
+ contents = [entry['content'] for entry in entries]
179
+ full_texts = [entry['full_text'] for entry in entries]
180
+
181
+ # 3. 批量获取向量
182
+ print("[create_vector_database] Vectorizing titles...")
183
+ title_embeddings = self.batch_get_embeddings(titles)
184
+
185
+ print("[create_vector_database] Vectorizing contents...")
186
+ content_embeddings = self.batch_get_embeddings(contents)
187
+
188
+ print("[create_vector_database] Vectorizing full texts...")
189
+ full_embeddings = self.batch_get_embeddings(full_texts)
190
+
191
+ # 4. 创建DataFrame来存储数据
192
+ print("[create_vector_database] Creating DataFrame...")
193
+
194
+ # 准备数据行
195
+ rows = []
196
+ for i, entry in enumerate(entries):
197
+ row = {
198
+ 'index': i,
199
+ 'id': entry['id'],
200
+ 'title': entry['title'],
201
+ 'source': entry['source'],
202
+ 'content': entry['content'],
203
+ 'full_text': entry['full_text']
204
+ }
205
+
206
+ # 添加标题向量维度
207
+ for j, val in enumerate(title_embeddings[i]):
208
+ row[f'title_dim_{j}'] = val
209
+
210
+ # 添加内容向量维度
211
+ for j, val in enumerate(content_embeddings[i]):
212
+ row[f'content_dim_{j}'] = val
213
+
214
+ # 添加完整文本向量维度
215
+ for j, val in enumerate(full_embeddings[i]):
216
+ row[f'full_dim_{j}'] = val
217
+
218
+ rows.append(row)
219
+
220
+ # 创建DataFrame
221
+ df = pd.DataFrame(rows)
222
+
223
+ # 5. 保存为CSV文件
224
+ print(f"[create_vector_database] Saving to {self.vector_db_path}...")
225
+ df.to_csv(self.vector_db_path, index=False, encoding='utf-8')
226
+
227
+ # 6. 保存元数据(JSON格式,便于查看)
228
+ metadata = {
229
+ 'embedding_model': self.embedding_model,
230
+ 'created_at': datetime.now().isoformat(),
231
+ 'num_entries': len(entries),
232
+ 'embedding_dimensions': len(title_embeddings[0]) if title_embeddings else 0,
233
+ 'vector_types': ['title', 'content', 'full'],
234
+ 'columns': list(df.columns),
235
+ 'entries_summary': [
236
+ {
237
+ 'id': entry['id'],
238
+ 'title': entry['title'],
239
+ 'source': entry['source']
240
+ } for entry in entries
241
+ ]
242
+ }
243
+
244
+ with open(self.metadata_path, 'w', encoding='utf-8') as f:
245
+ json.dump(metadata, f, ensure_ascii=False, indent=2)
246
+
247
+ print(f"[create_vector_database] Vector database created successfully!")
248
+ print(f" - Vector database saved to: {self.vector_db_path}")
249
+ print(f" - Metadata saved to: {self.metadata_path}")
250
+ print(f" - Processed {len(entries)} entries")
251
+ print(f" - Vector dimensions: {len(title_embeddings[0]) if title_embeddings else 0}")
252
+
253
+ # 清除缓存以便重新加载
254
+ self.clear_cache()
255
+
256
+ def clear_cache(self):
257
+ """清除所有缓存"""
258
+ self._cached_df = None
259
+ self._cached_metadata = None
260
+ self._cached_embeddings = {}
261
+ self._last_load_time = None
262
+ print("[clear_cache] Vector database cache cleared")
263
+
264
+ def load_vector_database(self, force_reload: bool = False) -> Tuple[Optional[pd.DataFrame], Optional[Dict]]:
265
+ """
266
+ 从CSV文件加载向量数据库(带缓存机制)
267
+
268
+ Args:
269
+ force_reload: 是否强制重新加载
270
+
271
+ Returns:
272
+ DataFrame和元数据字典的元组
273
+ """
274
+ # 检查是否需要重新加载
275
+ if not force_reload and self._cached_df is not None and self._cached_metadata is not None:
276
+ return self._cached_df, self._cached_metadata
277
+
278
+ try:
279
+ # 加载CSV文件
280
+ print(f"[load_vector_database] Loading from {self.vector_db_path}")
281
+ df = pd.read_csv(self.vector_db_path, encoding='utf-8')
282
+
283
+ # 加载元数据
284
+ print(f"[load_vector_database] Loading metadata from {self.metadata_path}")
285
+ with open(self.metadata_path, 'r', encoding='utf-8') as f:
286
+ metadata = json.load(f)
287
+
288
+ # 缓存结果
289
+ self._cached_df = df
290
+ self._cached_metadata = metadata
291
+ self._last_load_time = datetime.now()
292
+
293
+ # 预加载向量矩阵到缓存
294
+ self._preload_embeddings()
295
+
296
+ print(f"[load_vector_database] Successfully loaded vector database with {len(df)} entries")
297
+ return df, metadata
298
+ except FileNotFoundError as e:
299
+ print(f"[load_vector_database] Error: File not found - {e}")
300
+ return None, None
301
+ except Exception as e:
302
+ print(f"[load_vector_database] Error loading vector database: {e}")
303
+ return None, None
304
+
305
+ def _preload_embeddings(self):
306
+ """预加载所有类型的向量矩阵到缓存"""
307
+ if self._cached_df is None:
308
+ return
309
+
310
+ vector_types = ['title', 'content', 'full']
311
+ for vector_type in vector_types:
312
+ if vector_type not in self._cached_embeddings:
313
+ embeddings = self.get_embeddings_from_df(self._cached_df, vector_type)
314
+ # 预计算归一化向量
315
+ embeddings_norm = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
316
+ self._cached_embeddings[vector_type] = {
317
+ 'raw': embeddings,
318
+ 'normalized': embeddings_norm
319
+ }
320
+
321
+ print(f"[_preload_embeddings] Preloaded {len(vector_types)} types of vector matrices")
322
+
323
+ def get_embeddings_from_df(self, df: pd.DataFrame, vector_type: str = 'full') -> np.ndarray:
324
+ """
325
+ 从DataFrame中提取向量矩阵
326
+
327
+ Args:
328
+ df: 包含向量的DataFrame
329
+ vector_type: 向量类型 ('title', 'content', 'full')
330
+
331
+ Returns:
332
+ 向量矩阵
333
+ """
334
+ # 根据类型获取对应的列
335
+ if vector_type == 'title':
336
+ embedding_cols = [col for col in df.columns if col.startswith('title_dim_')]
337
+ elif vector_type == 'content':
338
+ embedding_cols = [col for col in df.columns if col.startswith('content_dim_')]
339
+ else: # 'full'
340
+ embedding_cols = [col for col in df.columns if col.startswith('full_dim_')]
341
+
342
+ embeddings = df[embedding_cols].values
343
+ return embeddings
344
+
345
+ def batch_search_similar(self, queries: List[str], top_k: int = 5,
346
+ title_weight: float = 0.4,
347
+ content_weight: float = 0.3,
348
+ full_weight: float = 0.3) -> List[List[Tuple[Dict, float, Dict]]]:
349
+ """
350
+ 批量搜索多个查询,只加载一次向量数据库
351
+
352
+ Args:
353
+ queries: 查询文本列表
354
+ top_k: 每个查询返回最相似的前k个结果
355
+ title_weight: 标题相似度的权重
356
+ content_weight: 内容相似度的权重
357
+ full_weight: 完整文本相似度的权重
358
+
359
+ Returns:
360
+ 每个查询对应的相似条目列表
361
+ """
362
+ # 确保权重之和为1
363
+ total_weight = title_weight + content_weight + full_weight
364
+ title_weight /= total_weight
365
+ content_weight /= total_weight
366
+ full_weight /= total_weight
367
+
368
+ # 加载向量数据库(只加载一次)
369
+ df, metadata = self.load_vector_database()
370
+ if df is None:
371
+ return [[] for _ in queries]
372
+
373
+ # 批量获取查询向量
374
+ print(f"[batch_search_similar] Generating vectors for {len(queries)} queries...")
375
+ query_embeddings = self.batch_get_embeddings(queries, batch_size=min(10, len(queries)))
376
+
377
+ if len(query_embeddings) != len(queries):
378
+ print("[batch_search_similar] Query vector generation failed")
379
+ return [[] for _ in queries]
380
+
381
+ # 获取缓存的归一化向量矩阵
382
+ title_embeddings_norm = self._cached_embeddings['title']['normalized']
383
+ content_embeddings_norm = self._cached_embeddings['content']['normalized']
384
+ full_embeddings_norm = self._cached_embeddings['full']['normalized']
385
+
386
+ all_results = []
387
+
388
+ # 对每个查询进行相似度计算
389
+ for i, (query, query_embedding) in enumerate(zip(queries, query_embeddings)):
390
+ if not query_embedding:
391
+ all_results.append([])
392
+ continue
393
+
394
+ query_vec = np.array(query_embedding)
395
+ query_vec_norm = query_vec / np.linalg.norm(query_vec)
396
+
397
+ # 计算各部分的相似度
398
+ title_similarities = np.dot(title_embeddings_norm, query_vec_norm)
399
+ content_similarities = np.dot(content_embeddings_norm, query_vec_norm)
400
+ full_similarities = np.dot(full_embeddings_norm, query_vec_norm)
401
+
402
+ # 加权综合相似度
403
+ combined_similarities = (
404
+ title_weight * title_similarities +
405
+ content_weight * content_similarities +
406
+ full_weight * full_similarities
407
+ )
408
+
409
+ # 获取top-k
410
+ top_indices = np.argsort(combined_similarities)[::-1][:top_k]
411
+
412
+ query_results = []
413
+ for idx in top_indices:
414
+ # 从DataFrame中获取条目信息
415
+ row = df.iloc[idx]
416
+ entry = {
417
+ 'id': row['id'],
418
+ 'title': row['title'],
419
+ 'source': row['source'],
420
+ 'content': row['content']
421
+ }
422
+
423
+ # 添加各部分的相似度详情
424
+ similarity_details = {
425
+ 'combined': float(combined_similarities[idx]),
426
+ 'title': float(title_similarities[idx]),
427
+ 'content': float(content_similarities[idx]),
428
+ 'full': float(full_similarities[idx])
429
+ }
430
+
431
+ query_results.append((entry, float(combined_similarities[idx]), similarity_details))
432
+
433
+ all_results.append(query_results)
434
+ print(f"[batch_search_similar] Completed query {i+1}/{len(queries)}: '{query[:50]}...'")
435
+
436
+ return all_results
437
+
438
+ def search_similar(self, query: str, top_k: int = 5,
439
+ title_weight: float = 0.4,
440
+ content_weight: float = 0.3,
441
+ full_weight: float = 0.3) -> List[Tuple[Dict, float, Dict]]:
442
+ """
443
+ 搜索与查询最相似的条目,综合考虑标题和内容的相似度
444
+ 使用批量搜索的优化版本
445
+
446
+ Args:
447
+ query: 查询文本
448
+ top_k: 返回最相似的前k个结果
449
+ title_weight: 标题相似度的权重
450
+ content_weight: 内容相似度的权重
451
+ full_weight: 完整文本相似度的权重
452
+
453
+ Returns:
454
+ 相似条目和相似度分数的列表
455
+ """
456
+ # 使用批量搜索处理单个查询
457
+ results = self.batch_search_similar([query], top_k, title_weight, content_weight, full_weight)
458
+ return results[0] if results else []
459
+
460
+ def search_with_entities_optimized(self, entities: List[str], top_k: int = 5) -> List[Tuple[Dict, float, Dict]]:
461
+ """
462
+ 优化版本:使用实体列表搜索知识库,只加载一次向量数据库
463
+
464
+ Args:
465
+ entities: 实体列表
466
+ top_k: 每个实体返回的结果数
467
+
468
+ Returns:
469
+ 合并和去重后的搜索结果
470
+ """
471
+ if not entities:
472
+ return []
473
+
474
+ # 使用批量搜索
475
+ batch_results = self.batch_search_similar(
476
+ entities,
477
+ top_k=top_k,
478
+ title_weight=0.3, # 对于实体搜索,标题权重更高
479
+ content_weight=0.5,
480
+ full_weight=0.2
481
+ )
482
+
483
+ # 合并结果并去重
484
+ seen_ids = set()
485
+ all_results = []
486
+
487
+ for entity_results in batch_results:
488
+ for entry, score, details in entity_results:
489
+ entry_id = entry['id']
490
+ if entry_id not in seen_ids:
491
+ seen_ids.add(entry_id)
492
+ all_results.append((entry, score, details))
493
+
494
+ # 按分数排序
495
+ sorted_results = sorted(all_results, key=lambda x: x[1], reverse=True)
496
+ return sorted_results
497
+
498
+ def get_cache_info(self) -> Dict:
499
+ """
500
+ 获取缓存状态信息
501
+
502
+ Returns:
503
+ 缓存状态字典
504
+ """
505
+ return {
506
+ 'is_cached': self._cached_df is not None,
507
+ 'cache_size': len(self._cached_df) if self._cached_df is not None else 0,
508
+ 'cached_embeddings': list(self._cached_embeddings.keys()),
509
+ 'last_load_time': self._last_load_time.isoformat() if self._last_load_time else None,
510
+ 'data_paths': {
511
+ 'knowledge_base': self.data_path,
512
+ 'vector_database': self.vector_db_path,
513
+ 'metadata': self.metadata_path
514
+ }
515
+ }