from datetime import datetime from typing import Literal from langchain_community.tools import DuckDuckGoSearchResults, OpenWeatherMapQueryRun from langchain_community.utilities import OpenWeatherMapAPIWrapper from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, SystemMessage from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable from langgraph.graph import END, MessagesState, StateGraph from langgraph.managed import RemainingSteps from langgraph.prebuilt import ToolNode from agents.llama_guard import LlamaGuard, LlamaGuardOutput, SafetyAssessment from agents.tools import calculator from core import get_model, settings class AgentState(MessagesState, total=False): """`total=False` is PEP589 specs. documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality """ safety: LlamaGuardOutput remaining_steps: RemainingSteps web_search = DuckDuckGoSearchResults(name="WebSearch") tools = [web_search, calculator] # Add weather tool if API key is set # Register for an API key at https://openweathermap.org/api/ if settings.OPENWEATHERMAP_API_KEY: wrapper = OpenWeatherMapAPIWrapper( openweathermap_api_key=settings.OPENWEATHERMAP_API_KEY.get_secret_value() ) tools.append(OpenWeatherMapQueryRun(name="Weather", api_wrapper=wrapper)) current_date = datetime.now().strftime("%B %d, %Y") instructions = f""" You are a helpful research assistant with the ability to search the web and use other tools. Today's date is {current_date}. NOTE: THE USER CAN'T SEE THE TOOL RESPONSE. A few things to remember: - Please include markdown-formatted links to any citations used in your response. Only include one or two citations per response unless more are needed. ONLY USE LINKS RETURNED BY THE TOOLS. - Use calculator tool with numexpr to answer math questions. The user does not understand numexpr, so for the final response, use human readable format - e.g. "300 * 200", not "(300 \\times 200)". """ def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]: bound_model = model.bind_tools(tools) preprocessor = RunnableLambda( lambda state: [SystemMessage(content=instructions)] + state["messages"], name="StateModifier", ) return preprocessor | bound_model # type: ignore[return-value] def format_safety_message(safety: LlamaGuardOutput) -> AIMessage: content = ( f"This conversation was flagged for unsafe content: {', '.join(safety.unsafe_categories)}" ) return AIMessage(content=content) async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL)) model_runnable = wrap_model(m) response = await model_runnable.ainvoke(state, config) # Run llama guard check here to avoid returning the message if it's unsafe llama_guard = LlamaGuard() safety_output = await llama_guard.ainvoke("Agent", state["messages"] + [response]) if safety_output.safety_assessment == SafetyAssessment.UNSAFE: return {"messages": [format_safety_message(safety_output)], "safety": safety_output} if state["remaining_steps"] < 2 and response.tool_calls: return { "messages": [ AIMessage( id=response.id, content="Sorry, need more steps to process this request.", ) ] } # We return a list, because this will get added to the existing list return {"messages": [response]} async def llama_guard_input(state: AgentState, config: RunnableConfig) -> AgentState: llama_guard = LlamaGuard() safety_output = await llama_guard.ainvoke("User", state["messages"]) return {"safety": safety_output, "messages": []} async def block_unsafe_content(state: AgentState, config: RunnableConfig) -> AgentState: safety: LlamaGuardOutput = state["safety"] return {"messages": [format_safety_message(safety)]} # Define the graph agent = StateGraph(AgentState) agent.add_node("model", acall_model) agent.add_node("tools", ToolNode(tools)) agent.add_node("guard_input", llama_guard_input) agent.add_node("block_unsafe_content", block_unsafe_content) agent.set_entry_point("guard_input") # Check for unsafe input and block further processing if found def check_safety(state: AgentState) -> Literal["unsafe", "safe"]: safety: LlamaGuardOutput = state["safety"] match safety.safety_assessment: case SafetyAssessment.UNSAFE: return "unsafe" case _: return "safe" agent.add_conditional_edges( "guard_input", check_safety, {"unsafe": "block_unsafe_content", "safe": "model"} ) # Always END after blocking unsafe content agent.add_edge("block_unsafe_content", END) # Always run "model" after "tools" agent.add_edge("tools", "model") # After "model", if there are tool calls, run "tools". Otherwise END. def pending_tool_calls(state: AgentState) -> Literal["tools", "done"]: last_message = state["messages"][-1] if not isinstance(last_message, AIMessage): raise TypeError(f"Expected AIMessage, got {type(last_message)}") if last_message.tool_calls: return "tools" return "done" agent.add_conditional_edges("model", pending_tool_calls, {"tools": "tools", "done": END}) research_assistant = agent.compile()