File size: 8,624 Bytes
92feab2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
from typing import List, Optional, Dict, Any, TypedDict, Annotated, Sequence
from functools import partial
import os

import gradio as gr

from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.language_models.llms import LLM

from langgraph.prebuilt import tools_condition, ToolNode
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.state import StateGraph
from langgraph.graph.message import add_messages
from langgraph.constants import START, END

try:
    from utils import html_format_docs_chat, get_session_id
    from tools.question_reformulation import reformulate_question_using_history
    from tools.org_seach import (
        extract_org_links_from_chatbot,
        embed_org_links_in_text,
        generate_org_link_dict,
    )
    from retrieval.elastic import retriever_tool
except ImportError:
    from .utils import html_format_docs_chat, get_session_id
    from .tools.question_reformulation import reformulate_question_using_history
    from .tools.org_seach import (
        extract_org_links_from_chatbot,
        embed_org_links_in_text,
        generate_org_link_dict,
    )
    from .retrieval.elastic import retriever_tool

ROOT = os.path.dirname(os.path.abspath(__file__))

# TODO https://www.metadocs.co/2024/08/29/simple-domain-specific-corrective-rag-with-langchain-and-langgraph/


class AgentState(TypedDict):
    # The add_messages function defines how an update should be processed
    # Default is to replace. add_messages says "append"
    messages: Annotated[Sequence[BaseMessage], add_messages]
    user_input: str
    org_dict: Dict


def search_agent(state, llm: LLM, tools) -> AgentState:
    """Invokes the agent model to generate a response based on the current state. Given
    the question, it will decide to retrieve using the retriever tool, or simply end.

    Parameters
    ----------
    state : _type_
        The current state
    llm : LLM
    tools : _type_
        _description_

    Returns
    -------
    AgentState
        The updated state with the agent response appended to messages
    """

    print("---SEARCH AGENT---")
    messages = state["messages"]
    question = messages[-1].content

    model = llm.bind_tools(tools)
    response = model.invoke(messages)
    # return a list, because this will get added to the existing list
    return {"messages": [response], "user_input": question}


def generate_with_context(state, llm: LLM) -> AgentState:
    """Generate answer.

    Parameters
    ----------
    state : _type_
        The current state
    llm : LLM
    tools : _type_
        _description_

    Returns
    -------
    AgentState
        The updated state with the agent response appended to messages
    """

    print("---GENERATE ANSWER---")
    messages = state["messages"]
    question = state["user_input"]
    last_message = messages[-1]

    sources_str = last_message.content
    sources_list = last_message.artifact  # cannot use directly as list of Documents
    # converting to html string
    sources_html = html_format_docs_chat(sources_list)
    if sources_list:
        print("---ADD SOURCES---")
    state["messages"].append(BaseMessage(content=sources_html, type="HTML"))

    # Prompt
    qa_system_prompt = """
        You are an assistant for question-answering tasks in the social and philanthropic sector. \n
        Use the following pieces of retrieved context to answer the question at the end. \n
        If you don't know the answer, just say that you don't know. \n
        Keep the response professional, friendly, and as concise as possible. \n
        Question: {question}
        Context: {context}
        Answer:
        """

    qa_prompt = ChatPromptTemplate(
        [
            ("system", qa_system_prompt),
            ("human", question),
        ]
    )

    rag_chain = qa_prompt | llm | StrOutputParser()
    response = rag_chain.invoke({"context": sources_str, "question": question})
    # couldn't figure out why returning usual "response" was seen as HumanMessage
    return {"messages": [AIMessage(content=response)], "user_input": question}


