GAIA / rag.py
hapda12's picture
Upload 12 files
358eb7e verified
"""
RAG 模块 - GAIA 知识库检索增强生成
基于 GAIA metadata 构建预置知识库,提供问题解题参考
"""
import os
import csv
import json
from typing import Optional, List
from langchain_core.documents import Document
from langchain_core.tools import tool
from langchain_core.prompts import ChatPromptTemplate
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_openai import ChatOpenAI
try:
from langchain_community.vectorstores import FAISS
except ImportError:
from langchain.vectorstores import FAISS
from config import (
OPENAI_BASE_URL,
OPENAI_API_KEY,
MODEL,
TEMPERATURE,
RAG_PERSIST_DIR,
RAG_CSV_PATH,
RAG_EMBEDDING_MODEL,
RAG_TOP_K,
DEBUG,
)
# 使用本地 HuggingFace Embedding(免费,无需 API)
try:
from langchain_huggingface import HuggingFaceEmbeddings
USE_LOCAL_EMBEDDING = True
except ImportError:
try:
from langchain_community.embeddings import HuggingFaceEmbeddings
USE_LOCAL_EMBEDDING = True
except ImportError:
from langchain_openai import OpenAIEmbeddings
USE_LOCAL_EMBEDDING = False
# ========================================
# RAG Manager
# ========================================
class GAIARAGManager:
"""
GAIA RAG 管理器
功能:
- 从 GAIA metadata 构建知识库
- 检索相似问题,提供解题参考
- 不直接返回答案,只提供解题步骤和工具建议
"""
def __init__(self, persist_dir: str = RAG_PERSIST_DIR):
self.persist_dir = persist_dir
# 延迟初始化(首次使用时加载)
self._embeddings = None
self._llm = None
self._vectorstore = None
self._initialized = False
# 文本分割器(轻量级,可以立即初始化)
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
separators=["\n\n", "\n", "。", ".", " ", ""]
)
# RAG Prompt(用于生成解题建议)
self.rag_prompt = ChatPromptTemplate.from_messages([
("system", """你是一个解题策略顾问。基于相似问题的解题经验,为新问题提供解题建议。
注意:
1. 只提供解题思路和工具建议,不要直接给出答案
2. 参考历史问题的解题步骤,但要根据新问题调整
3. 如果相似问题不太相关,明确说明
相似问题参考:
{context}"""),
("human", "新问题:{question}\n\n请给出解题建议:")
])
@property
def embeddings(self):
"""延迟加载嵌入模型"""
if self._embeddings is None:
if DEBUG:
print("[RAG] 正在加载嵌入模型...")
if USE_LOCAL_EMBEDDING:
self._embeddings = HuggingFaceEmbeddings(
model_name=RAG_EMBEDDING_MODEL,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
else:
self._embeddings = OpenAIEmbeddings(
base_url=OPENAI_BASE_URL,
api_key=OPENAI_API_KEY,
)
if DEBUG:
print("[RAG] 嵌入模型加载完成")
return self._embeddings
@property
def llm(self):
"""延迟加载 LLM"""
if self._llm is None:
self._llm = ChatOpenAI(
model=MODEL,
temperature=TEMPERATURE,
base_url=OPENAI_BASE_URL,
api_key=OPENAI_API_KEY,
)
return self._llm
@property
def vectorstore(self) -> Optional[FAISS]:
"""延迟加载向量存储"""
if not self._initialized:
self._load_index()
self._initialized = True
return self._vectorstore
@vectorstore.setter
def vectorstore(self, value):
self._vectorstore = value
def _load_index(self):
"""加载已有的向量索引"""
index_file = os.path.join(self.persist_dir, "index.faiss")
if os.path.exists(index_file):
try:
self.vectorstore = FAISS.load_local(
self.persist_dir,
self.embeddings,
allow_dangerous_deserialization=True
)
if DEBUG:
print(f"[RAG] 已加载索引: {self.persist_dir}")
except Exception as e:
if DEBUG:
print(f"[RAG] 加载索引失败: {e}")
self.vectorstore = None
else:
# 如果没有索引,尝试从默认 CSV 初始化
self._init_from_csv()
def _init_from_csv(self):
"""从默认 CSV 文件初始化向量库"""
# 检查多个可能的路径
possible_paths = [
RAG_CSV_PATH,
os.path.join(os.path.dirname(__file__), RAG_CSV_PATH),
os.path.join(os.path.dirname(__file__), "data_clean.csv"),
]
for csv_path in possible_paths:
if os.path.exists(csv_path):
if DEBUG:
print(f"[RAG] 从 CSV 初始化: {csv_path}")
self.load_csv(csv_path)
return
if DEBUG:
print("[RAG] 未找到 CSV 文件,知识库为空")
def load_csv(self, csv_path: str):
"""
从 CSV 文件加载文档
CSV 格式:
- content: 问题文本(用于 embedding)
- metadata: JSON 格式的元数据(answer, steps, tools, has_file)
"""
if not os.path.exists(csv_path):
raise FileNotFoundError(f"CSV 文件不存在: {csv_path}")
documents = []
with open(csv_path, newline="", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
content = row.get("content", "")
if not content:
continue
# 解析 metadata
try:
metadata = json.loads(row.get("metadata", "{}"))
except json.JSONDecodeError:
metadata = {}
metadata["csv_source"] = csv_path
documents.append(Document(page_content=content, metadata=metadata))
if not documents:
if DEBUG:
print("[RAG] CSV 中没有有效文档")
return
# 构建向量库
self.vectorstore = FAISS.from_documents(documents, self.embeddings)
# 持久化
os.makedirs(self.persist_dir, exist_ok=True)
self.vectorstore.save_local(self.persist_dir)
if DEBUG:
print(f"[RAG] 已加载 {len(documents)} 条文档")
def retrieve(self, query: str, k: int = RAG_TOP_K) -> List[Document]:
"""
检索相关文档
Args:
query: 查询文本
k: 返回文档数量
Returns:
相关文档列表
"""
if self.vectorstore is None:
return []
return self.vectorstore.similarity_search(query, k=k)
def retrieve_with_scores(self, query: str, k: int = RAG_TOP_K) -> List[tuple]:
"""
检索相关文档(带相似度分数)
Args:
query: 查询文本
k: 返回文档数量
Returns:
[(doc, score), ...] 列表
"""
if self.vectorstore is None:
return []
return self.vectorstore.similarity_search_with_score(query, k=k)
def get_solving_hints(self, question: str, k: int = RAG_TOP_K, score_threshold: float = 1.5) -> str:
"""
获取解题提示
根据相似问题,提取解题步骤和工具建议
Args:
question: 新问题
k: 检索数量
score_threshold: 相似度阈值(越小越相似,FAISS L2距离)
Returns:
解题提示文本
"""
docs_with_scores = self.retrieve_with_scores(question, k=k)
if not docs_with_scores:
return ""
# 过滤低相似度结果
relevant_docs = [(doc, score) for doc, score in docs_with_scores if score < score_threshold]
if not relevant_docs:
return ""
hints = []
for i, (doc, score) in enumerate(relevant_docs, 1):
meta = doc.metadata
steps = meta.get('steps', '')
tools = meta.get('tools', '')
has_file = meta.get('has_file', False)
hint_parts = [f"### 参考 {i} (相似度: {1/(1+score):.2f})"]
hint_parts.append(f"**相似问题**: {doc.page_content[:100]}...")
if steps:
hint_parts.append(f"**解题步骤**: {steps[:300]}...")
if tools:
hint_parts.append(f"**推荐工具**: {tools}")
if has_file:
hint_parts.append("**注意**: 该问题有附件文件")
hints.append("\n".join(hint_parts))
return "\n\n".join(hints)
def query(self, question: str, k: int = RAG_TOP_K) -> str:
"""
RAG 查询:检索 + 生成解题建议
Args:
question: 用户问题
k: 检索文档数量
Returns:
解题建议
"""
# 1. 检索相关文档
docs = self.retrieve(question, k=k)
if not docs:
return "知识库中没有找到相似问题。建议直接分析问题并使用适当的工具。"
# 2. 构建上下文
context_parts = []
for i, doc in enumerate(docs, 1):
meta = doc.metadata
context_parts.append(f"""
[相似问题 {i}]
问题: {doc.page_content}
解题步骤: {meta.get('steps', 'N/A')}
使用工具: {meta.get('tools', 'N/A')}
有附件: {'是' if meta.get('has_file') else '否'}
答案格式参考: {meta.get('answer', 'N/A')[:50]}...
""")
context = "\n".join(context_parts)
# 3. LLM 生成建议
chain = self.rag_prompt | self.llm
response = chain.invoke({
"context": context,
"question": question
})
return response.content
def get_stats(self) -> dict:
"""获取索引统计信息"""
if self.vectorstore is None:
return {"status": "empty", "doc_count": 0}
try:
doc_count = self.vectorstore.index.ntotal
except:
doc_count = "unknown"
return {
"status": "loaded",
"doc_count": doc_count,
"persist_dir": self.persist_dir
}
# ========================================
# 全局实例
# ========================================
_rag_manager: Optional[GAIARAGManager] = None
def get_rag_manager() -> GAIARAGManager:
"""获取 RAG 管理器单例"""
global _rag_manager
if _rag_manager is None:
_rag_manager = GAIARAGManager()
return _rag_manager
def _score_to_similarity(score) -> float:
"""FAISS L2 距离转 [0, 1] 相似度,处理异常值"""
try:
score_f = float(score)
except Exception:
return 0.0
if score_f != score_f: # NaN
return 0.0
if score_f < 0.0:
score_f = 0.0
return 1.0 / (1.0 + score_f)
def rag_lookup_answer(question: str, min_similarity: float = 0.85):
"""
RAG 短路查找:高置信度匹配时直接返回答案。
Returns:
命中: {"answer": str, "similarity": float, "score": float, "metadata": dict}
未命中/异常: None
"""
if not question or not str(question).strip():
return None
try:
manager = get_rag_manager()
results = manager.retrieve_with_scores(str(question).strip(), k=1)
if not results:
return None
best_doc, best_score = results[0]
similarity = _score_to_similarity(best_score)
answer = (best_doc.metadata.get("answer") or "").strip()
if not answer:
return None
if similarity > float(min_similarity):
return {
"answer": answer,
"similarity": float(similarity),
"score": float(best_score),
"metadata": dict(best_doc.metadata),
}
return None
except Exception as e:
if DEBUG:
print(f"[RAG] rag_lookup_answer failed: {type(e).__name__}: {e}")
return None
# ========================================
# Agent 工具
# ========================================
@tool
def rag_query(question: str) -> str:
"""
查询知识库。如果找到高度匹配的问题,直接返回答案;否则返回解题建议。
适用于:
- 快速查找已知问题的答案
- 获取相似问题的解题思路和推荐工具
Args:
question: 用户问题
Returns:
匹配答案或解题建议
"""
manager = get_rag_manager()
# 使用带分数的检索
results = manager.retrieve_with_scores(question, k=3)
if not results:
return "知识库中没有找到相似问题。建议使用 web_search 等工具获取信息。"
best_doc, best_score = results[0]
similarity = 1 / (1 + best_score)
# 高相似度 (>0.85):直接返回答案
if similarity > 0.85:
answer = best_doc.metadata.get('answer', '')
if answer:
return f"【知识库匹配成功】相似度: {similarity:.2f}\n直接答案: {answer}\n请直接使用此答案作为最终回答。"
# 中等相似度:返回答案 + 解题参考
if similarity > 0.6:
parts = []
for i, (doc, score) in enumerate(results[:2], 1):
sim = 1 / (1 + score)
meta = doc.metadata
parts.append(
f"[参考 {i}] 相似度: {sim:.2f}\n"
f"问题: {doc.page_content[:100]}...\n"
f"答案: {meta.get('answer', 'N/A')}\n"
f"步骤: {meta.get('steps', 'N/A')[:200]}\n"
f"工具: {meta.get('tools', 'N/A')}"
)
return "【知识库参考】\n" + "\n---\n".join(parts)
# 低相似度:仅返回工具建议
return manager.query(question)
@tool
def rag_retrieve(query: str) -> str:
"""
仅检索知识库中的相关文档片段,不生成建议。
用于查看原始的相似问题和解题步骤。
Args:
query: 检索查询
Returns:
相关文档片段
"""
manager = get_rag_manager()
docs_with_scores = manager.retrieve_with_scores(query, k=3)
if not docs_with_scores:
return "知识库为空或未找到相关文档。"
results = []
for i, (doc, score) in enumerate(docs_with_scores, 1):
meta = doc.metadata
results.append(f"""[{i}] 相似度: {1/(1+score):.2f}
问题: {doc.page_content[:200]}...
解题步骤: {meta.get('steps', 'N/A')[:200]}...
工具: {meta.get('tools', 'N/A')}
""")
return "\n---\n".join(results)
@tool
def rag_stats() -> str:
"""
获取知识库统计信息。
Returns:
知识库状态和文档数量
"""
manager = get_rag_manager()
stats = manager.get_stats()
return f"知识库状态: {stats['status']}, 文档数量: {stats['doc_count']}"
# ========================================
# 导出 RAG 工具
# ========================================
RAG_TOOLS = [rag_query, rag_retrieve, rag_stats]