Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| from typing import List, Dict, Tuple, Generator, Set | |
| from openai import OpenAI | |
| from vectorize_knowledge_base import KnowledgeBaseVectorizer | |
| import json | |
| from datetime import datetime | |
| import re | |
| class RAGLearningAssistant: | |
| def __init__(self, api_key: str, model: str = "gpt-4.1-nano-2025-04-14", vector_db_path: str = ""): | |
| """ | |
| 初始化RAG学习助手(适配学生Space) | |
| Args: | |
| api_key: OpenAI API密钥(必需) | |
| model: 使用的模型名称 | |
| vector_db_path: 向量数据库所在目录路径(数据存储仓库的本地目录) | |
| """ | |
| self.client = OpenAI(api_key=api_key) | |
| # 使用修改后的KnowledgeBaseVectorizer,指定vector_db_dir | |
| self.vectorizer = KnowledgeBaseVectorizer( | |
| api_key=api_key, | |
| vector_db_dir=vector_db_path # 传递数据存储仓库的本地目录 | |
| ) | |
| # 预加载向量数据库到缓存 | |
| print("[RAGLearningAssistant] Preloading vector database...") | |
| load_result = self.vectorizer.load_vector_database() | |
| if load_result[0] is not None: | |
| print(f"[RAGLearningAssistant] Vector database loaded successfully") | |
| else: | |
| print(f"[RAGLearningAssistant] Warning: Failed to load vector database") | |
| # 模型配置 | |
| self.model = model | |
| self.temperature = 0.1 | |
| self.max_tokens = 2000 | |
| # 系统提示词 | |
| self.system_prompt = """You are a helpful learning assistant specializing in road engineering. | |
| You have access to a knowledge base of course materials. When answering questions: | |
| 1. Stick to the provided context from the knowledge base. | |
| 2. At the end of your response, provide students the 'title' & 'from' fields of the chunks that were used to answer the question. | |
| 3. If the knowledge base doesn't contain relevant information, say so. Students can go to the teaching team for further assistance. | |
| In the response, enclose mathematical formulas and parameters for proper Markdown rendering. | |
| Bold key words if applicable. | |
| """ | |
| # 查询重写的系统提示词 - 改进版本 | |
| self.rewrite_prompt = """You are a query rewriting assistant. Your task is to provide a summary of the conversation history and then rewrite user queries based on conversation history to make them more clear and complete. | |
| Please format your response as follows: | |
| SUMMARY: [Brief summary of the conversation context. Include key points, user intent, and any relevant details] | |
| REWRITTEN_QUERY: [The rewritten query that incorporates context] | |
| Rules: | |
| 1. If there's relevant context from previous messages, incorporate it into the rewritten query | |
| 2. Make implicit references explicit | |
| 3. Maintain the original intent while adding clarity | |
| 4. If the query is already clear and complete, keep it as is | |
| 5. Always provide both SUMMARY and REWRITTEN_QUERY sections""" | |
| # 实体提取的系统提示词 | |
| self.entity_extraction_prompt = """You are an expert in road engineering. Extract key entities from the given query. | |
| Focus on: | |
| 1. Technical terms and jargon specific to road engineering | |
| 2. Formulas, equations, or mathematical concepts | |
| 3. Parameters, specifications, or measurements | |
| 4. Standards, methods, or procedures | |
| 5. Materials, equipment, or structures | |
| Return the entities as a JSON array of strings. Only include the most important and specific entities.""" | |
| # 对话历史 | |
| self.conversation_history = [] | |
| def rewrite_query(self, query: str) -> Tuple[str, str]: | |
| """ | |
| 基于对话历史重写查询,并返回历史总结 | |
| Args: | |
| query: 原始查询 | |
| Returns: | |
| (历史总结, 重写后的查询) | |
| """ | |
| # 构建消息 | |
| messages = [ | |
| {"role": "system", "content": self.rewrite_prompt} | |
| ] | |
| # 添加对话历史上下文 | |
| if self.conversation_history: | |
| context = "Previous conversation:\n" | |
| for msg in self.conversation_history[-6:]: # 最近3轮对话 | |
| role = "User" if msg["role"] == "user" else "Assistant" | |
| # 截取前200个字符避免过长 | |
| content = msg["content"][:200] + "..." if len(msg["content"]) > 200 else msg["content"] | |
| context += f"{role}: {content}\n" | |
| messages.append({ | |
| "role": "user", | |
| "content": f"{context}\n\nCurrent query: {query}\n\nPlease provide summary and rewritten query following the specified format:" | |
| }) | |
| else: | |
| # 没有历史时也要按格式返回 | |
| messages.append({ | |
| "role": "user", | |
| "content": f"Current query: {query}\n\nPlease provide summary and rewritten query following the specified format:" | |
| }) | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=0.1, # 低温度确保一致性 | |
| max_tokens=2000 | |
| ) | |
| content = response.choices[0].message.content.strip() | |
| # 改进的解析逻辑 | |
| summary = "" | |
| rewritten = query # 默认值 | |
| # 使用正则表达式提取SUMMARY和REWRITTEN_QUERY | |
| summary_match = re.search(r'SUMMARY:\s*(.*?)(?=REWRITTEN_QUERY:|$)', content, re.DOTALL | re.IGNORECASE) | |
| rewritten_match = re.search(r'REWRITTEN_QUERY:\s*(.*?)$', content, re.DOTALL | re.IGNORECASE) | |
| if summary_match: | |
| summary = summary_match.group(1).strip() | |
| if rewritten_match: | |
| rewritten = rewritten_match.group(1).strip() | |
| # 备用解析方法 - 如果正则表达式失败 | |
| if not summary and not rewritten_match: | |
| lines = content.split('\n') | |
| current_section = None | |
| summary_lines = [] | |
| rewritten_lines = [] | |
| for line in lines: | |
| line = line.strip() | |
| if line.upper().startswith("SUMMARY"): | |
| current_section = "summary" | |
| # 提取SUMMARY:后面的内容 | |
| summary_part = line[line.upper().find("SUMMARY"):].replace("SUMMARY:", "").strip() | |
| if summary_part: | |
| summary_lines.append(summary_part) | |
| elif line.upper().startswith("REWRITTEN_QUERY") or line.upper().startswith("REWRITTEN QUERY"): | |
| current_section = "rewritten" | |
| # 提取REWRITTEN_QUERY:后面的内容 | |
| rewritten_part = re.sub(r'^REWRITTEN[_\s]*QUERY[:\s]*', '', line, flags=re.IGNORECASE).strip() | |
| if rewritten_part: | |
| rewritten_lines.append(rewritten_part) | |
| elif current_section == "summary" and line: | |
| summary_lines.append(line) | |
| elif current_section == "rewritten" and line: | |
| rewritten_lines.append(line) | |
| if summary_lines: | |
| summary = " ".join(summary_lines) | |
| if rewritten_lines: | |
| rewritten = " ".join(rewritten_lines) | |
| # 如果仍然没有获得有效结果,使用更简单的方法 | |
| if not summary and self.conversation_history: | |
| summary = "Continue previous discussion" | |
| if not rewritten or rewritten == query: | |
| rewritten = query | |
| print(f"[rewrite_query] Raw query: {query}") | |
| print(f"[rewrite_query] Chat history summary: {summary}") | |
| print(f"[rewrite_query] Rewritten query: {rewritten}") | |
| return summary, rewritten | |
| except Exception as e: | |
| print(f"[rewrite_query] Query rewriting failed: {e}") | |
| # 生成简单的历史总结作为备用 | |
| simple_summary = "" | |
| if self.conversation_history: | |
| simple_summary = "Based on previous conversation content" | |
| return simple_summary, query # 失败时返回简单总结和原始查询 | |
| def extract_entities(self, original_query: str, summary: str, rewritten_query: str) -> List[str]: | |
| """ | |
| 从原始查询、历史总结和重写查询中提取关键实体(专业术语、公式、参数等) | |
| Args: | |
| original_query: 原始用户查询 | |
| summary: 历史总结 | |
| rewritten_query: 重写后的查询文本 | |
| Returns: | |
| 提取的实体列表 | |
| """ | |
| # 合并所有文本作为实体提取的输入 | |
| text_parts = [] | |
| # 添加原始查询 | |
| if original_query: | |
| text_parts.append(f"Original query: {original_query}") | |
| # 添加历史总结 | |
| if summary: | |
| text_parts.append(f"Context summary: {summary}") | |
| # 添加重写查询 | |
| if rewritten_query and rewritten_query != original_query: | |
| text_parts.append(f"Rewritten query: {rewritten_query}") | |
| combined_text = " | ".join(text_parts) | |
| messages = [ | |
| {"role": "system", "content": self.entity_extraction_prompt}, | |
| {"role": "user", "content": f"Text to extract entities from: {combined_text}\n\nExtract entities as JSON array:"} | |
| ] | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=self.temperature, | |
| max_tokens=self.max_tokens | |
| ) | |
| content = response.choices[0].message.content.strip() | |
| # 尝试解析JSON | |
| try: | |
| # 提取JSON数组(处理可能的markdown格式) | |
| json_match = re.search(r'\[.*?\]', content, re.DOTALL) | |
| if json_match: | |
| entities = json.loads(json_match.group()) | |
| else: | |
| entities = json.loads(content) | |
| print(f"[extract_entities] Extracted entities: {entities}") | |
| return entities | |
| except json.JSONDecodeError: | |
| # 如果JSON解析失败,尝试简单的文本处理 | |
| print(f"[extract_entities] JSON parsing failed, using backup method") | |
| # 查找引号中的内容 | |
| entities = re.findall(r'"([^"]+)"', content) | |
| return entities if entities else self.simple_entity_extraction(combined_text) | |
| except Exception as e: | |
| print(f"[extract_entities] Entity extraction failed: {e}") | |
| # 失败时使用简单的关键词提取 | |
| return self.simple_entity_extraction(combined_text) | |
| def simple_entity_extraction(self, query: str) -> List[str]: | |
| """ | |
| 简单的实体提取备用方法 | |
| Args: | |
| query: 查询文本 | |
| Returns: | |
| 提取的关键词列表 | |
| """ | |
| # 移除常见停用词 | |
| stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', | |
| 'of', 'with', 'by', 'from', 'what', 'how', 'when', 'where', 'why', | |
| 'is', 'are', 'was', 'were', 'been', 'be', 'have', 'has', 'had', | |
| 'original', 'query', 'context', 'summary', 'rewritten'} # 添加新的停用词 | |
| # 分词并过滤 | |
| words = query.lower().split() | |
| entities = [w for w in words if w not in stop_words and len(w) > 2] | |
| # 查找可能的专业术语(包含大写字母或数字) | |
| special_terms = re.findall(r'\b[A-Z][a-zA-Z]*\b|\b\w*\d+\w*\b', query) | |
| entities.extend(special_terms) | |
| # 去重并返回 | |
| return list(set(entities))[:10] # 最多返回5个实体 | |
| def enhanced_search(self, query: str, top_k: int = 5) -> Tuple[str, str, List[str], List[Tuple[Dict, float, Dict]]]: | |
| """ | |
| 增强搜索:重写查询 -> 提取实体 -> 基于实体搜索(优化版本) | |
| Args: | |
| query: 原始查询 | |
| top_k: 返回的结果数 | |
| Returns: | |
| (历史总结, 重写后的查询, 提取的实体, 搜索结果) | |
| """ | |
| # 1. 重写查询并获取历史总结 | |
| summary, rewritten_query = self.rewrite_query(query) | |
| # 2. 基于原始查询、总结和重写查询提取实体 | |
| entities = self.extract_entities(query, summary, rewritten_query) | |
| # 3. 基于实体搜索(使用优化的批量搜索) | |
| if entities: | |
| # 使用优化的批量搜索方法 | |
| search_results = self.vectorizer.search_with_entities_optimized(entities, top_k) | |
| else: | |
| # 如果没有提取到实体,使用重写后的查询进行搜索 | |
| print("[enhanced_search] No entities extracted, using full query search") | |
| search_results = self.vectorizer.search_similar( | |
| rewritten_query, | |
| top_k=top_k, | |
| title_weight=0.2, | |
| content_weight=0.5, | |
| full_weight=0.3 | |
| ) | |
| return summary, rewritten_query, entities, search_results | |
| def format_context(self, search_results: List[Tuple[Dict, float, Dict]]) -> str: | |
| """ | |
| 格式化搜索结果作为上下文 | |
| Args: | |
| search_results: 搜索结果列表 | |
| Returns: | |
| 格式化的上下文字符串 | |
| """ | |
| if not search_results: | |
| return "" | |
| context_parts = [] | |
| for i, result in enumerate(search_results, 1): | |
| entry, combined_score, details = result | |
| # 只显示 title, source, content,不显示 id | |
| context_parts.append( | |
| f"Title: {entry['title']}\n" | |
| f"From: {entry['source']}\n" | |
| f"Content: {entry['content']}\n" | |
| ) | |
| return "RELEVANT KNOWLEDGE BASE CONTENT:\n" + "\n---\n".join(context_parts) | |
| def build_messages(self, query: str, context: str) -> List[Dict[str, str]]: | |
| """ | |
| 构建消息列表,包含系统提示、上下文和用户查询 | |
| Args: | |
| query: 用户查询 | |
| context: 知识库上下文 | |
| Returns: | |
| 消息列表 | |
| """ | |
| messages = [ | |
| {"role": "system", "content": self.system_prompt} | |
| ] | |
| # 添加对话历史(保留最近5轮对话) | |
| for msg in self.conversation_history[-10:]: # 最多保留5轮对话(10条消息) | |
| messages.append(msg) | |
| # 构建用户消息,包含上下文 | |
| user_message = query | |
| if context: | |
| user_message = f"{context}\n\nUSER QUESTION: {query}" | |
| messages.append({"role": "user", "content": user_message}) | |
| return messages | |
| def generate_response_stream(self, query: str) -> Generator[str, None, None]: | |
| """ | |
| 生成流式响应 | |
| Args: | |
| query: 用户查询 | |
| Yields: | |
| 响应文本片段 | |
| """ | |
| # 1. 增强搜索(现在使用优化版本) | |
| print("[generate_response_stream] Processing query...") | |
| summary, rewritten_query, entities, search_results = self.enhanced_search(query) | |
| # 2. 格式化上下文 | |
| context = self.format_context(search_results) | |
| # 3. 构建消息(使用原始查询,但包含基于实体搜索的上下文) | |
| messages = self.build_messages(query, context) | |
| # 4. 调用OpenAI API进行流式生成 | |
| try: | |
| stream = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=self.temperature, | |
| max_tokens=self.max_tokens, | |
| stream=True | |
| ) | |
| # 收集完整响应用于保存到历史 | |
| full_response = "" | |
| # 首先返回搜索信息 | |
| search_info = f"\n**Query Analysis:**\n" | |
| search_info += f"- Query: {query}\n" | |
| if summary: | |
| search_info += f"- Summary of history: {summary}\n" | |
| if rewritten_query != query: | |
| search_info += f"- Rewritten query: {rewritten_query}\n" | |
| search_info += f"- Key entities: {', '.join(entities) if entities else 'No specific entities extracted'}\n" | |
| if search_results: | |
| search_info += f"\n**Relevant Sources:**\n" | |
| for result in search_results: | |
| entry, combined_score, details = result | |
| # 给用户显示时包含 ID 和相关度分数 | |
| search_info += f"- [{entry['id']}] {entry['title']} (Relevance: {combined_score:.3f})\n" | |
| search_info += "\n**Response:**\n" | |
| else: | |
| search_info += "\n**Response:** (No relevant knowledge base content found, answering based on general knowledge)\n" | |
| # 添加缓存信息(调试用) | |
| cache_info = self.vectorizer.get_cache_info() | |
| if cache_info['is_cached']: | |
| search_info += f"💡 Vector database cached with {cache_info['cache_size']} entries\n\n" | |
| yield search_info | |
| # 流式返回生成的内容 | |
| for chunk in stream: | |
| if chunk.choices[0].delta.content is not None: | |
| content = chunk.choices[0].delta.content | |
| full_response += content | |
| yield content | |
| # 保存到对话历史 | |
| self.conversation_history.append({"role": "user", "content": query}) | |
| self.conversation_history.append({"role": "assistant", "content": full_response}) | |
| except Exception as e: | |
| yield f"\n\nError: Problem occurred while generating response - {str(e)}" | |
| def generate_response(self, query: str) -> str: | |
| """ | |
| 生成完整响应(非流式) | |
| Args: | |
| query: 用户查询 | |
| Returns: | |
| 完整的响应文本 | |
| """ | |
| response_parts = [] | |
| for part in self.generate_response_stream(query): | |
| response_parts.append(part) | |
| return "".join(response_parts) | |
| def clear_history(self): | |
| """清除对话历史""" | |
| self.conversation_history = [] | |
| print("[clear_history] Conversation history cleared") | |
| def clear_vector_cache(self): | |
| """清除向量数据库缓存""" | |
| self.vectorizer.clear_cache() | |
| print("[clear_vector_cache] Vector database cache cleared") | |
| def reload_vector_database(self): | |
| """重新加载向量数据库""" | |
| print("[reload_vector_database] Reloading vector database...") | |
| self.vectorizer.load_vector_database(force_reload=True) | |
| print("[reload_vector_database] Vector database reload completed") | |
| def get_system_status(self) -> Dict: | |
| """ | |
| 获取系统状态信息 | |
| Returns: | |
| 系统状态字典 | |
| """ | |
| cache_info = self.vectorizer.get_cache_info() | |
| return { | |
| 'model': self.model, | |
| 'conversation_turns': len(self.conversation_history) // 2, | |
| 'vector_cache': cache_info, | |
| 'last_update': datetime.now().isoformat() | |
| } | |
| def save_conversation(self, filepath: str = None): | |
| """ | |
| 保存对话历史 | |
| Args: | |
| filepath: 保存路径 | |
| """ | |
| if filepath is None: | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filepath = f"conversation_{timestamp}.json" | |
| conversation_data = { | |
| "timestamp": datetime.now().isoformat(), | |
| "model": self.model, | |
| "system_status": self.get_system_status(), | |
| "history": self.conversation_history | |
| } | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(conversation_data, f, ensure_ascii=False, indent=2) | |
| print(f"[save_conversation] Conversation saved to: {filepath}") |