| | """ |
| | RAG 模块 - GAIA 知识库检索增强生成 |
| | 基于 GAIA metadata 构建预置知识库,提供问题解题参考 |
| | """ |
| |
|
| | import os |
| | import csv |
| | import json |
| | from typing import Optional, List |
| |
|
| | from langchain_core.documents import Document |
| | from langchain_core.tools import tool |
| | from langchain_core.prompts import ChatPromptTemplate |
| | from langchain_text_splitters import RecursiveCharacterTextSplitter |
| | from langchain_openai import ChatOpenAI |
| |
|
| | try: |
| | from langchain_community.vectorstores import FAISS |
| | except ImportError: |
| | from langchain.vectorstores import FAISS |
| |
|
| | from config import ( |
| | OPENAI_BASE_URL, |
| | OPENAI_API_KEY, |
| | MODEL, |
| | TEMPERATURE, |
| | RAG_PERSIST_DIR, |
| | RAG_CSV_PATH, |
| | RAG_EMBEDDING_MODEL, |
| | RAG_TOP_K, |
| | DEBUG, |
| | ) |
| |
|
| | |
| | try: |
| | from langchain_huggingface import HuggingFaceEmbeddings |
| | USE_LOCAL_EMBEDDING = True |
| | except ImportError: |
| | try: |
| | from langchain_community.embeddings import HuggingFaceEmbeddings |
| | USE_LOCAL_EMBEDDING = True |
| | except ImportError: |
| | from langchain_openai import OpenAIEmbeddings |
| | USE_LOCAL_EMBEDDING = False |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class GAIARAGManager: |
| | """ |
| | GAIA RAG 管理器 |
| | |
| | 功能: |
| | - 从 GAIA metadata 构建知识库 |
| | - 检索相似问题,提供解题参考 |
| | - 不直接返回答案,只提供解题步骤和工具建议 |
| | """ |
| |
|
| | def __init__(self, persist_dir: str = RAG_PERSIST_DIR): |
| | self.persist_dir = persist_dir |
| |
|
| | |
| | self._embeddings = None |
| | self._llm = None |
| | self._vectorstore = None |
| | self._initialized = False |
| |
|
| | |
| | self.text_splitter = RecursiveCharacterTextSplitter( |
| | chunk_size=1000, |
| | chunk_overlap=200, |
| | separators=["\n\n", "\n", "。", ".", " ", ""] |
| | ) |
| |
|
| | |
| | self.rag_prompt = ChatPromptTemplate.from_messages([ |
| | ("system", """你是一个解题策略顾问。基于相似问题的解题经验,为新问题提供解题建议。 |
| | |
| | 注意: |
| | 1. 只提供解题思路和工具建议,不要直接给出答案 |
| | 2. 参考历史问题的解题步骤,但要根据新问题调整 |
| | 3. 如果相似问题不太相关,明确说明 |
| | |
| | 相似问题参考: |
| | {context}"""), |
| | ("human", "新问题:{question}\n\n请给出解题建议:") |
| | ]) |
| |
|
| | @property |
| | def embeddings(self): |
| | """延迟加载嵌入模型""" |
| | if self._embeddings is None: |
| | if DEBUG: |
| | print("[RAG] 正在加载嵌入模型...") |
| | if USE_LOCAL_EMBEDDING: |
| | self._embeddings = HuggingFaceEmbeddings( |
| | model_name=RAG_EMBEDDING_MODEL, |
| | model_kwargs={'device': 'cpu'}, |
| | encode_kwargs={'normalize_embeddings': True} |
| | ) |
| | else: |
| | self._embeddings = OpenAIEmbeddings( |
| | base_url=OPENAI_BASE_URL, |
| | api_key=OPENAI_API_KEY, |
| | ) |
| | if DEBUG: |
| | print("[RAG] 嵌入模型加载完成") |
| | return self._embeddings |
| |
|
| | @property |
| | def llm(self): |
| | """延迟加载 LLM""" |
| | if self._llm is None: |
| | self._llm = ChatOpenAI( |
| | model=MODEL, |
| | temperature=TEMPERATURE, |
| | base_url=OPENAI_BASE_URL, |
| | api_key=OPENAI_API_KEY, |
| | ) |
| | return self._llm |
| |
|
| | @property |
| | def vectorstore(self) -> Optional[FAISS]: |
| | """延迟加载向量存储""" |
| | if not self._initialized: |
| | self._load_index() |
| | self._initialized = True |
| | return self._vectorstore |
| |
|
| | @vectorstore.setter |
| | def vectorstore(self, value): |
| | self._vectorstore = value |
| |
|
| | def _load_index(self): |
| | """加载已有的向量索引""" |
| | index_file = os.path.join(self.persist_dir, "index.faiss") |
| | if os.path.exists(index_file): |
| | try: |
| | self.vectorstore = FAISS.load_local( |
| | self.persist_dir, |
| | self.embeddings, |
| | allow_dangerous_deserialization=True |
| | ) |
| | if DEBUG: |
| | print(f"[RAG] 已加载索引: {self.persist_dir}") |
| | except Exception as e: |
| | if DEBUG: |
| | print(f"[RAG] 加载索引失败: {e}") |
| | self.vectorstore = None |
| | else: |
| | |
| | self._init_from_csv() |
| |
|
| | def _init_from_csv(self): |
| | """从默认 CSV 文件初始化向量库""" |
| | |
| | possible_paths = [ |
| | RAG_CSV_PATH, |
| | os.path.join(os.path.dirname(__file__), RAG_CSV_PATH), |
| | os.path.join(os.path.dirname(__file__), "data_clean.csv"), |
| | ] |
| |
|
| | for csv_path in possible_paths: |
| | if os.path.exists(csv_path): |
| | if DEBUG: |
| | print(f"[RAG] 从 CSV 初始化: {csv_path}") |
| | self.load_csv(csv_path) |
| | return |
| |
|
| | if DEBUG: |
| | print("[RAG] 未找到 CSV 文件,知识库为空") |
| |
|
| | def load_csv(self, csv_path: str): |
| | """ |
| | 从 CSV 文件加载文档 |
| | |
| | CSV 格式: |
| | - content: 问题文本(用于 embedding) |
| | - metadata: JSON 格式的元数据(answer, steps, tools, has_file) |
| | """ |
| | if not os.path.exists(csv_path): |
| | raise FileNotFoundError(f"CSV 文件不存在: {csv_path}") |
| |
|
| | documents = [] |
| | with open(csv_path, newline="", encoding="utf-8") as f: |
| | reader = csv.DictReader(f) |
| | for row in reader: |
| | content = row.get("content", "") |
| | if not content: |
| | continue |
| |
|
| | |
| | try: |
| | metadata = json.loads(row.get("metadata", "{}")) |
| | except json.JSONDecodeError: |
| | metadata = {} |
| |
|
| | metadata["csv_source"] = csv_path |
| | documents.append(Document(page_content=content, metadata=metadata)) |
| |
|
| | if not documents: |
| | if DEBUG: |
| | print("[RAG] CSV 中没有有效文档") |
| | return |
| |
|
| | |
| | self.vectorstore = FAISS.from_documents(documents, self.embeddings) |
| |
|
| | |
| | os.makedirs(self.persist_dir, exist_ok=True) |
| | self.vectorstore.save_local(self.persist_dir) |
| |
|
| | if DEBUG: |
| | print(f"[RAG] 已加载 {len(documents)} 条文档") |
| |
|
| | def retrieve(self, query: str, k: int = RAG_TOP_K) -> List[Document]: |
| | """ |
| | 检索相关文档 |
| | |
| | Args: |
| | query: 查询文本 |
| | k: 返回文档数量 |
| | |
| | Returns: |
| | 相关文档列表 |
| | """ |
| | if self.vectorstore is None: |
| | return [] |
| |
|
| | return self.vectorstore.similarity_search(query, k=k) |
| |
|
| | def retrieve_with_scores(self, query: str, k: int = RAG_TOP_K) -> List[tuple]: |
| | """ |
| | 检索相关文档(带相似度分数) |
| | |
| | Args: |
| | query: 查询文本 |
| | k: 返回文档数量 |
| | |
| | Returns: |
| | [(doc, score), ...] 列表 |
| | """ |
| | if self.vectorstore is None: |
| | return [] |
| |
|
| | return self.vectorstore.similarity_search_with_score(query, k=k) |
| |
|
| | def get_solving_hints(self, question: str, k: int = RAG_TOP_K, score_threshold: float = 1.5) -> str: |
| | """ |
| | 获取解题提示 |
| | |
| | 根据相似问题,提取解题步骤和工具建议 |
| | |
| | Args: |
| | question: 新问题 |
| | k: 检索数量 |
| | score_threshold: 相似度阈值(越小越相似,FAISS L2距离) |
| | |
| | Returns: |
| | 解题提示文本 |
| | """ |
| | docs_with_scores = self.retrieve_with_scores(question, k=k) |
| |
|
| | if not docs_with_scores: |
| | return "" |
| |
|
| | |
| | relevant_docs = [(doc, score) for doc, score in docs_with_scores if score < score_threshold] |
| |
|
| | if not relevant_docs: |
| | return "" |
| |
|
| | hints = [] |
| | for i, (doc, score) in enumerate(relevant_docs, 1): |
| | meta = doc.metadata |
| | steps = meta.get('steps', '') |
| | tools = meta.get('tools', '') |
| | has_file = meta.get('has_file', False) |
| |
|
| | hint_parts = [f"### 参考 {i} (相似度: {1/(1+score):.2f})"] |
| | hint_parts.append(f"**相似问题**: {doc.page_content[:100]}...") |
| |
|
| | if steps: |
| | hint_parts.append(f"**解题步骤**: {steps[:300]}...") |
| | if tools: |
| | hint_parts.append(f"**推荐工具**: {tools}") |
| | if has_file: |
| | hint_parts.append("**注意**: 该问题有附件文件") |
| |
|
| | hints.append("\n".join(hint_parts)) |
| |
|
| | return "\n\n".join(hints) |
| |
|
| | def query(self, question: str, k: int = RAG_TOP_K) -> str: |
| | """ |
| | RAG 查询:检索 + 生成解题建议 |
| | |
| | Args: |
| | question: 用户问题 |
| | k: 检索文档数量 |
| | |
| | Returns: |
| | 解题建议 |
| | """ |
| | |
| | docs = self.retrieve(question, k=k) |
| |
|
| | if not docs: |
| | return "知识库中没有找到相似问题。建议直接分析问题并使用适当的工具。" |
| |
|
| | |
| | context_parts = [] |
| | for i, doc in enumerate(docs, 1): |
| | meta = doc.metadata |
| | context_parts.append(f""" |
| | [相似问题 {i}] |
| | 问题: {doc.page_content} |
| | 解题步骤: {meta.get('steps', 'N/A')} |
| | 使用工具: {meta.get('tools', 'N/A')} |
| | 有附件: {'是' if meta.get('has_file') else '否'} |
| | 答案格式参考: {meta.get('answer', 'N/A')[:50]}... |
| | """) |
| |
|
| | context = "\n".join(context_parts) |
| |
|
| | |
| | chain = self.rag_prompt | self.llm |
| | response = chain.invoke({ |
| | "context": context, |
| | "question": question |
| | }) |
| |
|
| | return response.content |
| |
|
| | def get_stats(self) -> dict: |
| | """获取索引统计信息""" |
| | if self.vectorstore is None: |
| | return {"status": "empty", "doc_count": 0} |
| |
|
| | try: |
| | doc_count = self.vectorstore.index.ntotal |
| | except: |
| | doc_count = "unknown" |
| |
|
| | return { |
| | "status": "loaded", |
| | "doc_count": doc_count, |
| | "persist_dir": self.persist_dir |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | _rag_manager: Optional[GAIARAGManager] = None |
| |
|
| |
|
| | def get_rag_manager() -> GAIARAGManager: |
| | """获取 RAG 管理器单例""" |
| | global _rag_manager |
| | if _rag_manager is None: |
| | _rag_manager = GAIARAGManager() |
| | return _rag_manager |
| |
|
| |
|
| | def _score_to_similarity(score) -> float: |
| | """FAISS L2 距离转 [0, 1] 相似度,处理异常值""" |
| | try: |
| | score_f = float(score) |
| | except Exception: |
| | return 0.0 |
| | if score_f != score_f: |
| | return 0.0 |
| | if score_f < 0.0: |
| | score_f = 0.0 |
| | return 1.0 / (1.0 + score_f) |
| |
|
| |
|
| | def rag_lookup_answer(question: str, min_similarity: float = 0.85): |
| | """ |
| | RAG 短路查找:高置信度匹配时直接返回答案。 |
| | |
| | Returns: |
| | 命中: {"answer": str, "similarity": float, "score": float, "metadata": dict} |
| | 未命中/异常: None |
| | """ |
| | if not question or not str(question).strip(): |
| | return None |
| | try: |
| | manager = get_rag_manager() |
| | results = manager.retrieve_with_scores(str(question).strip(), k=1) |
| | if not results: |
| | return None |
| | best_doc, best_score = results[0] |
| | similarity = _score_to_similarity(best_score) |
| | answer = (best_doc.metadata.get("answer") or "").strip() |
| | if not answer: |
| | return None |
| | if similarity > float(min_similarity): |
| | return { |
| | "answer": answer, |
| | "similarity": float(similarity), |
| | "score": float(best_score), |
| | "metadata": dict(best_doc.metadata), |
| | } |
| | return None |
| | except Exception as e: |
| | if DEBUG: |
| | print(f"[RAG] rag_lookup_answer failed: {type(e).__name__}: {e}") |
| | return None |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @tool |
| | def rag_query(question: str) -> str: |
| | """ |
| | 查询知识库。如果找到高度匹配的问题,直接返回答案;否则返回解题建议。 |
| | |
| | 适用于: |
| | - 快速查找已知问题的答案 |
| | - 获取相似问题的解题思路和推荐工具 |
| | |
| | Args: |
| | question: 用户问题 |
| | |
| | Returns: |
| | 匹配答案或解题建议 |
| | """ |
| | manager = get_rag_manager() |
| |
|
| | |
| | results = manager.retrieve_with_scores(question, k=3) |
| | if not results: |
| | return "知识库中没有找到相似问题。建议使用 web_search 等工具获取信息。" |
| |
|
| | best_doc, best_score = results[0] |
| | similarity = 1 / (1 + best_score) |
| |
|
| | |
| | if similarity > 0.85: |
| | answer = best_doc.metadata.get('answer', '') |
| | if answer: |
| | return f"【知识库匹配成功】相似度: {similarity:.2f}\n直接答案: {answer}\n请直接使用此答案作为最终回答。" |
| |
|
| | |
| | if similarity > 0.6: |
| | parts = [] |
| | for i, (doc, score) in enumerate(results[:2], 1): |
| | sim = 1 / (1 + score) |
| | meta = doc.metadata |
| | parts.append( |
| | f"[参考 {i}] 相似度: {sim:.2f}\n" |
| | f"问题: {doc.page_content[:100]}...\n" |
| | f"答案: {meta.get('answer', 'N/A')}\n" |
| | f"步骤: {meta.get('steps', 'N/A')[:200]}\n" |
| | f"工具: {meta.get('tools', 'N/A')}" |
| | ) |
| | return "【知识库参考】\n" + "\n---\n".join(parts) |
| |
|
| | |
| | return manager.query(question) |
| |
|
| |
|
| | @tool |
| | def rag_retrieve(query: str) -> str: |
| | """ |
| | 仅检索知识库中的相关文档片段,不生成建议。 |
| | |
| | 用于查看原始的相似问题和解题步骤。 |
| | |
| | Args: |
| | query: 检索查询 |
| | |
| | Returns: |
| | 相关文档片段 |
| | """ |
| | manager = get_rag_manager() |
| | docs_with_scores = manager.retrieve_with_scores(query, k=3) |
| |
|
| | if not docs_with_scores: |
| | return "知识库为空或未找到相关文档。" |
| |
|
| | results = [] |
| | for i, (doc, score) in enumerate(docs_with_scores, 1): |
| | meta = doc.metadata |
| | results.append(f"""[{i}] 相似度: {1/(1+score):.2f} |
| | 问题: {doc.page_content[:200]}... |
| | 解题步骤: {meta.get('steps', 'N/A')[:200]}... |
| | 工具: {meta.get('tools', 'N/A')} |
| | """) |
| |
|
| | return "\n---\n".join(results) |
| |
|
| |
|
| | @tool |
| | def rag_stats() -> str: |
| | """ |
| | 获取知识库统计信息。 |
| | |
| | Returns: |
| | 知识库状态和文档数量 |
| | """ |
| | manager = get_rag_manager() |
| | stats = manager.get_stats() |
| | return f"知识库状态: {stats['status']}, 文档数量: {stats['doc_count']}" |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | RAG_TOOLS = [rag_query, rag_retrieve, rag_stats] |
| |
|