"""Analysis agent subgraph.""" import ast from typing import Annotated, TypedDict from langchain_core.language_models import BaseChatModel from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage from langchain_core.tools import BaseTool from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages from langgraph.graph.state import CompiledStateGraph from langgraph.prebuilt import ToolNode _SYSTEM = ( "You are a harmonic analysis assistant. Given a chord sequence, use the available tools " "to analyse it. " "Use your judgement about which tools to call and in what order to get the best results." ) _TOOL_NAMES = {"analyze_chord_sequence_text", "analyze_music_file"} class AnalysisState(TypedDict): user_input: str # received from parent limit: int # received from parent originality_score: float | None # returned to parent neighbours: list[dict] # returned to parent messages: Annotated[list, add_messages] # private def build_analysis_subgraph( analysis_llm: BaseChatModel, mcp_tools: list[BaseTool], ) -> CompiledStateGraph: """Build the harmonic analysis agent subgraph. :param analysis_llm: LLM for harmonic analysis (must support tool calling). :param mcp_tools: Harmonic analysis MCP tools. """ analysis_llm_with_tools = analysis_llm.bind_tools(mcp_tools) def start(state: AnalysisState) -> dict: return { "originality_score": None, "neighbours": [], "messages": [ SystemMessage(content=_SYSTEM), HumanMessage(content=f"{state['user_input']}\n\nReturn up to {state['limit']} similar songs."), ], } def agent(state: AnalysisState) -> dict: return {"messages": [analysis_llm_with_tools.invoke(state["messages"])]} def router(state: AnalysisState) -> str: last = state["messages"][-1] return "tools" if getattr(last, "tool_calls", None) else "extract" def extract(state: AnalysisState) -> dict: tool_msgs = [ m for m in state["messages"] if isinstance(m, ToolMessage) and m.name in _TOOL_NAMES ] if not tool_msgs: raise RuntimeError("Analysis agent did not call any analysis tool — cannot extract results.") content = tool_msgs[-1].content lines = content.strip().splitlines() if isinstance(content, str) else [] score = float(lines[0]) neighbours = ast.literal_eval(lines[1]) return {"originality_score": score, "neighbours": neighbours} graph = StateGraph(AnalysisState) graph.add_node("start", start) graph.add_node("agent", agent) graph.add_node("tools", ToolNode(mcp_tools)) graph.add_node("extract", extract) graph.add_edge(START, "start") graph.add_edge("start", "agent") graph.add_conditional_edges("agent", router, ["tools", "extract"]) graph.add_edge("tools", "agent") graph.add_edge("extract", END) return graph.compile()