from typing import TypedDict from src.config.llm import model from langgraph.prebuilt import create_react_agent from langgraph_swarm import create_handoff_tool from langchain_core.messages import RemoveMessage from .prompt import practice_agent_prompt, teaching_agent_prompt from typing_extensions import TypedDict, Annotated from langchain_core.messages import AnyMessage from langgraph.graph import add_messages from loguru import logger class State(TypedDict): active_agent: str | None messages: Annotated[list[AnyMessage], add_messages] unit: str vocabulary: list key_structures: list practice_questions: list student_level: str def trim_history(state: State): if not state.get("active_agent"): state["active_agent"] = "Practice Agent" history = state.get("messages", []) if len(history) > 25: num_to_remove = len(history) - 5 remove_messages = [ RemoveMessage(id=history[i].id) for i in range(num_to_remove) ] state["messages"] = remove_messages return state async def call_practice_agent(state: State): logger.info("Calling practice agent...") practice_agent = create_react_agent( model, [ create_handoff_tool( agent_name="Teaching Agent", description="Hand off to Teaching Agent when user makes same grammar mistake twice in a row, asks for grammar/structure help, or shows fundamental misunderstanding", ), ], prompt=practice_agent_prompt.format( unit=state["unit"], vocabulary=state["vocabulary"], key_structures=state["key_structures"], practice_questions=state["practice_questions"], student_level=state["student_level"], ), name="Practice Agent", ) response = await practice_agent.ainvoke({"messages": state["messages"]}) return {"messages": response["messages"]} async def call_teaching_agent(state: State): logger.info("Calling teaching agent...") teaching_agent = create_react_agent( model, [ create_handoff_tool( agent_name="Practice Agent", description="Hand off back to Practice Agent when user demonstrates understanding and is ready for conversation practice", ), ], prompt=teaching_agent_prompt.format( unit=state["unit"], vocabulary=state["vocabulary"], key_structures=state["key_structures"], practice_questions=state["practice_questions"], student_level=state["student_level"], ), name="Teaching Agent", ) response = await teaching_agent.ainvoke({"messages": state["messages"]}) return {"messages": response["messages"]} def route_to_active_agent(state: State) -> str: if state["active_agent"] == "Practice Agent": return "Practice Agent" elif state["active_agent"] == "Teaching Agent": return "Teaching Agent"