Spaces:
Sleeping
Sleeping
| """Analysis agent subgraph.""" | |
| import ast | |
| from typing import Annotated, TypedDict | |
| from langchain_core.language_models import BaseChatModel | |
| from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage | |
| from langchain_core.tools import BaseTool | |
| from langgraph.graph import END, START, StateGraph | |
| from langgraph.graph.message import add_messages | |
| from langgraph.graph.state import CompiledStateGraph | |
| from langgraph.prebuilt import ToolNode | |
| _SYSTEM = ( | |
| "You are a harmonic analysis assistant. Given a chord sequence, use the available tools " | |
| "to analyse it. " | |
| "Use your judgement about which tools to call and in what order to get the best results." | |
| ) | |
| _TOOL_NAMES = {"analyze_chord_sequence_text", "analyze_music_file"} | |
| class AnalysisState(TypedDict): | |
| user_input: str # received from parent | |
| limit: int # received from parent | |
| originality_score: float | None # returned to parent | |
| neighbours: list[dict] # returned to parent | |
| messages: Annotated[list, add_messages] # private | |
| def build_analysis_subgraph( | |
| analysis_llm: BaseChatModel, | |
| mcp_tools: list[BaseTool], | |
| ) -> CompiledStateGraph: | |
| """Build the harmonic analysis agent subgraph. | |
| :param analysis_llm: LLM for harmonic analysis (must support tool calling). | |
| :param mcp_tools: Harmonic analysis MCP tools. | |
| """ | |
| analysis_llm_with_tools = analysis_llm.bind_tools(mcp_tools) | |
| def start(state: AnalysisState) -> dict: | |
| return { | |
| "originality_score": None, | |
| "neighbours": [], | |
| "messages": [ | |
| SystemMessage(content=_SYSTEM), | |
| HumanMessage(content=f"{state['user_input']}\n\nReturn up to {state['limit']} similar songs."), | |
| ], | |
| } | |
| def agent(state: AnalysisState) -> dict: | |
| return {"messages": [analysis_llm_with_tools.invoke(state["messages"])]} | |
| def router(state: AnalysisState) -> str: | |
| last = state["messages"][-1] | |
| return "tools" if getattr(last, "tool_calls", None) else "extract" | |
| def extract(state: AnalysisState) -> dict: | |
| tool_msgs = [ | |
| m for m in state["messages"] | |
| if isinstance(m, ToolMessage) and m.name in _TOOL_NAMES | |
| ] | |
| if not tool_msgs: | |
| raise RuntimeError("Analysis agent did not call any analysis tool — cannot extract results.") | |
| content = tool_msgs[-1].content | |
| lines = content.strip().splitlines() if isinstance(content, str) else [] | |
| score = float(lines[0]) | |
| neighbours = ast.literal_eval(lines[1]) | |
| return {"originality_score": score, "neighbours": neighbours} | |
| graph = StateGraph(AnalysisState) | |
| graph.add_node("start", start) | |
| graph.add_node("agent", agent) | |
| graph.add_node("tools", ToolNode(mcp_tools)) | |
| graph.add_node("extract", extract) | |
| graph.add_edge(START, "start") | |
| graph.add_edge("start", "agent") | |
| graph.add_conditional_edges("agent", router, ["tools", "extract"]) | |
| graph.add_edge("tools", "agent") | |
| graph.add_edge("extract", END) | |
| return graph.compile() | |