Spaces:
Sleeping
Sleeping
File size: 3,012 Bytes
6cbca40 c0827a3 6cbca40 c0827a3 6cbca40 c0827a3 6cbca40 c0827a3 6cbca40 c0827a3 6cbca40 c0827a3 6cbca40 c0827a3 6cbca40 c0827a3 6cbca40 c0827a3 6cbca40 c0827a3 6cbca40 c0827a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
from typing import TypedDict
from src.config.llm import model
from langgraph.prebuilt import create_react_agent
from langgraph_swarm import create_handoff_tool
from langchain_core.messages import RemoveMessage
from .prompt import roleplay_prompt, guiding_prompt
from typing_extensions import TypedDict, Annotated
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from loguru import logger
class State(TypedDict):
active_agent: str | None
messages: Annotated[list[AnyMessage], add_messages]
scenario_title: str
scenario_description: str
scenario_context: str
your_role: str
key_vocabulary: str
def trim_history(state: State):
if not state.get("active_agent"):
state["active_agent"] = "Roleplay Agent"
history = state.get("messages", [])
if len(history) > 25:
num_to_remove = len(history) - 5
remove_messages = [
RemoveMessage(id=history[i].id) for i in range(num_to_remove)
]
state["messages"] = remove_messages
return state
async def call_roleplay(state: State):
logger.info("Calling roleplay agent...")
roleplay_agent = create_react_agent(
model,
[
create_handoff_tool(
agent_name="Guiding Agent",
description="Hand off to Guiding Agent when user shows signs of needing help, guidance, or struggles with communication",
),
],
prompt=roleplay_prompt.format(
scenario_title=state["scenario_title"],
scenario_description=state["scenario_description"],
scenario_context=state["scenario_context"],
your_role=state["your_role"],
key_vocabulary=state["key_vocabulary"],
),
name="Roleplay Agent",
)
response = await roleplay_agent.ainvoke({"messages": state["messages"]})
return {"messages": response["messages"]}
async def call_guiding_agent(state: State):
logger.info("Calling guiding agent...")
guiding_agent = create_react_agent(
model,
[
create_handoff_tool(
agent_name="Roleplay Agent",
description="Hand off back to Roleplay Agent when user is ready for scenario practice and shows improved confidence",
),
],
prompt=guiding_prompt.format(
scenario_title=state["scenario_title"],
scenario_description=state["scenario_description"],
scenario_context=state["scenario_context"],
your_role=state["your_role"],
key_vocabulary=state["key_vocabulary"],
),
name="Guiding Agent",
)
response = await guiding_agent.ainvoke({"messages": state["messages"]})
return {"messages": response["messages"]}
def route_to_active_agent(state: State) -> str:
if state["active_agent"] == "Roleplay Agent":
return "Roleplay Agent"
elif state["active_agent"] == "Guiding Agent":
return "Guiding Agent"
|