File size: 3,109 Bytes
9482535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d65f7a
 
9482535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d65f7a
9482535
 
 
 
 
7d65f7a
9482535
 
 
7d65f7a
9482535
 
7d65f7a
9482535
 
 
7d65f7a
9482535
 
 
 
 
7d65f7a
9482535
 
 
 
 
 
 
7d65f7a
 
9482535
7d65f7a
9482535
 
7d65f7a
9482535
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Analysis agent subgraph."""

import ast
from typing import Annotated, TypedDict

from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import BaseTool
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode

_SYSTEM = (
    "You are a harmonic analysis assistant. Given a chord sequence, use the available tools "
    "to analyse it. "
    "Use your judgement about which tools to call and in what order to get the best results."
)

_TOOL_NAMES = {"analyze_chord_sequence_text", "analyze_music_file"}


class AnalysisState(TypedDict):
    user_input: str                  # received from parent
    limit: int                       # received from parent
    originality_score: float | None  # returned to parent
    neighbours: list[dict]           # returned to parent
    messages: Annotated[list, add_messages]  # private


def build_analysis_subgraph(
    analysis_llm: BaseChatModel,
    mcp_tools: list[BaseTool],
) -> CompiledStateGraph:
    """Build the harmonic analysis agent subgraph.

    :param analysis_llm: LLM for harmonic analysis (must support tool calling).
    :param mcp_tools: Harmonic analysis MCP tools.
    """
    analysis_llm_with_tools = analysis_llm.bind_tools(mcp_tools)

    def start(state: AnalysisState) -> dict:
        return {
            "originality_score": None,
            "neighbours": [],
            "messages": [
                SystemMessage(content=_SYSTEM),
                HumanMessage(content=f"{state['user_input']}\n\nReturn up to {state['limit']} similar songs."),
            ],
        }

    def agent(state: AnalysisState) -> dict:
        return {"messages": [analysis_llm_with_tools.invoke(state["messages"])]}

    def router(state: AnalysisState) -> str:
        last = state["messages"][-1]
        return "tools" if getattr(last, "tool_calls", None) else "extract"

    def extract(state: AnalysisState) -> dict:
        tool_msgs = [
            m for m in state["messages"]
            if isinstance(m, ToolMessage) and m.name in _TOOL_NAMES
        ]
        if not tool_msgs:
            raise RuntimeError("Analysis agent did not call any analysis tool — cannot extract results.")
        content = tool_msgs[-1].content
        lines = content.strip().splitlines() if isinstance(content, str) else []
        score = float(lines[0])
        neighbours = ast.literal_eval(lines[1])
        return {"originality_score": score, "neighbours": neighbours}

    graph = StateGraph(AnalysisState)
    graph.add_node("start", start)
    graph.add_node("agent", agent)
    graph.add_node("tools", ToolNode(mcp_tools))
    graph.add_node("extract", extract)
    graph.add_edge(START, "start")
    graph.add_edge("start", "agent")
    graph.add_conditional_edges("agent", router, ["tools", "extract"])
    graph.add_edge("tools", "agent")
    graph.add_edge("extract", END)
    return graph.compile()