Spaces:
Running
Running
File size: 8,624 Bytes
92feab2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
from typing import List, Optional, Dict, Any, TypedDict, Annotated, Sequence
from functools import partial
import os
import gradio as gr
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.language_models.llms import LLM
from langgraph.prebuilt import tools_condition, ToolNode
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.state import StateGraph
from langgraph.graph.message import add_messages
from langgraph.constants import START, END
try:
from utils import html_format_docs_chat, get_session_id
from tools.question_reformulation import reformulate_question_using_history
from tools.org_seach import (
extract_org_links_from_chatbot,
embed_org_links_in_text,
generate_org_link_dict,
)
from retrieval.elastic import retriever_tool
except ImportError:
from .utils import html_format_docs_chat, get_session_id
from .tools.question_reformulation import reformulate_question_using_history
from .tools.org_seach import (
extract_org_links_from_chatbot,
embed_org_links_in_text,
generate_org_link_dict,
)
from .retrieval.elastic import retriever_tool
ROOT = os.path.dirname(os.path.abspath(__file__))
# TODO https://www.metadocs.co/2024/08/29/simple-domain-specific-corrective-rag-with-langchain-and-langgraph/
class AgentState(TypedDict):
# The add_messages function defines how an update should be processed
# Default is to replace. add_messages says "append"
messages: Annotated[Sequence[BaseMessage], add_messages]
user_input: str
org_dict: Dict
def search_agent(state, llm: LLM, tools) -> AgentState:
"""Invokes the agent model to generate a response based on the current state. Given
the question, it will decide to retrieve using the retriever tool, or simply end.
Parameters
----------
state : _type_
The current state
llm : LLM
tools : _type_
_description_
Returns
-------
AgentState
The updated state with the agent response appended to messages
"""
print("---SEARCH AGENT---")
messages = state["messages"]
question = messages[-1].content
model = llm.bind_tools(tools)
response = model.invoke(messages)
# return a list, because this will get added to the existing list
return {"messages": [response], "user_input": question}
def generate_with_context(state, llm: LLM) -> AgentState:
"""Generate answer.
Parameters
----------
state : _type_
The current state
llm : LLM
tools : _type_
_description_
Returns
-------
AgentState
The updated state with the agent response appended to messages
"""
print("---GENERATE ANSWER---")
messages = state["messages"]
question = state["user_input"]
last_message = messages[-1]
sources_str = last_message.content
sources_list = last_message.artifact # cannot use directly as list of Documents
# converting to html string
sources_html = html_format_docs_chat(sources_list)
if sources_list:
print("---ADD SOURCES---")
state["messages"].append(BaseMessage(content=sources_html, type="HTML"))
# Prompt
qa_system_prompt = """
You are an assistant for question-answering tasks in the social and philanthropic sector. \n
Use the following pieces of retrieved context to answer the question at the end. \n
If you don't know the answer, just say that you don't know. \n
Keep the response professional, friendly, and as concise as possible. \n
Question: {question}
Context: {context}
Answer:
"""
qa_prompt = ChatPromptTemplate(
[
("system", qa_system_prompt),
("human", question),
]
)
rag_chain = qa_prompt | llm | StrOutputParser()
response = rag_chain.invoke({"context": sources_str, "question": question})
# couldn't figure out why returning usual "response" was seen as HumanMessage
return {"messages": [AIMessage(content=response)], "user_input": question}
def has_org_name(state: AgentState) -> AgentState:
"""
Processes the latest message to extract organization links and determine the next step.
Args:
state (AgentState): The current state of the agent, including a list of messages.
Returns:
dict: A dictionary with the next agent action and, if available, a dictionary of organization links.
"""
print("---HAS ORG NAMES?---")
messages = state["messages"]
last_message = messages[-1].content
output_list = extract_org_links_from_chatbot(last_message)
link_dict = generate_org_link_dict(output_list) if output_list else {}
if link_dict:
print("---FOUND ORG NAMES---")
return {"next": "insert_org_link", "org_dict": link_dict}
print("---NO ORG NAMES FOUND---")
return {"next": END, "messages": messages}
def insert_org_link(state: AgentState) -> AgentState:
"""
Embeds organization links in the latest message content and returns it as an AI message.
Args:
state (dict): The current state, including the organization links and latest message.
Returns:
dict: A dictionary with the updated message content as an AIMessage.
"""
print("---INSERT ORG LINKS---")
messages = state["messages"]
last_message = messages[-1].content
messages.pop(-1) # Deleting the original message because we will append the same one but with links
link_dict = state["org_dict"]
last_message = embed_org_links_in_text(last_message, link_dict)
return {"messages": [AIMessage(content=last_message)]}
def build_compute_graph(llm: LLM, indices: List[str]) -> StateGraph:
candid_retriever_tool = retriever_tool(indices=indices)
retrieve = ToolNode([candid_retriever_tool])
tools = [candid_retriever_tool]
G = StateGraph(AgentState)
# Add nodes
G.add_node("reformulate", partial(reformulate_question_using_history, llm=llm))
G.add_node("search_agent", partial(search_agent, llm=llm, tools=tools))
G.add_node("retrieve", retrieve)
G.add_node("generate_with_context", partial(generate_with_context, llm=llm))
G.add_node("has_org_name", has_org_name)
G.add_node("insert_org_link", insert_org_link)
# Add edges
G.add_edge(START, "reformulate")
G.add_edge("reformulate", "search_agent")
# Conditional edges from search_agent
G.add_conditional_edges(
source="search_agent",
path=tools_condition, # TODO just a conditional edge here?
path_map={
"tools": "retrieve",
"__end__": "has_org_name",
},
)
G.add_edge("retrieve", "generate_with_context")
# Add edges
G.add_edge("generate_with_context", "has_org_name")
# Use add_conditional_edges for has_org_name
G.add_conditional_edges(
"has_org_name",
lambda x: x["next"], # Now we're accessing the 'next' key from the dict
{"insert_org_link": "insert_org_link", END: END},
)
G.add_edge("insert_org_link", END)
return G
def run_chat(
thread_id: str,
user_input: Dict[str, Any],
chatbot: List[Dict],
llm: LLM,
indices: Optional[List[str]] = None,
):
# https://langchain-ai.github.io/langgraph/tutorials/rag/langgraph_agentic_rag/#graph
chatbot.append({"role": "user", "content": user_input["text"]})
inputs = {"messages": chatbot}
# thread_id can be an email https://github.com/yurisasc/memory-enhanced-ai-assistant/blob/main/assistant.py
thread_id = get_session_id(thread_id)
config = {"configurable": {"thread_id": thread_id}}
workflow = build_compute_graph(llm=llm, indices=indices)
memory = MemorySaver() # TODO: don't use for Prod
graph = workflow.compile(checkpointer=memory)
response = graph.invoke(inputs, config=config)
messages = response["messages"]
last_message = messages[-1]
ai_answer = last_message.content
sources_html = ""
for message in messages[-2:]:
if message.type == "HTML":
sources_html = message.content
chatbot.append({"role": "assistant", "content": ai_answer})
if sources_html:
chatbot.append(
{
"role": "assistant",
"content": sources_html,
"metadata": {"title": "Sources HTML"},
}
)
return gr.MultimodalTextbox(value=None, interactive=True), chatbot, thread_id
|