def has_org_name(state: AgentState) -> AgentState:
    """
    Processes the latest message to extract organization links and determine the next step.

    Args:
        state (AgentState): The current state of the agent, including a list of messages.

    Returns:
        dict: A dictionary with the next agent action and, if available, a dictionary of organization links.
    """
    print("---HAS ORG NAMES?---")
    messages = state["messages"]
    last_message = messages[-1].content
    output_list = extract_org_links_from_chatbot(last_message)
    link_dict = generate_org_link_dict(output_list) if output_list else {}
    if link_dict:
        print("---FOUND ORG NAMES---")
        return {"next": "insert_org_link", "org_dict": link_dict}
    print("---NO ORG NAMES FOUND---")
    return {"next": END, "messages": messages}


def insert_org_link(state: AgentState) -> AgentState:
    """
    Embeds organization links in the latest message content and returns it as an AI message.

    Args:
        state (dict): The current state, including the organization links and latest message.

    Returns:
        dict: A dictionary with the updated message content as an AIMessage.
    """
    print("---INSERT ORG LINKS---")
    messages = state["messages"]
    last_message = messages[-1].content
    messages.pop(-1) # Deleting the original message because we will append the same one but with links
    link_dict = state["org_dict"]
    last_message = embed_org_links_in_text(last_message, link_dict)
    return {"messages": [AIMessage(content=last_message)]}


def build_compute_graph(llm: LLM, indices: List[str]) -> StateGraph:
    candid_retriever_tool = retriever_tool(indices=indices)
    retrieve = ToolNode([candid_retriever_tool])
    tools = [candid_retriever_tool]

    G = StateGraph(AgentState)
    # Add nodes
    G.add_node("reformulate", partial(reformulate_question_using_history, llm=llm))
    G.add_node("search_agent", partial(search_agent, llm=llm, tools=tools))
    G.add_node("retrieve", retrieve)
    G.add_node("generate_with_context", partial(generate_with_context, llm=llm))
    G.add_node("has_org_name", has_org_name)
    G.add_node("insert_org_link", insert_org_link)

    # Add edges
    G.add_edge(START, "reformulate")
    G.add_edge("reformulate", "search_agent")
    # Conditional edges from search_agent
    G.add_conditional_edges(
        source="search_agent",
        path=tools_condition, # TODO just a conditional edge here?
        path_map={
            "tools": "retrieve",
            "__end__": "has_org_name",
        },
    )
    G.add_edge("retrieve", "generate_with_context")

    # Add edges
    G.add_edge("generate_with_context", "has_org_name")
    # Use add_conditional_edges for has_org_name
    G.add_conditional_edges(
        "has_org_name",
        lambda x: x["next"],  # Now we're accessing the 'next' key from the dict
        {"insert_org_link": "insert_org_link", END: END},
    )
    G.add_edge("insert_org_link", END)

    return G


def run_chat(
    thread_id: str,
    user_input: Dict[str, Any],
    chatbot: List[Dict],
    llm: LLM,
    indices: Optional[List[str]] = None,
):
    # https://langchain-ai.github.io/langgraph/tutorials/rag/langgraph_agentic_rag/#graph

    chatbot.append({"role": "user", "content": user_input["text"]})
    inputs = {"messages": chatbot}
    # thread_id can be an email https://github.com/yurisasc/memory-enhanced-ai-assistant/blob/main/assistant.py
    thread_id = get_session_id(thread_id)
    config = {"configurable": {"thread_id": thread_id}}

    workflow = build_compute_graph(llm=llm, indices=indices)

    memory = MemorySaver()  # TODO: don't use for Prod
    graph = workflow.compile(checkpointer=memory)
    response = graph.invoke(inputs, config=config)
    messages = response["messages"]
    last_message = messages[-1]
    ai_answer = last_message.content
    sources_html = ""
    for message in messages[-2:]:
        if message.type == "HTML":
            sources_html = message.content

    chatbot.append({"role": "assistant", "content": ai_answer})
    if sources_html:
        chatbot.append(
            {
                "role": "assistant",
                "content": sources_html,
                "metadata": {"title": "Sources HTML"},
            }
        )

    return gr.MultimodalTextbox(value=None, interactive=True), chatbot, thread_id