""" RAG 模块 - GAIA 知识库检索增强生成 基于 GAIA metadata 构建预置知识库,提供问题解题参考 """ import os import csv import json from typing import Optional, List from langchain_core.documents import Document from langchain_core.tools import tool from langchain_core.prompts import ChatPromptTemplate from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_openai import ChatOpenAI try: from langchain_community.vectorstores import FAISS except ImportError: from langchain.vectorstores import FAISS from config import ( OPENAI_BASE_URL, OPENAI_API_KEY, MODEL, TEMPERATURE, RAG_PERSIST_DIR, RAG_CSV_PATH, RAG_EMBEDDING_MODEL, RAG_TOP_K, DEBUG, ) # 使用本地 HuggingFace Embedding(免费,无需 API) try: from langchain_huggingface import HuggingFaceEmbeddings USE_LOCAL_EMBEDDING = True except ImportError: try: from langchain_community.embeddings import HuggingFaceEmbeddings USE_LOCAL_EMBEDDING = True except ImportError: from langchain_openai import OpenAIEmbeddings USE_LOCAL_EMBEDDING = False # ======================================== # RAG Manager # ======================================== class GAIARAGManager: """ GAIA RAG 管理器 功能: - 从 GAIA metadata 构建知识库 - 检索相似问题,提供解题参考 - 不直接返回答案,只提供解题步骤和工具建议 """ def __init__(self, persist_dir: str = RAG_PERSIST_DIR): self.persist_dir = persist_dir # 延迟初始化(首次使用时加载) self._embeddings = None self._llm = None self._vectorstore = None self._initialized = False # 文本分割器(轻量级,可以立即初始化) self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200, separators=["\n\n", "\n", "。", ".", " ", ""] ) # RAG Prompt(用于生成解题建议) self.rag_prompt = ChatPromptTemplate.from_messages([ ("system", """你是一个解题策略顾问。基于相似问题的解题经验,为新问题提供解题建议。 注意: 1. 只提供解题思路和工具建议,不要直接给出答案 2. 参考历史问题的解题步骤,但要根据新问题调整 3. 如果相似问题不太相关,明确说明 相似问题参考: {context}"""), ("human", "新问题:{question}\n\n请给出解题建议:") ]) @property def embeddings(self): """延迟加载嵌入模型""" if self._embeddings is None: if DEBUG: print("[RAG] 正在加载嵌入模型...") if USE_LOCAL_EMBEDDING: self._embeddings = HuggingFaceEmbeddings( model_name=RAG_EMBEDDING_MODEL, model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True} ) else: self._embeddings = OpenAIEmbeddings( base_url=OPENAI_BASE_URL, api_key=OPENAI_API_KEY, ) if DEBUG: print("[RAG] 嵌入模型加载完成") return self._embeddings @property def llm(self): """延迟加载 LLM""" if self._llm is None: self._llm = ChatOpenAI( model=MODEL, temperature=TEMPERATURE, base_url=OPENAI_BASE_URL, api_key=OPENAI_API_KEY, ) return self._llm @property def vectorstore(self) -> Optional[FAISS]: """延迟加载向量存储""" if not self._initialized: self._load_index() self._initialized = True return self._vectorstore @vectorstore.setter def vectorstore(self, value): self._vectorstore = value def _load_index(self): """加载已有的向量索引""" index_file = os.path.join(self.persist_dir, "index.faiss") if os.path.exists(index_file): try: self.vectorstore = FAISS.load_local( self.persist_dir, self.embeddings, allow_dangerous_deserialization=True ) if DEBUG: print(f"[RAG] 已加载索引: {self.persist_dir}") except Exception as e: if DEBUG: print(f"[RAG] 加载索引失败: {e}") self.vectorstore = None else: # 如果没有索引,尝试从默认 CSV 初始化 self._init_from_csv() def _init_from_csv(self): """从默认 CSV 文件初始化向量库""" # 检查多个可能的路径 possible_paths = [ RAG_CSV_PATH, os.path.join(os.path.dirname(__file__), RAG_CSV_PATH), os.path.join(os.path.dirname(__file__), "data_clean.csv"), ] for csv_path in possible_paths: if os.path.exists(csv_path): if DEBUG: print(f"[RAG] 从 CSV 初始化: {csv_path}") self.load_csv(csv_path) return if DEBUG: print("[RAG] 未找到 CSV 文件,知识库为空") def load_csv(self, csv_path: str): """ 从 CSV 文件加载文档 CSV 格式: - content: 问题文本(用于 embedding) - metadata: JSON 格式的元数据(answer, steps, tools, has_file) """ if not os.path.exists(csv_path): raise FileNotFoundError(f"CSV 文件不存在: {csv_path}") documents = [] with open(csv_path, newline="", encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: content = row.get("content", "") if not content: continue # 解析 metadata try: metadata = json.loads(row.get("metadata", "{}")) except json.JSONDecodeError: metadata = {} metadata["csv_source"] = csv_path documents.append(Document(page_content=content, metadata=metadata)) if not documents: if DEBUG: print("[RAG] CSV 中没有有效文档") return # 构建向量库 self.vectorstore = FAISS.from_documents(documents, self.embeddings) # 持久化 os.makedirs(self.persist_dir, exist_ok=True) self.vectorstore.save_local(self.persist_dir) if DEBUG: print(f"[RAG] 已加载 {len(documents)} 条文档") def retrieve(self, query: str, k: int = RAG_TOP_K) -> List[Document]: """ 检索相关文档 Args: query: 查询文本 k: 返回文档数量 Returns: 相关文档列表 """ if self.vectorstore is None: return [] return self.vectorstore.similarity_search(query, k=k) def retrieve_with_scores(self, query: str, k: int = RAG_TOP_K) -> List[tuple]: """ 检索相关文档(带相似度分数) Args: query: 查询文本 k: 返回文档数量 Returns: [(doc, score), ...] 列表 """ if self.vectorstore is None: return [] return self.vectorstore.similarity_search_with_score(query, k=k) def get_solving_hints(self, question: str, k: int = RAG_TOP_K, score_threshold: float = 1.5) -> str: """ 获取解题提示 根据相似问题,提取解题步骤和工具建议 Args: question: 新问题 k: 检索数量 score_threshold: 相似度阈值(越小越相似,FAISS L2距离) Returns: 解题提示文本 """ docs_with_scores = self.retrieve_with_scores(question, k=k) if not docs_with_scores: return "" # 过滤低相似度结果 relevant_docs = [(doc, score) for doc, score in docs_with_scores if score < score_threshold] if not relevant_docs: return "" hints = [] for i, (doc, score) in enumerate(relevant_docs, 1): meta = doc.metadata steps = meta.get('steps', '') tools = meta.get('tools', '') has_file = meta.get('has_file', False) hint_parts = [f"### 参考 {i} (相似度: {1/(1+score):.2f})"] hint_parts.append(f"**相似问题**: {doc.page_content[:100]}...") if steps: hint_parts.append(f"**解题步骤**: {steps[:300]}...") if tools: hint_parts.append(f"**推荐工具**: {tools}") if has_file: hint_parts.append("**注意**: 该问题有附件文件") hints.append("\n".join(hint_parts)) return "\n\n".join(hints) def query(self, question: str, k: int = RAG_TOP_K) -> str: """ RAG 查询:检索 + 生成解题建议 Args: question: 用户问题 k: 检索文档数量 Returns: 解题建议 """ # 1. 检索相关文档 docs = self.retrieve(question, k=k) if not docs: return "知识库中没有找到相似问题。建议直接分析问题并使用适当的工具。" # 2. 构建上下文 context_parts = [] for i, doc in enumerate(docs, 1): meta = doc.metadata context_parts.append(f""" [相似问题 {i}] 问题: {doc.page_content} 解题步骤: {meta.get('steps', 'N/A')} 使用工具: {meta.get('tools', 'N/A')} 有附件: {'是' if meta.get('has_file') else '否'} 答案格式参考: {meta.get('answer', 'N/A')[:50]}... """) context = "\n".join(context_parts) # 3. LLM 生成建议 chain = self.rag_prompt | self.llm response = chain.invoke({ "context": context, "question": question }) return response.content def get_stats(self) -> dict: """获取索引统计信息""" if self.vectorstore is None: return {"status": "empty", "doc_count": 0} try: doc_count = self.vectorstore.index.ntotal except: doc_count = "unknown" return { "status": "loaded", "doc_count": doc_count, "persist_dir": self.persist_dir } # ======================================== # 全局实例 # ======================================== _rag_manager: Optional[GAIARAGManager] = None def get_rag_manager() -> GAIARAGManager: """获取 RAG 管理器单例""" global _rag_manager if _rag_manager is None: _rag_manager = GAIARAGManager() return _rag_manager def _score_to_similarity(score) -> float: """FAISS L2 距离转 [0, 1] 相似度,处理异常值""" try: score_f = float(score) except Exception: return 0.0 if score_f != score_f: # NaN return 0.0 if score_f < 0.0: score_f = 0.0 return 1.0 / (1.0 + score_f) def rag_lookup_answer(question: str, min_similarity: float = 0.85): """ RAG 短路查找:高置信度匹配时直接返回答案。 Returns: 命中: {"answer": str, "similarity": float, "score": float, "metadata": dict} 未命中/异常: None """ if not question or not str(question).strip(): return None try: manager = get_rag_manager() results = manager.retrieve_with_scores(str(question).strip(), k=1) if not results: return None best_doc, best_score = results[0] similarity = _score_to_similarity(best_score) answer = (best_doc.metadata.get("answer") or "").strip() if not answer: return None if similarity > float(min_similarity): return { "answer": answer, "similarity": float(similarity), "score": float(best_score), "metadata": dict(best_doc.metadata), } return None except Exception as e: if DEBUG: print(f"[RAG] rag_lookup_answer failed: {type(e).__name__}: {e}") return None # ======================================== # Agent 工具 # ======================================== @tool def rag_query(question: str) -> str: """ 查询知识库。如果找到高度匹配的问题,直接返回答案;否则返回解题建议。 适用于: - 快速查找已知问题的答案 - 获取相似问题的解题思路和推荐工具 Args: question: 用户问题 Returns: 匹配答案或解题建议 """ manager = get_rag_manager() # 使用带分数的检索 results = manager.retrieve_with_scores(question, k=3) if not results: return "知识库中没有找到相似问题。建议使用 web_search 等工具获取信息。" best_doc, best_score = results[0] similarity = 1 / (1 + best_score) # 高相似度 (>0.85):直接返回答案 if similarity > 0.85: answer = best_doc.metadata.get('answer', '') if answer: return f"【知识库匹配成功】相似度: {similarity:.2f}\n直接答案: {answer}\n请直接使用此答案作为最终回答。" # 中等相似度:返回答案 + 解题参考 if similarity > 0.6: parts = [] for i, (doc, score) in enumerate(results[:2], 1): sim = 1 / (1 + score) meta = doc.metadata parts.append( f"[参考 {i}] 相似度: {sim:.2f}\n" f"问题: {doc.page_content[:100]}...\n" f"答案: {meta.get('answer', 'N/A')}\n" f"步骤: {meta.get('steps', 'N/A')[:200]}\n" f"工具: {meta.get('tools', 'N/A')}" ) return "【知识库参考】\n" + "\n---\n".join(parts) # 低相似度:仅返回工具建议 return manager.query(question) @tool def rag_retrieve(query: str) -> str: """ 仅检索知识库中的相关文档片段,不生成建议。 用于查看原始的相似问题和解题步骤。 Args: query: 检索查询 Returns: 相关文档片段 """ manager = get_rag_manager() docs_with_scores = manager.retrieve_with_scores(query, k=3) if not docs_with_scores: return "知识库为空或未找到相关文档。" results = [] for i, (doc, score) in enumerate(docs_with_scores, 1): meta = doc.metadata results.append(f"""[{i}] 相似度: {1/(1+score):.2f} 问题: {doc.page_content[:200]}... 解题步骤: {meta.get('steps', 'N/A')[:200]}... 工具: {meta.get('tools', 'N/A')} """) return "\n---\n".join(results) @tool def rag_stats() -> str: """ 获取知识库统计信息。 Returns: 知识库状态和文档数量 """ manager = get_rag_manager() stats = manager.get_stats() return f"知识库状态: {stats['status']}, 文档数量: {stats['doc_count']}" # ======================================== # 导出 RAG 工具 # ======================================== RAG_TOOLS = [rag_query, rag_retrieve, rag_stats]