File size: 5,265 Bytes
22dcdfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
from datetime import datetime
from typing import Literal
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.runnables import (
RunnableConfig,
RunnableLambda,
RunnableSerializable,
)
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.managed import RemainingSteps
from langgraph.prebuilt import ToolNode
from agents.llama_guard import LlamaGuard, LlamaGuardOutput, SafetyAssessment
from agents.tools import database_search
from core import get_model, settings
class AgentState(MessagesState, total=False):
"""`total=False` is PEP589 specs.
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality
"""
safety: LlamaGuardOutput
remaining_steps: RemainingSteps
tools = [database_search]
current_date = datetime.now().strftime("%B %d, %Y")
instructions = f"""
You are AcmeBot, a helpful and knowledgeable virtual assistant designed to support employees by retrieving
and answering questions based on AcmeTech's official Employee Handbook. Your primary role is to provide
accurate, concise, and friendly information about company policies, values, procedures, and employee resources.
Today's date is {current_date}.
NOTE: THE USER CAN'T SEE THE TOOL RESPONSE.
A few things to remember:
- If you have access to multiple databases, gather information from a diverse range of sources before crafting your response.
- Please include markdown-formatted links to any citations used in your response. Only include one
or two citations per response unless more are needed. ONLY USE LINKS RETURNED BY THE TOOLS.
- Only use information from the database. Do not use information from outside sources.
"""
def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]:
bound_model = model.bind_tools(tools)
preprocessor = RunnableLambda(
lambda state: [SystemMessage(content=instructions)] + state["messages"],
name="StateModifier",
)
return preprocessor | bound_model # type: ignore[return-value]
def format_safety_message(safety: LlamaGuardOutput) -> AIMessage:
content = (
f"This conversation was flagged for unsafe content: {', '.join(safety.unsafe_categories)}"
)
return AIMessage(content=content)
async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL))
model_runnable = wrap_model(m)
response = await model_runnable.ainvoke(state, config)
# Run llama guard check here to avoid returning the message if it's unsafe
llama_guard = LlamaGuard()
safety_output = await llama_guard.ainvoke("Agent", state["messages"] + [response])
if safety_output.safety_assessment == SafetyAssessment.UNSAFE:
return {
"messages": [format_safety_message(safety_output)],
"safety": safety_output,
}
if state["remaining_steps"] < 2 and response.tool_calls:
return {
"messages": [
AIMessage(
id=response.id,
content="Sorry, need more steps to process this request.",
)
]
}
# We return a list, because this will get added to the existing list
return {"messages": [response]}
async def llama_guard_input(state: AgentState, config: RunnableConfig) -> AgentState:
llama_guard = LlamaGuard()
safety_output = await llama_guard.ainvoke("User", state["messages"])
return {"safety": safety_output, "messages": []}
async def block_unsafe_content(state: AgentState, config: RunnableConfig) -> AgentState:
safety: LlamaGuardOutput = state["safety"]
return {"messages": [format_safety_message(safety)]}
# Define the graph
agent = StateGraph(AgentState)
agent.add_node("model", acall_model)
agent.add_node("tools", ToolNode(tools))
agent.add_node("guard_input", llama_guard_input)
agent.add_node("block_unsafe_content", block_unsafe_content)
agent.set_entry_point("guard_input")
# Check for unsafe input and block further processing if found
def check_safety(state: AgentState) -> Literal["unsafe", "safe"]:
safety: LlamaGuardOutput = state["safety"]
match safety.safety_assessment:
case SafetyAssessment.UNSAFE:
return "unsafe"
case _:
return "safe"
agent.add_conditional_edges(
"guard_input", check_safety, {"unsafe": "block_unsafe_content", "safe": "model"}
)
# Always END after blocking unsafe content
agent.add_edge("block_unsafe_content", END)
# Always run "model" after "tools"
agent.add_edge("tools", "model")
# After "model", if there are tool calls, run "tools". Otherwise END.
def pending_tool_calls(state: AgentState) -> Literal["tools", "done"]:
last_message = state["messages"][-1]
if not isinstance(last_message, AIMessage):
raise TypeError(f"Expected AIMessage, got {type(last_message)}")
if last_message.tool_calls:
return "tools"
return "done"
agent.add_conditional_edges("model", pending_tool_calls, {"tools": "tools", "done": END})
rag_assistant = agent.compile()
|