Spaces:
Running
Running
| from typing import List, Optional, Callable, Any | |
| from functools import partial | |
| import logging | |
| from langchain_core.messages import AIMessage, BaseMessage | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.language_models.llms import LLM | |
| from langgraph.prebuilt import tools_condition, ToolNode | |
| from langgraph.graph.state import StateGraph | |
| from langgraph.constants import START, END | |
| from ask_candid.tools.recommendation import ( | |
| detect_intent_with_llm, | |
| determine_context, | |
| make_recommendation | |
| ) | |
| from ask_candid.tools.question_reformulation import reformulate_question_using_history | |
| from ask_candid.tools.org_seach import has_org_name, insert_org_link | |
| from ask_candid.tools.search import search_agent, retriever_tool | |
| from ask_candid.agents.schema import AgentState | |
| from ask_candid.base.config.data import DataIndices | |
| from ask_candid.utils import html_format_docs_chat | |
| logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s") | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| def generate_with_context( | |
| state: AgentState, | |
| llm: LLM, | |
| user_callback: Optional[Callable[[str], Any]] = None | |
| ) -> AgentState: | |
| """Generate answer. | |
| Parameters | |
| ---------- | |
| state : AgentState | |
| The current state | |
| llm : LLM | |
| user_callback : Optional[Callable[[str], Any]], optional | |
| Optional UI callback to inform the user of apps states, by default None | |
| Returns | |
| ------- | |
| AgentState | |
| The updated state with the agent response appended to messages | |
| """ | |
| logger.info("---GENERATE ANSWER---") | |
| if user_callback is not None: | |
| try: | |
| user_callback("Writing a response...") | |
| except Exception as ex: | |
| logger.warning("User callback was passed in but failed: %s", ex) | |
| messages = state["messages"] | |
| question = state["user_input"] | |
| last_message = messages[-1] | |
| sources_str = last_message.content | |
| sources_list = last_message.artifact | |
| sources_html = html_format_docs_chat(sources_list) | |
| if sources_list: | |
| logger.info("---ADD SOURCES---") | |
| state["messages"].append(BaseMessage(content=sources_html, type="HTML")) | |
| # Prompt | |
| qa_system_prompt = """ | |
| You are an assistant for question-answering tasks in the social and philanthropic sector. \n | |
| Use the following pieces of retrieved context to answer the question at the end. \n | |
| If you don't know the answer, just say that you don't know. \n | |
| Keep the response professional, friendly, and as concise as possible. \n | |
| Question: {question} | |
| Context: {context} | |
| Answer: | |
| """ | |
| qa_prompt = ChatPromptTemplate([ | |
| ("system", qa_system_prompt), | |
| ("human", question), | |
| ]) | |
| rag_chain = qa_prompt | llm | StrOutputParser() | |
| response = rag_chain.invoke({"context": sources_str, "question": question}) | |
| return {"messages": [AIMessage(content=response)], "user_input": question} | |
| def add_recommendations_pipeline_( | |
| G: StateGraph, | |
| llm: LLM, | |
| reformulation_node_name: str = "reformulate", | |
| search_node_name: str = "search_agent" | |
| ) -> None: | |
| """Adds execution sub-graph for recommendation engine flow. Graph changes are in-place. | |
| Parameters | |
| ---------- | |
| G : StateGraph | |
| Execution graph | |
| reformulation_node_name : str, optional | |
| Name of the node which reforumates input queries, by default "reformulate" | |
| search_node_name : str, optional | |
| Name of the node which executes document search + retrieval, by default "search_agent" | |
| """ | |
| # Nodes for recommendation functionalities | |
| G.add_node(node="detect_intent_with_llm", action=partial(detect_intent_with_llm, llm=llm)) | |
| G.add_node(node="determine_context", action=determine_context) | |
| G.add_node(node="make_recommendation", action=make_recommendation) | |
| # Check for recommendation query first | |
| # Execute until reaching END if user asks for recommendation | |
| G.add_edge(start_key=reformulation_node_name, end_key="detect_intent_with_llm") | |
| G.add_conditional_edges( | |
| source="detect_intent_with_llm", | |
| path=lambda state: "determine_context" if state["intent"] in ["rfp", "funder"] else search_node_name, | |
| path_map={ | |
| "determine_context": "determine_context", | |
| search_node_name: search_node_name | |
| }, | |
| ) | |
| G.add_edge(start_key="determine_context", end_key="make_recommendation") | |
| G.add_edge(start_key="make_recommendation", end_key=END) | |
| def build_compute_graph( | |
| llm: LLM, | |
| indices: List[DataIndices], | |
| enable_recommendations: bool = False, | |
| user_callback: Optional[Callable[[str], Any]] = None | |
| ) -> StateGraph: | |
| """Execution graph builder, the output is the execution flow for an interaction with the assistant. | |
| Parameters | |
| ---------- | |
| llm : LLM | |
| indices : List[DataIndices] | |
| Semantic index names to search over | |
| enable_recommendations : bool, optional | |
| Set to `True` to allow the flow to generate recommendations based on context, by default False | |
| user_callback : Optional[Callable[[str], Any]], optional | |
| Optional UI callback to inform the user of apps states, by default None | |
| Returns | |
| ------- | |
| StateGraph | |
| Execution graph | |
| """ | |
| candid_retriever_tool = retriever_tool(indices=indices, user_callback=user_callback) | |
| retrieve = ToolNode([candid_retriever_tool]) | |
| tools = [candid_retriever_tool] | |
| G = StateGraph(AgentState) | |
| G.add_node( | |
| node="reformulate", | |
| action=partial(reformulate_question_using_history, llm=llm, focus_on_recommendations=enable_recommendations) | |
| ) | |
| G.add_node(node="search_agent", action=partial(search_agent, llm=llm, tools=tools)) | |
| G.add_node(node="retrieve", action=retrieve) | |
| G.add_node( | |
| node="generate_with_context", | |
| action=partial(generate_with_context, llm=llm, user_callback=user_callback) | |
| ) | |
| G.add_node(node="has_org_name", action=partial(has_org_name, llm=llm, user_callback=user_callback)) | |
| G.add_node(node="insert_org_link", action=insert_org_link) | |
| if enable_recommendations: | |
| add_recommendations_pipeline_( | |
| G, llm=llm, | |
| reformulation_node_name="reformulate", | |
| search_node_name="search_agent" | |
| ) | |
| else: | |
| G.add_edge(start_key="reformulate", end_key="search_agent") | |
| G.add_edge(start_key=START, end_key="reformulate") | |
| G.add_conditional_edges( | |
| source="search_agent", | |
| path=tools_condition, | |
| path_map={ | |
| "tools": "retrieve", | |
| END: "has_org_name", | |
| }, | |
| ) | |
| G.add_edge(start_key="retrieve", end_key="generate_with_context") | |
| G.add_edge(start_key="generate_with_context", end_key="has_org_name") | |
| G.add_conditional_edges( | |
| source="has_org_name", | |
| path=lambda x: x["next"], # Now we're accessing the 'next' key from the dict | |
| path_map={ | |
| "insert_org_link": "insert_org_link", | |
| END: END | |
| }, | |
| ) | |
| G.add_edge(start_key="insert_org_link", end_key=END) | |
| return G | |