adaptive_rag / workflow_nodes.py
lanny xu
resolve conflict
90b33eb
raw
history blame
7.74 kB
"""
工作流节点模块
包含所有工作流节点函数和状态管理
"""
from typing import List
from typing_extensions import TypedDict
try:
from langchain_core.documents import Document
except ImportError:
from langchain.schema import Document
from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import StrOutputParser
from langchain_community.tools.tavily_search import TavilySearchResults
try:
from langchain_core.prompts import PromptTemplate
except ImportError:
from langchain.prompts import PromptTemplate
from config import LOCAL_LLM, WEB_SEARCH_RESULTS_COUNT
from pprint import pprint
class GraphState(TypedDict):
"""
表示图的状态
属性:
question: 问题
generation: LLM生成
documents: 文档列表
"""
question: str
generation: str
documents: List[str]
class WorkflowNodes:
"""工作流节点类,包含所有节点函数"""
def __init__(self, retriever, graders):
self.retriever = retriever
self.graders = graders
# 设置RAG链 - 使用本地提示模板
rag_prompt_template = PromptTemplate(
template="""你是一个问答助手。使用以下检索到的上下文来回答问题。
如果你不知道答案,就说你不知道。最多使用三句话并保持答案简洁。
问题: {question}
上下文: {context}
答案:""",
input_variables=["question", "context"]
)
llm = ChatOllama(model=LOCAL_LLM, temperature=0)
self.rag_chain = rag_prompt_template | llm | StrOutputParser()
# 设置网络搜索
self.web_search_tool = TavilySearchResults(k=WEB_SEARCH_RESULTS_COUNT)
def retrieve(self, state):
"""
检索文档
Args:
state (dict): 当前图状态
Returns:
state (dict): 添加了documents键的新状态,包含检索到的文档
"""
print("---检索---")
question = state["question"]
# 检索
documents = self.retriever.get_relevant_documents(question)
return {"documents": documents, "question": question}
def generate(self, state):
"""
生成答案
Args:
state (dict): 当前图状态
Returns:
state (dict): 添加了generation键的新状态,包含LLM生成
"""
print("---生成---")
question = state["question"]
documents = state["documents"]
# RAG生成
generation = self.rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
def grade_documents(self, state):
"""
确定检索到的文档是否与问题相关
Args:
state (dict): 当前图状态
Returns:
state (dict): 更新documents键,只包含过滤后的相关文档
"""
print("---检查文档与问题的相关性---")
question = state["question"]
documents = state["documents"]
# 为每个文档评分
filtered_docs = []
for d in documents:
score = self.graders["document_grader"].grade(question, d.page_content)
grade = score
if grade == "yes":
print("---评分:文档相关---")
filtered_docs.append(d)
else:
print("---评分:文档不相关---")
continue
return {"documents": filtered_docs, "question": question}
def transform_query(self, state):
"""
转换查询以产生更好的问题
Args:
state (dict): 当前图状态
Returns:
state (dict): 用重新表述的问题更新question键
"""
print("---转换查询---")
question = state["question"]
documents = state["documents"]
# 重写问题
better_question = self.graders["query_rewriter"].rewrite(question)
return {"documents": documents, "question": better_question}
def web_search(self, state):
"""
基于重新表述的问题进行网络搜索
Args:
state (dict): 当前图状态
Returns:
state (dict): 用附加的网络结果更新documents键
"""
print("---网络搜索---")
question = state["question"]
# 网络搜索
docs = self.web_search_tool.invoke({"query": question})
web_results = "\n".join([d["content"] for d in docs])
web_results = Document(page_content=web_results)
return {"documents": web_results, "question": question}
def route_question(self, state):
"""
将问题路由到网络搜索或RAG
Args:
state (dict): 当前图状态
Returns:
str: 要调用的下一个节点
"""
print("---路由问题---")
question = state["question"]
print(question)
source = self.graders["query_router"].route(question)
print(source)
if source == "web_search":
print("---将问题路由到网络搜索---")
return "web_search"
elif source == "vectorstore":
print("---将问题路由到RAG---")
return "vectorstore"
def decide_to_generate(self, state):
"""
确定是生成答案还是重新生成问题
Args:
state (dict): 当前图状态
Returns:
str: 要调用的下一个节点的二进制决策
"""
print("---评估已评分的文档---")
filtered_documents = state["documents"]
if not filtered_documents:
# 所有文档都被过滤掉了
# 我们将重新生成一个新查询
print("---决策:所有文档都与问题不相关,转换查询---")
return "transform_query"
else:
# 我们有相关文档,所以生成答案
print("---决策:生成---")
return "generate"
def grade_generation_v_documents_and_question(self, state):
"""
确定生成是否基于文档并回答问题
Args:
state (dict): 当前图状态
Returns:
str: 要调用的下一个节点的决策
"""
print("---检查幻觉---")
question = state["question"]
documents = state["documents"]
generation = state["generation"]
score = self.graders["hallucination_grader"].grade(generation, documents)
grade = score
# 检查幻觉
if grade == "yes":
print("---决策:生成基于文档---")
# 检查问题回答
print("---评分生成 vs 问题---")
score = self.graders["answer_grader"].grade(question, generation)
grade = score
if grade == "yes":
print("---决策:生成解决了问题---")
return "useful"
else:
print("---决策:生成没有解决问题---")
return "not useful"
else:
print("---决策:生成不基于文档,重试---")
return "not supported"
def format_docs(docs):
"""格式化文档用于显示"""
return "\n\n".join(doc.page_content for doc in docs)