Spaces:
Paused
Paused
| """ | |
| 路由器和评分器模块 | |
| 包含查询路由、文档相关性评分、答案质量评分和幻觉检测 | |
| """ | |
| try: | |
| from langchain_core.prompts import PromptTemplate | |
| except ImportError: | |
| from langchain.prompts import PromptTemplate | |
| from langchain_community.chat_models import ChatOllama | |
| from langchain_core.output_parsers import JsonOutputParser, StrOutputParser | |
| from config import LOCAL_LLM | |
| class QueryRouter: | |
| """查询路由器,决定使用向量存储还是网络搜索""" | |
| def __init__(self): | |
| self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0) | |
| self.prompt = PromptTemplate( | |
| template="""你是一个专家,负责将用户问题路由到向量存储或网络搜索。 | |
| 对于关于LLM智能体、提示工程和对抗性攻击的问题,使用向量存储。 | |
| 你不需要严格匹配问题中与这些主题相关的关键词。 | |
| 否则,使用网络搜索。根据问题给出二进制选择'web_search'或'vectorstore'。 | |
| 返回一个只包含'datasource'键的JSON,不要前言或解释。 | |
| 要路由的问题:{question}""", | |
| input_variables=["question"], | |
| ) | |
| self.router = self.prompt | self.llm | JsonOutputParser() | |
| def route(self, question: str) -> str: | |
| """路由问题到相应的数据源""" | |
| result = self.router.invoke({"question": question}) | |
| return result.get("datasource", "web_search") | |
| class DocumentGrader: | |
| """文档相关性评分器""" | |
| def __init__(self): | |
| self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0) | |
| self.prompt = PromptTemplate( | |
| template="""你是一个评分员,评估检索到的文档是否与用户问题相关。 | |
| 如果文档包含与用户问题相关的关键词或语义,请给出'yes'分数。 | |
| 给出二进制分数'yes'或'no',以表明文档是否与问题相关。 | |
| 将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。 | |
| 检索到的文档: | |
| {document} | |
| 用户问题:{question}""", | |
| input_variables=["question", "document"], | |
| ) | |
| self.grader = self.prompt | self.llm | JsonOutputParser() | |
| def grade(self, question: str, document: str) -> str: | |
| """评估文档与问题的相关性""" | |
| result = self.grader.invoke({"question": question, "document": document}) | |
| return result.get("score", "no") | |
| class AnswerGrader: | |
| """答案质量评分器""" | |
| def __init__(self): | |
| self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0) | |
| self.prompt = PromptTemplate( | |
| template="""你是一个评分员,评估答案是否有助于解决问题。 | |
| 这里是答案: | |
| \n ------- \n | |
| {generation} | |
| \n ------- \n | |
| 这里是问题:{question} | |
| 给出二进制分数'yes'或'no',表示答案是否有助于解决问题。 | |
| 将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。""", | |
| input_variables=["generation", "question"], | |
| ) | |
| self.grader = self.prompt | self.llm | JsonOutputParser() | |
| def grade(self, question: str, generation: str) -> str: | |
| """评估答案质量""" | |
| result = self.grader.invoke({"question": question, "generation": generation}) | |
| return result.get("score", "no") | |
| class HallucinationGrader: | |
| """ | |
| 幻觉检测器 - 使用专业模型(Vectara + NLI) | |
| 相比 LLM-as-a-Judge 方法: | |
| - 准确率从 60-75% 提升到 85-95% | |
| - 速度提升 5-10 倍 | |
| - 成本降低 90% | |
| """ | |
| def __init__(self, method: str = "hybrid"): | |
| """ | |
| 初始化幻觉检测器 | |
| Args: | |
| method: 'vectara', 'nli', 或 'hybrid' (推荐) | |
| """ | |
| # 尝试加载专业检测模型 | |
| try: | |
| from hallucination_detector import initialize_hallucination_detector | |
| self.detector = initialize_hallucination_detector(method=method) | |
| self.use_professional_detector = True | |
| print(f"✅ 使用专业幻觉检测器: {method}") | |
| except Exception as e: | |
| print(f"⚠️ 专业检测器加载失败,回退到 LLM 方法: {e}") | |
| self.use_professional_detector = False | |
| # 回退到原有的 LLM 方法 | |
| self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0) | |
| self.prompt = PromptTemplate( | |
| template="""你是一个评分员,评估LLM生成是否基于/支持一组检索到的事实。 | |
| 给出二进制分数'yes'或'no'。'yes'意味着答案基于/支持文档。 | |
| 将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。 | |
| 检索到的文档: | |
| {documents} | |
| LLM生成:{generation}""", | |
| input_variables=["generation", "documents"], | |
| ) | |
| self.grader = self.prompt | self.llm | JsonOutputParser() | |
| def grade(self, generation: str, documents) -> str: | |
| """ | |
| 检测生成内容是否存在幻觉 | |
| Args: | |
| generation: LLM 生成的内容 | |
| documents: 参考文档 | |
| Returns: | |
| "yes" 表示无幻觉,"no" 表示有幻觉 | |
| """ | |
| if self.use_professional_detector: | |
| # 使用专业检测器 | |
| return self.detector.grade(generation, documents) | |
| else: | |
| # 回退到 LLM 方法 | |
| result = self.grader.invoke({"generation": generation, "documents": documents}) | |
| return result.get("score", "no") | |
| class QueryRewriter: | |
| """查询重写器,优化查询以获得更好的检索结果""" | |
| def __init__(self): | |
| self.llm = ChatOllama(model=LOCAL_LLM, temperature=0) | |
| self.prompt = PromptTemplate( | |
| template="""你是一个问题重写器,将输入问题转换为更适合向量存储检索的更好版本。 | |
| 查看初始问题并制定一个改进的问题。 | |
| 这里是初始问题:\n\n {question}。改进的问题(无前言):\n """, | |
| input_variables=["question"], | |
| ) | |
| self.rewriter = self.prompt | self.llm | StrOutputParser() | |
| def rewrite(self, question: str) -> str: | |
| """重写查询以获得更好的检索效果""" | |
| print(f"---原始查询: {question}---") | |
| rewritten_query = self.rewriter.invoke({"question": question}) | |
| print(f"---重写查询: {rewritten_query}---") | |
| return rewritten_query | |
| def initialize_graders_and_router(): | |
| """初始化所有评分器和路由器""" | |
| # Load detection method from config | |
| try: | |
| from hallucination_config import HALLUCINATION_DETECTION_METHOD | |
| detection_method = HALLUCINATION_DETECTION_METHOD | |
| except ImportError: | |
| detection_method = "hybrid" # Default to hybrid | |
| query_router = QueryRouter() | |
| document_grader = DocumentGrader() | |
| answer_grader = AnswerGrader() | |
| hallucination_grader = HallucinationGrader(method=detection_method) | |
| query_rewriter = QueryRewriter() | |
| return { | |
| "query_router": query_router, | |
| "document_grader": document_grader, | |
| "answer_grader": answer_grader, | |
| "hallucination_grader": hallucination_grader, | |
| "query_rewriter": query_rewriter | |
| } |