Spaces:
Runtime error
Runtime error
File size: 3,580 Bytes
64dd0b5 f16141d 7d9df5e 64dd0b5 f16141d 64dd0b5 f16141d 64dd0b5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | from langgraph.graph import END, StateGraph
from langchain_core.vectorstores import VectorStore
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.tools import tool, InjectedToolCallId
from langchain_core.messages import ToolMessage
from typing import Callable, List, Annotated
from langchain_core.documents import Document
from typing_extensions import TypedDict
from langgraph.graph.message import add_messages
from langgraph.types import Command
class RagState(TypedDict):
messages: Annotated[list, add_messages]
context: list[Document]
RAG_PROMPT = """\
Given a provided context and a question, you must answer the question. If you do not know the answer, you must state that you do not know.
Context:
{context}
Question:
{question}
Answer:
"""
rag_prompt_template = ChatPromptTemplate.from_template(RAG_PROMPT)
def create_retriever_node(vector_store: VectorStore, search_kwargs: dict = {"k": 5}) -> Callable:
def retriever_node(state: RagState) -> RagState:
retriever = vector_store.as_retriever(search_kwargs=search_kwargs)
retrieved_docs = retriever.invoke(state["messages"][-1].content)
return {"context" : retrieved_docs}
return retriever_node
def create_generator_node(model: BaseChatModel, template: ChatPromptTemplate = rag_prompt_template) -> Callable:
generation_chain = template | model
def generator_node(state: RagState) -> RagState:
response = generation_chain.invoke({"query" : state["messages"][-1].content, "context" : state["context"]})
return {"messages" : response}
return generator_node
def make_rag_graph(model: BaseChatModel, vector_store: VectorStore, template: ChatPromptTemplate = rag_prompt_template, search_kwargs: dict = {"k": 5}) -> StateGraph:
retriever_node = create_retriever_node(vector_store, search_kwargs)
generator_node = create_generator_node(model, template)
rag_graph = StateGraph(RagState)
rag_graph.add_node("retriever", retriever_node)
rag_graph.add_node("generator", generator_node)
rag_graph.set_entry_point("retriever")
rag_graph.add_edge("retriever", "generator")
rag_graph.add_edge("generator", END)
return rag_graph.compile()
def create_vector_search_tool(vector_store: VectorStore, search_kwargs: dict, inject_to_state: bool = False) -> Callable:
# WARNING: the graph state REQUIRES a 'context' key if inject_to_state is True
if inject_to_state:
@tool("vector-search")
def vector_search_tool(query: str, tool_call_id: Annotated[str, InjectedToolCallId]) -> Command:
"""Searches a vector database for the given query and returns relevant document contents."""
retriever = vector_store.as_retriever(search_kwargs=search_kwargs)
retrieved_docs = retriever.invoke(query)
return Command(update={
"context": [doc.page_content for doc in retrieved_docs],
# update the message history
"messages": [
ToolMessage(
f"{[doc.page_content for doc in retrieved_docs]}",
tool_call_id=tool_call_id
)
]
})
return vector_search_tool
else:
@tool("vector-search")
def vector_search_tool(query: str) -> List[str]:
"""Searches a vector database for the given query and returns relevant document contents."""
retriever = vector_store.as_retriever(search_kwargs=search_kwargs)
retrieved_docs = retriever.invoke(query)
return [doc.page_content for doc in retrieved_docs]
return vector_search_tool
|