harmonic-analysis / agents /multi /analysis.py
ohollo's picture
Langsmith evals
7d65f7a
Raw
History Blame Contribute Delete
3.11 kB
"""Analysis agent subgraph."""
import ast
from typing import Annotated, TypedDict
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import BaseTool
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode
_SYSTEM = (
"You are a harmonic analysis assistant. Given a chord sequence, use the available tools "
"to analyse it. "
"Use your judgement about which tools to call and in what order to get the best results."
)
_TOOL_NAMES = {"analyze_chord_sequence_text", "analyze_music_file"}
class AnalysisState(TypedDict):
user_input: str # received from parent
limit: int # received from parent
originality_score: float | None # returned to parent
neighbours: list[dict] # returned to parent
messages: Annotated[list, add_messages] # private
def build_analysis_subgraph(
analysis_llm: BaseChatModel,
mcp_tools: list[BaseTool],
) -> CompiledStateGraph:
"""Build the harmonic analysis agent subgraph.
:param analysis_llm: LLM for harmonic analysis (must support tool calling).
:param mcp_tools: Harmonic analysis MCP tools.
"""
analysis_llm_with_tools = analysis_llm.bind_tools(mcp_tools)
def start(state: AnalysisState) -> dict:
return {
"originality_score": None,
"neighbours": [],
"messages": [
SystemMessage(content=_SYSTEM),
HumanMessage(content=f"{state['user_input']}\n\nReturn up to {state['limit']} similar songs."),
],
}
def agent(state: AnalysisState) -> dict:
return {"messages": [analysis_llm_with_tools.invoke(state["messages"])]}
def router(state: AnalysisState) -> str:
last = state["messages"][-1]
return "tools" if getattr(last, "tool_calls", None) else "extract"
def extract(state: AnalysisState) -> dict:
tool_msgs = [
m for m in state["messages"]
if isinstance(m, ToolMessage) and m.name in _TOOL_NAMES
]
if not tool_msgs:
raise RuntimeError("Analysis agent did not call any analysis tool — cannot extract results.")
content = tool_msgs[-1].content
lines = content.strip().splitlines() if isinstance(content, str) else []
score = float(lines[0])
neighbours = ast.literal_eval(lines[1])
return {"originality_score": score, "neighbours": neighbours}
graph = StateGraph(AnalysisState)
graph.add_node("start", start)
graph.add_node("agent", agent)
graph.add_node("tools", ToolNode(mcp_tools))
graph.add_node("extract", extract)
graph.add_edge(START, "start")
graph.add_edge("start", "agent")
graph.add_conditional_edges("agent", router, ["tools", "extract"])
graph.add_edge("tools", "agent")
graph.add_edge("extract", END)
return graph.compile()