File size: 5,977 Bytes
399f3c6
 
 
 
 
90b33eb
 
 
 
 
399f3c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
路由器和评分器模块
包含查询路由、文档相关性评分、答案质量评分和幻觉检测
"""

try:
    from langchain_core.prompts import PromptTemplate
except ImportError:
    from langchain.prompts import PromptTemplate

from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser

from config import LOCAL_LLM


class QueryRouter:
    """查询路由器,决定使用向量存储还是网络搜索"""
    
    def __init__(self):
        self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
        self.prompt = PromptTemplate(
            template="""你是一个专家,负责将用户问题路由到向量存储或网络搜索。
            对于关于LLM智能体、提示工程和对抗性攻击的问题,使用向量存储。
            你不需要严格匹配问题中与这些主题相关的关键词。
            否则,使用网络搜索。根据问题给出二进制选择'web_search'或'vectorstore'。
            返回一个只包含'datasource'键的JSON,不要前言或解释。
            要路由的问题:{question}""",
            input_variables=["question"],
        )
        self.router = self.prompt | self.llm | JsonOutputParser()
    
    def route(self, question: str) -> str:
        """路由问题到相应的数据源"""
        result = self.router.invoke({"question": question})
        return result.get("datasource", "web_search")


class DocumentGrader:
    """文档相关性评分器"""
    
    def __init__(self):
        self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
        self.prompt = PromptTemplate(
            template="""你是一个评分员,评估检索到的文档是否与用户问题相关。
            如果文档包含与用户问题相关的关键词或语义,请给出'yes'分数。
            给出二进制分数'yes'或'no',以表明文档是否与问题相关。
            将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。
            
            检索到的文档:

 {document} 


            用户问题:{question}""",
            input_variables=["question", "document"],
        )
        self.grader = self.prompt | self.llm | JsonOutputParser()
    
    def grade(self, question: str, document: str) -> str:
        """评估文档与问题的相关性"""
        result = self.grader.invoke({"question": question, "document": document})
        return result.get("score", "no")


class AnswerGrader:
    """答案质量评分器"""
    
    def __init__(self):
        self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
        self.prompt = PromptTemplate(
            template="""你是一个评分员,评估答案是否有助于解决问题。
            这里是答案:
            \n ------- \n
            {generation} 
            \n ------- \n
            这里是问题:{question}
            给出二进制分数'yes'或'no',表示答案是否有助于解决问题。
            将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。""",
            input_variables=["generation", "question"],
        )
        self.grader = self.prompt | self.llm | JsonOutputParser()
    
    def grade(self, question: str, generation: str) -> str:
        """评估答案质量"""
        result = self.grader.invoke({"question": question, "generation": generation})
        return result.get("score", "no")


class HallucinationGrader:
    """幻觉检测器"""
    
    def __init__(self):
        self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
        self.prompt = PromptTemplate(
            template="""你是一个评分员,评估LLM生成是否基于/支持一组检索到的事实。
            给出二进制分数'yes'或'no'。'yes'意味着答案基于/支持文档。
            将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。
            
            检索到的文档:

 {documents} 


            LLM生成:{generation}""",
            input_variables=["generation", "documents"],
        )
        self.grader = self.prompt | self.llm | JsonOutputParser()
    
    def grade(self, generation: str, documents) -> str:
        """检测生成内容是否存在幻觉"""
        result = self.grader.invoke({"generation": generation, "documents": documents})
        return result.get("score", "no")


class QueryRewriter:
    """查询重写器,优化查询以获得更好的检索结果"""
    
    def __init__(self):
        self.llm = ChatOllama(model=LOCAL_LLM, temperature=0)
        self.prompt = PromptTemplate(
            template="""你是一个问题重写器,将输入问题转换为更适合向量存储检索的更好版本。
            查看初始问题并制定一个改进的问题。
            这里是初始问题:\n\n {question}。改进的问题(无前言):\n """,
            input_variables=["question"],
        )
        self.rewriter = self.prompt | self.llm | StrOutputParser()
    
    def rewrite(self, question: str) -> str:
        """重写查询以获得更好的检索效果"""
        print(f"---原始查询: {question}---")
        rewritten_query = self.rewriter.invoke({"question": question})
        print(f"---重写查询: {rewritten_query}---")
        return rewritten_query


def initialize_graders_and_router():
    """初始化所有评分器和路由器"""
    query_router = QueryRouter()
    document_grader = DocumentGrader()
    answer_grader = AnswerGrader()
    hallucination_grader = HallucinationGrader()
    query_rewriter = QueryRewriter()
    
    return {
        "query_router": query_router,
        "document_grader": document_grader,
        "answer_grader": answer_grader,
        "hallucination_grader": hallucination_grader,
        "query_rewriter": query_rewriter
    }