Spaces:
Running
Running
| from typing import List, Optional, Dict, Any, TypedDict, Annotated, Sequence | |
| from functools import partial | |
| import os | |
| import gradio as gr | |
| from langchain_core.messages import AIMessage, BaseMessage | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.language_models.llms import LLM | |
| from langgraph.prebuilt import tools_condition, ToolNode | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langgraph.graph.state import StateGraph | |
| from langgraph.graph.message import add_messages | |
| from langgraph.constants import START, END | |
| try: | |
| from utils import html_format_docs_chat, get_session_id | |
| from tools.question_reformulation import reformulate_question_using_history | |
| from tools.org_seach import ( | |
| extract_org_links_from_chatbot, | |
| embed_org_links_in_text, | |
| generate_org_link_dict, | |
| ) | |
| from retrieval.elastic import retriever_tool | |
| except ImportError: | |
| from .utils import html_format_docs_chat, get_session_id | |
| from .tools.question_reformulation import reformulate_question_using_history | |
| from .tools.org_seach import ( | |
| extract_org_links_from_chatbot, | |
| embed_org_links_in_text, | |
| generate_org_link_dict, | |
| ) | |
| from .retrieval.elastic import retriever_tool | |
| ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| # TODO https://www.metadocs.co/2024/08/29/simple-domain-specific-corrective-rag-with-langchain-and-langgraph/ | |
| class AgentState(TypedDict): | |
| # The add_messages function defines how an update should be processed | |
| # Default is to replace. add_messages says "append" | |
| messages: Annotated[Sequence[BaseMessage], add_messages] | |
| user_input: str | |
| org_dict: Dict | |
| def search_agent(state, llm: LLM, tools) -> AgentState: | |
| """Invokes the agent model to generate a response based on the current state. Given | |
| the question, it will decide to retrieve using the retriever tool, or simply end. | |
| Parameters | |
| ---------- | |
| state : _type_ | |
| The current state | |
| llm : LLM | |
| tools : _type_ | |
| _description_ | |
| Returns | |
| ------- | |
| AgentState | |
| The updated state with the agent response appended to messages | |
| """ | |
| print("---SEARCH AGENT---") | |
| messages = state["messages"] | |
| question = messages[-1].content | |
| model = llm.bind_tools(tools) | |
| response = model.invoke(messages) | |
| # return a list, because this will get added to the existing list | |
| return {"messages": [response], "user_input": question} | |
| def generate_with_context(state, llm: LLM) -> AgentState: | |
| """Generate answer. | |
| Parameters | |
| ---------- | |
| state : _type_ | |
| The current state | |
| llm : LLM | |
| tools : _type_ | |
| _description_ | |
| Returns | |
| ------- | |
| AgentState | |
| The updated state with the agent response appended to messages | |
| """ | |
| print("---GENERATE ANSWER---") | |
| messages = state["messages"] | |
| question = state["user_input"] | |
| last_message = messages[-1] | |
| sources_str = last_message.content | |
| sources_list = last_message.artifact # cannot use directly as list of Documents | |
| # converting to html string | |
| sources_html = html_format_docs_chat(sources_list) | |
| if sources_list: | |
| print("---ADD SOURCES---") | |
| state["messages"].append(BaseMessage(content=sources_html, type="HTML")) | |
| # Prompt | |
| qa_system_prompt = """ | |
| You are an assistant for question-answering tasks in the social and philanthropic sector. \n | |
| Use the following pieces of retrieved context to answer the question at the end. \n | |
| If you don't know the answer, just say that you don't know. \n | |
| Keep the response professional, friendly, and as concise as possible. \n | |
| Question: {question} | |
| Context: {context} | |
| Answer: | |
| """ | |
| qa_prompt = ChatPromptTemplate( | |
| [ | |
| ("system", qa_system_prompt), | |
| ("human", question), | |
| ] | |
| ) | |
| rag_chain = qa_prompt | llm | StrOutputParser() | |
| response = rag_chain.invoke({"context": sources_str, "question": question}) | |
| # couldn't figure out why returning usual "response" was seen as HumanMessage | |
| return {"messages": [AIMessage(content=response)], "user_input": question} | |
| def has_org_name(state: AgentState) -> AgentState: | |
| """ | |
| Processes the latest message to extract organization links and determine the next step. | |
| Args: | |
| state (AgentState): The current state of the agent, including a list of messages. | |
| Returns: | |
| dict: A dictionary with the next agent action and, if available, a dictionary of organization links. | |
| """ | |
| print("---HAS ORG NAMES?---") | |
| messages = state["messages"] | |
| last_message = messages[-1].content | |
| output_list = extract_org_links_from_chatbot(last_message) | |
| link_dict = generate_org_link_dict(output_list) if output_list else {} | |
| if link_dict: | |
| print("---FOUND ORG NAMES---") | |
| return {"next": "insert_org_link", "org_dict": link_dict} | |
| print("---NO ORG NAMES FOUND---") | |
| return {"next": END, "messages": messages} | |
| def insert_org_link(state: AgentState) -> AgentState: | |
| """ | |
| Embeds organization links in the latest message content and returns it as an AI message. | |
| Args: | |
| state (dict): The current state, including the organization links and latest message. | |
| Returns: | |
| dict: A dictionary with the updated message content as an AIMessage. | |
| """ | |
| print("---INSERT ORG LINKS---") | |
| messages = state["messages"] | |
| last_message = messages[-1].content | |
| messages.pop(-1) # Deleting the original message because we will append the same one but with links | |
| link_dict = state["org_dict"] | |
| last_message = embed_org_links_in_text(last_message, link_dict) | |
| return {"messages": [AIMessage(content=last_message)]} | |
| def build_compute_graph(llm: LLM, indices: List[str]) -> StateGraph: | |
| candid_retriever_tool = retriever_tool(indices=indices) | |
| retrieve = ToolNode([candid_retriever_tool]) | |
| tools = [candid_retriever_tool] | |
| G = StateGraph(AgentState) | |
| # Add nodes | |
| G.add_node("reformulate", partial(reformulate_question_using_history, llm=llm)) | |
| G.add_node("search_agent", partial(search_agent, llm=llm, tools=tools)) | |
| G.add_node("retrieve", retrieve) | |
| G.add_node("generate_with_context", partial(generate_with_context, llm=llm)) | |
| G.add_node("has_org_name", has_org_name) | |
| G.add_node("insert_org_link", insert_org_link) | |
| # Add edges | |
| G.add_edge(START, "reformulate") | |
| G.add_edge("reformulate", "search_agent") | |
| # Conditional edges from search_agent | |
| G.add_conditional_edges( | |
| source="search_agent", | |
| path=tools_condition, # TODO just a conditional edge here? | |
| path_map={ | |
| "tools": "retrieve", | |
| "__end__": "has_org_name", | |
| }, | |
| ) | |
| G.add_edge("retrieve", "generate_with_context") | |
| # Add edges | |
| G.add_edge("generate_with_context", "has_org_name") | |
| # Use add_conditional_edges for has_org_name | |
| G.add_conditional_edges( | |
| "has_org_name", | |
| lambda x: x["next"], # Now we're accessing the 'next' key from the dict | |
| {"insert_org_link": "insert_org_link", END: END}, | |
| ) | |
| G.add_edge("insert_org_link", END) | |
| return G | |
| def run_chat( | |
| thread_id: str, | |
| user_input: Dict[str, Any], | |
| chatbot: List[Dict], | |
| llm: LLM, | |
| indices: Optional[List[str]] = None, | |
| ): | |
| # https://langchain-ai.github.io/langgraph/tutorials/rag/langgraph_agentic_rag/#graph | |
| chatbot.append({"role": "user", "content": user_input["text"]}) | |
| inputs = {"messages": chatbot} | |
| # thread_id can be an email https://github.com/yurisasc/memory-enhanced-ai-assistant/blob/main/assistant.py | |
| thread_id = get_session_id(thread_id) | |
| config = {"configurable": {"thread_id": thread_id}} | |
| workflow = build_compute_graph(llm=llm, indices=indices) | |
| memory = MemorySaver() # TODO: don't use for Prod | |
| graph = workflow.compile(checkpointer=memory) | |
| response = graph.invoke(inputs, config=config) | |
| messages = response["messages"] | |
| last_message = messages[-1] | |
| ai_answer = last_message.content | |
| sources_html = "" | |
| for message in messages[-2:]: | |
| if message.type == "HTML": | |
| sources_html = message.content | |
| chatbot.append({"role": "assistant", "content": ai_answer}) | |
| if sources_html: | |
| chatbot.append( | |
| { | |
| "role": "assistant", | |
| "content": sources_html, | |
| "metadata": {"title": "Sources HTML"}, | |
| } | |
| ) | |
| return gr.MultimodalTextbox(value=None, interactive=True), chatbot, thread_id | |