File size: 3,012 Bytes
6cbca40
 
 
c0827a3
 
6cbca40
c0827a3
 
 
 
6cbca40
 
 
c0827a3
 
 
 
 
 
 
6cbca40
 
c0827a3
 
 
 
 
 
 
 
 
 
 
6cbca40
c0827a3
 
 
6cbca40
 
 
 
 
 
 
 
 
c0827a3
 
 
 
 
6cbca40
 
 
c0827a3
 
 
6cbca40
c0827a3
 
 
6cbca40
 
 
 
 
 
 
 
 
c0827a3
 
 
 
 
6cbca40
 
 
c0827a3
 
6cbca40
 
c0827a3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from typing import TypedDict
from src.config.llm import model
from langgraph.prebuilt import create_react_agent
from langgraph_swarm import create_handoff_tool
from langchain_core.messages import RemoveMessage
from .prompt import roleplay_prompt, guiding_prompt
from typing_extensions import TypedDict, Annotated
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from loguru import logger


class State(TypedDict):
    active_agent: str | None
    messages: Annotated[list[AnyMessage], add_messages]
    scenario_title: str
    scenario_description: str
    scenario_context: str
    your_role: str
    key_vocabulary: str


def trim_history(state: State):
    if not state.get("active_agent"):
        state["active_agent"] = "Roleplay Agent"
    history = state.get("messages", [])
    if len(history) > 25:
        num_to_remove = len(history) - 5
        remove_messages = [
            RemoveMessage(id=history[i].id) for i in range(num_to_remove)
        ]
        state["messages"] = remove_messages
    return state


async def call_roleplay(state: State):
    logger.info("Calling roleplay agent...")
    roleplay_agent = create_react_agent(
        model,
        [
            create_handoff_tool(
                agent_name="Guiding Agent",
                description="Hand off to Guiding Agent when user shows signs of needing help, guidance, or struggles with communication",
            ),
        ],
        prompt=roleplay_prompt.format(
            scenario_title=state["scenario_title"],
            scenario_description=state["scenario_description"],
            scenario_context=state["scenario_context"],
            your_role=state["your_role"],
            key_vocabulary=state["key_vocabulary"],
        ),
        name="Roleplay Agent",
    )
    response = await roleplay_agent.ainvoke({"messages": state["messages"]})

    return {"messages": response["messages"]}


async def call_guiding_agent(state: State):
    logger.info("Calling guiding agent...")
    guiding_agent = create_react_agent(
        model,
        [
            create_handoff_tool(
                agent_name="Roleplay Agent",
                description="Hand off back to Roleplay Agent when user is ready for scenario practice and shows improved confidence",
            ),
        ],
        prompt=guiding_prompt.format(
            scenario_title=state["scenario_title"],
            scenario_description=state["scenario_description"],
            scenario_context=state["scenario_context"],
            your_role=state["your_role"],
            key_vocabulary=state["key_vocabulary"],
        ),
        name="Guiding Agent",
    )
    response = await guiding_agent.ainvoke({"messages": state["messages"]})
    return {"messages": response["messages"]}


def route_to_active_agent(state: State) -> str:
    if state["active_agent"] == "Roleplay Agent":
        return "Roleplay Agent"
    elif state["active_agent"] == "Guiding Agent":
        return "Guiding Agent"