|
|
from datetime import datetime |
|
|
from typing import Literal |
|
|
|
|
|
from langchain_community.tools import DuckDuckGoSearchResults, OpenWeatherMapQueryRun |
|
|
from langchain_community.utilities import OpenWeatherMapAPIWrapper |
|
|
from langchain_core.language_models.chat_models import BaseChatModel |
|
|
from langchain_core.messages import AIMessage, SystemMessage |
|
|
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable |
|
|
from langgraph.graph import END, MessagesState, StateGraph |
|
|
from langgraph.managed import RemainingSteps |
|
|
from langgraph.prebuilt import ToolNode |
|
|
|
|
|
from agents.llama_guard import LlamaGuard, LlamaGuardOutput, SafetyAssessment |
|
|
from agents.tools import calculator |
|
|
from core import get_model, settings |
|
|
|
|
|
|
|
|
class AgentState(MessagesState, total=False): |
|
|
"""`total=False` is PEP589 specs. |
|
|
|
|
|
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality |
|
|
""" |
|
|
|
|
|
safety: LlamaGuardOutput |
|
|
remaining_steps: RemainingSteps |
|
|
|
|
|
|
|
|
web_search = DuckDuckGoSearchResults(name="WebSearch") |
|
|
tools = [web_search, calculator] |
|
|
|
|
|
|
|
|
|
|
|
if settings.OPENWEATHERMAP_API_KEY: |
|
|
wrapper = OpenWeatherMapAPIWrapper( |
|
|
openweathermap_api_key=settings.OPENWEATHERMAP_API_KEY.get_secret_value() |
|
|
) |
|
|
tools.append(OpenWeatherMapQueryRun(name="Weather", api_wrapper=wrapper)) |
|
|
|
|
|
current_date = datetime.now().strftime("%B %d, %Y") |
|
|
instructions = f""" |
|
|
You are a helpful research assistant with the ability to search the web and use other tools. |
|
|
Today's date is {current_date}. |
|
|
|
|
|
NOTE: THE USER CAN'T SEE THE TOOL RESPONSE. |
|
|
|
|
|
A few things to remember: |
|
|
- Please include markdown-formatted links to any citations used in your response. Only include one |
|
|
or two citations per response unless more are needed. ONLY USE LINKS RETURNED BY THE TOOLS. |
|
|
- Use calculator tool with numexpr to answer math questions. The user does not understand numexpr, |
|
|
so for the final response, use human readable format - e.g. "300 * 200", not "(300 \\times 200)". |
|
|
""" |
|
|
|
|
|
|
|
|
def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]: |
|
|
bound_model = model.bind_tools(tools) |
|
|
preprocessor = RunnableLambda( |
|
|
lambda state: [SystemMessage(content=instructions)] + state["messages"], |
|
|
name="StateModifier", |
|
|
) |
|
|
return preprocessor | bound_model |
|
|
|
|
|
|
|
|
def format_safety_message(safety: LlamaGuardOutput) -> AIMessage: |
|
|
content = ( |
|
|
f"This conversation was flagged for unsafe content: {', '.join(safety.unsafe_categories)}" |
|
|
) |
|
|
return AIMessage(content=content) |
|
|
|
|
|
|
|
|
async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: |
|
|
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL)) |
|
|
model_runnable = wrap_model(m) |
|
|
response = await model_runnable.ainvoke(state, config) |
|
|
|
|
|
|
|
|
llama_guard = LlamaGuard() |
|
|
safety_output = await llama_guard.ainvoke("Agent", state["messages"] + [response]) |
|
|
if safety_output.safety_assessment == SafetyAssessment.UNSAFE: |
|
|
return {"messages": [format_safety_message(safety_output)], "safety": safety_output} |
|
|
|
|
|
if state["remaining_steps"] < 2 and response.tool_calls: |
|
|
return { |
|
|
"messages": [ |
|
|
AIMessage( |
|
|
id=response.id, |
|
|
content="Sorry, need more steps to process this request.", |
|
|
) |
|
|
] |
|
|
} |
|
|
|
|
|
return {"messages": [response]} |
|
|
|
|
|
|
|
|
async def llama_guard_input(state: AgentState, config: RunnableConfig) -> AgentState: |
|
|
llama_guard = LlamaGuard() |
|
|
safety_output = await llama_guard.ainvoke("User", state["messages"]) |
|
|
return {"safety": safety_output, "messages": []} |
|
|
|
|
|
|
|
|
async def block_unsafe_content(state: AgentState, config: RunnableConfig) -> AgentState: |
|
|
safety: LlamaGuardOutput = state["safety"] |
|
|
return {"messages": [format_safety_message(safety)]} |
|
|
|
|
|
|
|
|
|
|
|
agent = StateGraph(AgentState) |
|
|
agent.add_node("model", acall_model) |
|
|
agent.add_node("tools", ToolNode(tools)) |
|
|
agent.add_node("guard_input", llama_guard_input) |
|
|
agent.add_node("block_unsafe_content", block_unsafe_content) |
|
|
agent.set_entry_point("guard_input") |
|
|
|
|
|
|
|
|
|
|
|
def check_safety(state: AgentState) -> Literal["unsafe", "safe"]: |
|
|
safety: LlamaGuardOutput = state["safety"] |
|
|
match safety.safety_assessment: |
|
|
case SafetyAssessment.UNSAFE: |
|
|
return "unsafe" |
|
|
case _: |
|
|
return "safe" |
|
|
|
|
|
|
|
|
agent.add_conditional_edges( |
|
|
"guard_input", check_safety, {"unsafe": "block_unsafe_content", "safe": "model"} |
|
|
) |
|
|
|
|
|
|
|
|
agent.add_edge("block_unsafe_content", END) |
|
|
|
|
|
|
|
|
agent.add_edge("tools", "model") |
|
|
|
|
|
|
|
|
|
|
|
def pending_tool_calls(state: AgentState) -> Literal["tools", "done"]: |
|
|
last_message = state["messages"][-1] |
|
|
if not isinstance(last_message, AIMessage): |
|
|
raise TypeError(f"Expected AIMessage, got {type(last_message)}") |
|
|
if last_message.tool_calls: |
|
|
return "tools" |
|
|
return "done" |
|
|
|
|
|
|
|
|
agent.add_conditional_edges("model", pending_tool_calls, {"tools": "tools", "done": END}) |
|
|
|
|
|
|
|
|
research_assistant = agent.compile() |
|
|
|