Spaces:
Paused
Paused
File size: 5,977 Bytes
399f3c6 90b33eb 399f3c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
"""
路由器和评分器模块
包含查询路由、文档相关性评分、答案质量评分和幻觉检测
"""
try:
from langchain_core.prompts import PromptTemplate
except ImportError:
from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from config import LOCAL_LLM
class QueryRouter:
"""查询路由器,决定使用向量存储还是网络搜索"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
self.prompt = PromptTemplate(
template="""你是一个专家,负责将用户问题路由到向量存储或网络搜索。
对于关于LLM智能体、提示工程和对抗性攻击的问题,使用向量存储。
你不需要严格匹配问题中与这些主题相关的关键词。
否则,使用网络搜索。根据问题给出二进制选择'web_search'或'vectorstore'。
返回一个只包含'datasource'键的JSON,不要前言或解释。
要路由的问题:{question}""",
input_variables=["question"],
)
self.router = self.prompt | self.llm | JsonOutputParser()
def route(self, question: str) -> str:
"""路由问题到相应的数据源"""
result = self.router.invoke({"question": question})
return result.get("datasource", "web_search")
class DocumentGrader:
"""文档相关性评分器"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
self.prompt = PromptTemplate(
template="""你是一个评分员,评估检索到的文档是否与用户问题相关。
如果文档包含与用户问题相关的关键词或语义,请给出'yes'分数。
给出二进制分数'yes'或'no',以表明文档是否与问题相关。
将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。
检索到的文档:
{document}
用户问题:{question}""",
input_variables=["question", "document"],
)
self.grader = self.prompt | self.llm | JsonOutputParser()
def grade(self, question: str, document: str) -> str:
"""评估文档与问题的相关性"""
result = self.grader.invoke({"question": question, "document": document})
return result.get("score", "no")
class AnswerGrader:
"""答案质量评分器"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
self.prompt = PromptTemplate(
template="""你是一个评分员,评估答案是否有助于解决问题。
这里是答案:
\n ------- \n
{generation}
\n ------- \n
这里是问题:{question}
给出二进制分数'yes'或'no',表示答案是否有助于解决问题。
将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。""",
input_variables=["generation", "question"],
)
self.grader = self.prompt | self.llm | JsonOutputParser()
def grade(self, question: str, generation: str) -> str:
"""评估答案质量"""
result = self.grader.invoke({"question": question, "generation": generation})
return result.get("score", "no")
class HallucinationGrader:
"""幻觉检测器"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
self.prompt = PromptTemplate(
template="""你是一个评分员,评估LLM生成是否基于/支持一组检索到的事实。
给出二进制分数'yes'或'no'。'yes'意味着答案基于/支持文档。
将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。
检索到的文档:
{documents}
LLM生成:{generation}""",
input_variables=["generation", "documents"],
)
self.grader = self.prompt | self.llm | JsonOutputParser()
def grade(self, generation: str, documents) -> str:
"""检测生成内容是否存在幻觉"""
result = self.grader.invoke({"generation": generation, "documents": documents})
return result.get("score", "no")
class QueryRewriter:
"""查询重写器,优化查询以获得更好的检索结果"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, temperature=0)
self.prompt = PromptTemplate(
template="""你是一个问题重写器,将输入问题转换为更适合向量存储检索的更好版本。
查看初始问题并制定一个改进的问题。
这里是初始问题:\n\n {question}。改进的问题(无前言):\n """,
input_variables=["question"],
)
self.rewriter = self.prompt | self.llm | StrOutputParser()
def rewrite(self, question: str) -> str:
"""重写查询以获得更好的检索效果"""
print(f"---原始查询: {question}---")
rewritten_query = self.rewriter.invoke({"question": question})
print(f"---重写查询: {rewritten_query}---")
return rewritten_query
def initialize_graders_and_router():
"""初始化所有评分器和路由器"""
query_router = QueryRouter()
document_grader = DocumentGrader()
answer_grader = AnswerGrader()
hallucination_grader = HallucinationGrader()
query_rewriter = QueryRewriter()
return {
"query_router": query_router,
"document_grader": document_grader,
"answer_grader": answer_grader,
"hallucination_grader": hallucination_grader,
"query_rewriter": query_rewriter
} |