File size: 3,580 Bytes
64dd0b5
 
 
 
f16141d
 
7d9df5e
64dd0b5
 
 
f16141d
64dd0b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f16141d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64dd0b5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from langgraph.graph import END, StateGraph
from langchain_core.vectorstores import VectorStore
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.tools import tool, InjectedToolCallId
from langchain_core.messages import ToolMessage
from typing import Callable, List, Annotated
from langchain_core.documents import Document
from typing_extensions import TypedDict
from langgraph.graph.message import add_messages
from langgraph.types import Command

class RagState(TypedDict):
  messages: Annotated[list, add_messages]
  context: list[Document]

RAG_PROMPT = """\
Given a provided context and a question, you must answer the question. If you do not know the answer, you must state that you do not know.

Context:
{context}

Question:
{question}

Answer:
"""

rag_prompt_template = ChatPromptTemplate.from_template(RAG_PROMPT)


def create_retriever_node(vector_store: VectorStore, search_kwargs: dict = {"k": 5}) -> Callable:
  def retriever_node(state: RagState) -> RagState:
    retriever = vector_store.as_retriever(search_kwargs=search_kwargs)
    retrieved_docs = retriever.invoke(state["messages"][-1].content)
    return {"context" : retrieved_docs}
  return retriever_node


def create_generator_node(model: BaseChatModel, template: ChatPromptTemplate = rag_prompt_template) -> Callable:
  generation_chain = template | model
  def generator_node(state: RagState) -> RagState:
    response = generation_chain.invoke({"query" : state["messages"][-1].content, "context" : state["context"]})
    return {"messages" : response}
  return generator_node


def make_rag_graph(model: BaseChatModel, vector_store: VectorStore, template: ChatPromptTemplate = rag_prompt_template, search_kwargs: dict = {"k": 5}) -> StateGraph:
  retriever_node = create_retriever_node(vector_store, search_kwargs)
  generator_node = create_generator_node(model, template)

  rag_graph = StateGraph(RagState)

  rag_graph.add_node("retriever", retriever_node)
  rag_graph.add_node("generator", generator_node)

  rag_graph.set_entry_point("retriever")

  rag_graph.add_edge("retriever", "generator")
  rag_graph.add_edge("generator", END)

  return rag_graph.compile()

def create_vector_search_tool(vector_store: VectorStore, search_kwargs: dict, inject_to_state: bool = False) -> Callable:
  # WARNING: the graph state REQUIRES a 'context' key if inject_to_state is True
  if inject_to_state:
    @tool("vector-search")
    def vector_search_tool(query: str, tool_call_id: Annotated[str, InjectedToolCallId]) -> Command:
      """Searches a vector database for the given query and returns relevant document contents."""
      retriever = vector_store.as_retriever(search_kwargs=search_kwargs)
      retrieved_docs = retriever.invoke(query)
      return Command(update={
          "context": [doc.page_content for doc in retrieved_docs],
          # update the message history
          "messages": [
              ToolMessage(
                  f"{[doc.page_content for doc in retrieved_docs]}",
                  tool_call_id=tool_call_id
              )
          ]
      })
    return vector_search_tool
  
  else:
    @tool("vector-search")
    def vector_search_tool(query: str) -> List[str]:
      """Searches a vector database for the given query and returns relevant document contents."""
      retriever = vector_store.as_retriever(search_kwargs=search_kwargs)
      retrieved_docs = retriever.invoke(query)
      return [doc.page_content for doc in retrieved_docs]
    return vector_search_tool