backend / src /agents /rag_assistant.py
anujjoshi3105's picture
initial
22dcdfd
from datetime import datetime
from typing import Literal
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.runnables import (
RunnableConfig,
RunnableLambda,
RunnableSerializable,
)
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.managed import RemainingSteps
from langgraph.prebuilt import ToolNode
from agents.llama_guard import LlamaGuard, LlamaGuardOutput, SafetyAssessment
from agents.tools import database_search
from core import get_model, settings
class AgentState(MessagesState, total=False):
"""`total=False` is PEP589 specs.
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality
"""
safety: LlamaGuardOutput
remaining_steps: RemainingSteps
tools = [database_search]
current_date = datetime.now().strftime("%B %d, %Y")
instructions = f"""
You are AcmeBot, a helpful and knowledgeable virtual assistant designed to support employees by retrieving
and answering questions based on AcmeTech's official Employee Handbook. Your primary role is to provide
accurate, concise, and friendly information about company policies, values, procedures, and employee resources.
Today's date is {current_date}.
NOTE: THE USER CAN'T SEE THE TOOL RESPONSE.
A few things to remember:
- If you have access to multiple databases, gather information from a diverse range of sources before crafting your response.
- Please include markdown-formatted links to any citations used in your response. Only include one
or two citations per response unless more are needed. ONLY USE LINKS RETURNED BY THE TOOLS.
- Only use information from the database. Do not use information from outside sources.
"""
def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]:
bound_model = model.bind_tools(tools)
preprocessor = RunnableLambda(
lambda state: [SystemMessage(content=instructions)] + state["messages"],
name="StateModifier",
)
return preprocessor | bound_model # type: ignore[return-value]
def format_safety_message(safety: LlamaGuardOutput) -> AIMessage:
content = (
f"This conversation was flagged for unsafe content: {', '.join(safety.unsafe_categories)}"
)
return AIMessage(content=content)
async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL))
model_runnable = wrap_model(m)
response = await model_runnable.ainvoke(state, config)
# Run llama guard check here to avoid returning the message if it's unsafe
llama_guard = LlamaGuard()
safety_output = await llama_guard.ainvoke("Agent", state["messages"] + [response])
if safety_output.safety_assessment == SafetyAssessment.UNSAFE:
return {
"messages": [format_safety_message(safety_output)],
"safety": safety_output,
}
if state["remaining_steps"] < 2 and response.tool_calls:
return {
"messages": [
AIMessage(
id=response.id,
content="Sorry, need more steps to process this request.",
)
]
}
# We return a list, because this will get added to the existing list
return {"messages": [response]}
async def llama_guard_input(state: AgentState, config: RunnableConfig) -> AgentState:
llama_guard = LlamaGuard()
safety_output = await llama_guard.ainvoke("User", state["messages"])
return {"safety": safety_output, "messages": []}
async def block_unsafe_content(state: AgentState, config: RunnableConfig) -> AgentState:
safety: LlamaGuardOutput = state["safety"]
return {"messages": [format_safety_message(safety)]}
# Define the graph
agent = StateGraph(AgentState)
agent.add_node("model", acall_model)
agent.add_node("tools", ToolNode(tools))
agent.add_node("guard_input", llama_guard_input)
agent.add_node("block_unsafe_content", block_unsafe_content)
agent.set_entry_point("guard_input")
# Check for unsafe input and block further processing if found
def check_safety(state: AgentState) -> Literal["unsafe", "safe"]:
safety: LlamaGuardOutput = state["safety"]
match safety.safety_assessment:
case SafetyAssessment.UNSAFE:
return "unsafe"
case _:
return "safe"
agent.add_conditional_edges(
"guard_input", check_safety, {"unsafe": "block_unsafe_content", "safe": "model"}
)
# Always END after blocking unsafe content
agent.add_edge("block_unsafe_content", END)
# Always run "model" after "tools"
agent.add_edge("tools", "model")
# After "model", if there are tool calls, run "tools". Otherwise END.
def pending_tool_calls(state: AgentState) -> Literal["tools", "done"]:
last_message = state["messages"][-1]
if not isinstance(last_message, AIMessage):
raise TypeError(f"Expected AIMessage, got {type(last_message)}")
if last_message.tool_calls:
return "tools"
return "done"
agent.add_conditional_edges("model", pending_tool_calls, {"tools": "tools", "done": END})
rag_assistant = agent.compile()