ABAO77's picture
Refactor roleplay agent implementation and update session handling for improved message processing
c0827a3
raw
history blame
3.01 kB
from typing import TypedDict
from src.config.llm import model
from langgraph.prebuilt import create_react_agent
from langgraph_swarm import create_handoff_tool
from langchain_core.messages import RemoveMessage
from .prompt import roleplay_prompt, guiding_prompt
from typing_extensions import TypedDict, Annotated
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from loguru import logger
class State(TypedDict):
active_agent: str | None
messages: Annotated[list[AnyMessage], add_messages]
scenario_title: str
scenario_description: str
scenario_context: str
your_role: str
key_vocabulary: str
def trim_history(state: State):
if not state.get("active_agent"):
state["active_agent"] = "Roleplay Agent"
history = state.get("messages", [])
if len(history) > 25:
num_to_remove = len(history) - 5
remove_messages = [
RemoveMessage(id=history[i].id) for i in range(num_to_remove)
]
state["messages"] = remove_messages
return state
async def call_roleplay(state: State):
logger.info("Calling roleplay agent...")
roleplay_agent = create_react_agent(
model,
[
create_handoff_tool(
agent_name="Guiding Agent",
description="Hand off to Guiding Agent when user shows signs of needing help, guidance, or struggles with communication",
),
],
prompt=roleplay_prompt.format(
scenario_title=state["scenario_title"],
scenario_description=state["scenario_description"],
scenario_context=state["scenario_context"],
your_role=state["your_role"],
key_vocabulary=state["key_vocabulary"],
),
name="Roleplay Agent",
)
response = await roleplay_agent.ainvoke({"messages": state["messages"]})
return {"messages": response["messages"]}
async def call_guiding_agent(state: State):
logger.info("Calling guiding agent...")
guiding_agent = create_react_agent(
model,
[
create_handoff_tool(
agent_name="Roleplay Agent",
description="Hand off back to Roleplay Agent when user is ready for scenario practice and shows improved confidence",
),
],
prompt=guiding_prompt.format(
scenario_title=state["scenario_title"],
scenario_description=state["scenario_description"],
scenario_context=state["scenario_context"],
your_role=state["your_role"],
key_vocabulary=state["key_vocabulary"],
),
name="Guiding Agent",
)
response = await guiding_agent.ainvoke({"messages": state["messages"]})
return {"messages": response["messages"]}
def route_to_active_agent(state: State) -> str:
if state["active_agent"] == "Roleplay Agent":
return "Roleplay Agent"
elif state["active_agent"] == "Guiding Agent":
return "Guiding Agent"