import { create } from "zustand"; import { type Node, type Edge, type OnNodesChange, type OnEdgesChange, type OnConnect, type XYPosition, applyNodeChanges, applyEdgeChanges, addEdge as rfAddEdge, MarkerType, } from "@xyflow/react"; import dagre from "dagre"; import { api } from "@/lib/api"; import type { AgentProfile } from "@/types/agent"; import type { GraphSaveRequest, GraphResponse, GraphListItem, GraphValidationResponse, } from "@/types/graph"; import type { AgentExecutionStatus } from "@/types/execution"; export interface AgentNodeData { agentId: string; displayName: string; persona: string; description: string; llmBackbone: string | null; tools: string[]; isStartNode: boolean; isEndNode: boolean; executionStatus: AgentExecutionStatus; response?: string; tokensUsed?: number; [key: string]: unknown; } export interface EdgeData { condition?: string | null; label?: string | null; weight?: number; evaluated?: boolean; [key: string]: unknown; } interface GraphStore { // React Flow state nodes: Node[]; edges: Edge[]; // Metadata graphId: string | null; graphName: string; graphDescription: string; taskQuery: string; startNode: string | null; endNode: string | null; // Saved graphs list savedGraphs: GraphListItem[]; // Validation validationErrors: string[]; validationWarnings: string[]; executionOrder: string[]; // React Flow handlers onNodesChange: OnNodesChange; onEdgesChange: OnEdgesChange; onConnect: OnConnect; // Operations addAgentNode: (agent: AgentProfile, position?: XYPosition) => void; removeNode: (nodeId: string) => void; setEdgeCondition: (edgeId: string, condition: string, weight?: number) => void; setStartNode: (nodeId: string | null) => void; setEndNode: (nodeId: string | null) => void; setTaskQuery: (query: string) => void; setGraphName: (name: string) => void; setGraphDescription: (desc: string) => void; updateNodeExecutionStatus: (nodeId: string, status: AgentExecutionStatus, response?: string, tokens?: number) => void; resetExecutionStatus: () => void; // Auto-layout autoLayout: () => void; // Persistence saveGraph: () => Promise; loadGraph: (graphId: string) => Promise; fetchSavedGraphs: () => Promise; deleteGraph: (graphId: string) => Promise; validate: () => Promise; newGraph: () => void; loadTemplate: (template: GraphTemplate) => void; // Node editing editingNodeId: string | null; setEditingNodeId: (nodeId: string | null) => void; updateNodeData: (nodeId: string, data: Partial) => void; // Conversion toGraphRequest: () => GraphSaveRequest; } export interface GraphTemplate { name: string; description: string; agents: { agent_id: string; display_name: string; persona: string; description: string; tools: string[] }[]; edges: { source: string; target: string; weight: number; condition: string | null }[]; start_node: string; end_node: string; } const NODE_WIDTH = 220; const NODE_HEIGHT = 120; export const useGraphStore = create((set, get) => ({ nodes: [], edges: [], graphId: null, graphName: "Untitled Workflow", graphDescription: "", taskQuery: "", startNode: null, endNode: null, savedGraphs: [], validationErrors: [], validationWarnings: [], executionOrder: [], editingNodeId: null, onNodesChange: (changes) => set((s) => ({ nodes: applyNodeChanges(changes, s.nodes) as Node[] })), onEdgesChange: (changes) => set((s) => ({ edges: applyEdgeChanges(changes, s.edges) as Edge[] })), onConnect: (connection) => set((s) => ({ edges: rfAddEdge( { ...connection, type: "conditionEdge", animated: true, markerEnd: { type: MarkerType.ArrowClosed }, data: { condition: null, label: null, weight: 1.0 }, }, s.edges ), })), addAgentNode: (agent, position) => { const existing = get().nodes; const pos = position || { x: 100 + (existing.length % 4) * 260, y: 100 + Math.floor(existing.length / 4) * 160, }; const newNode: Node = { id: agent.agent_id, type: "agentNode", position: pos, data: { agentId: agent.agent_id, displayName: agent.display_name, persona: agent.persona || "", description: agent.description || "", llmBackbone: agent.llm_backbone || null, tools: agent.tools || [], isStartNode: false, isEndNode: false, executionStatus: "idle", }, }; set((s) => ({ nodes: [...s.nodes, newNode] })); }, removeNode: (nodeId) => set((s) => ({ nodes: s.nodes.filter((n) => n.id !== nodeId), edges: s.edges.filter((e) => e.source !== nodeId && e.target !== nodeId), startNode: s.startNode === nodeId ? null : s.startNode, endNode: s.endNode === nodeId ? null : s.endNode, })), setEdgeCondition: (edgeId, condition, weight) => set((s) => ({ edges: s.edges.map((e) => e.id === edgeId ? { ...e, data: { ...e.data, condition, label: condition, weight: weight ?? e.data?.weight ?? 1.0 } } : e ), })), setStartNode: (nodeId) => set((s) => ({ startNode: nodeId, nodes: s.nodes.map((n) => ({ ...n, data: { ...n.data, isStartNode: n.id === nodeId }, })), })), setEndNode: (nodeId) => set((s) => ({ endNode: nodeId, nodes: s.nodes.map((n) => ({ ...n, data: { ...n.data, isEndNode: n.id === nodeId }, })), })), setEditingNodeId: (nodeId) => set({ editingNodeId: nodeId }), updateNodeData: (nodeId, data) => set((s) => ({ nodes: s.nodes.map((n) => n.id === nodeId ? { ...n, data: { ...n.data, ...data } } : n ), })), setTaskQuery: (query) => set({ taskQuery: query }), setGraphName: (name) => set({ graphName: name }), setGraphDescription: (desc) => set({ graphDescription: desc }), updateNodeExecutionStatus: (nodeId, status, response, tokens) => set((s) => ({ nodes: s.nodes.map((n) => n.id === nodeId ? { ...n, data: { ...n.data, executionStatus: status, response: response ?? n.data.response, tokensUsed: tokens ?? n.data.tokensUsed, }, } : n ), })), resetExecutionStatus: () => set((s) => ({ nodes: s.nodes.map((n) => ({ ...n, data: { ...n.data, executionStatus: "idle", response: undefined, tokensUsed: undefined }, })), edges: s.edges.map((e) => ({ ...e, data: { ...e.data, evaluated: undefined }, })), })), autoLayout: () => { const { nodes, edges } = get(); if (nodes.length === 0) return; const g = new dagre.graphlib.Graph(); g.setGraph({ rankdir: "TB", ranksep: 80, nodesep: 60 }); g.setDefaultEdgeLabel(() => ({})); nodes.forEach((node) => { g.setNode(node.id, { width: NODE_WIDTH, height: NODE_HEIGHT }); }); edges.forEach((edge) => { g.setEdge(edge.source, edge.target); }); dagre.layout(g); const layoutNodes = nodes.map((node) => { const n = g.node(node.id); return { ...node, position: { x: n.x - NODE_WIDTH / 2, y: n.y - NODE_HEIGHT / 2 }, }; }); set({ nodes: layoutNodes }); }, toGraphRequest: () => { const s = get(); const positions: Record = {}; s.nodes.forEach((n) => { positions[n.id] = { x: n.position.x, y: n.position.y }; }); return { name: s.graphName, description: s.graphDescription, agents: s.nodes.map((n) => ({ agent_id: n.data.agentId, display_name: n.data.displayName, persona: n.data.persona, description: n.data.description, llm_backbone: n.data.llmBackbone, tools: n.data.tools, })), edges: s.edges.map((e) => ({ source: e.source, target: e.target, weight: e.data?.weight ?? 1.0, condition: e.data?.condition || null, label: e.data?.label || null, })), positions, start_node: s.startNode, end_node: s.endNode, task_query: s.taskQuery, }; }, saveGraph: async () => { const s = get(); const data = s.toGraphRequest(); let result: GraphResponse; if (s.graphId) { result = await api.put(`/graphs/${s.graphId}`, data); } else { result = await api.post("/graphs", data); } set({ graphId: result.graph_id }); return result; }, loadGraph: async (graphId) => { const graph = await api.get(`/graphs/${graphId}`); const nodes: Node[] = graph.agents.map((a) => ({ id: a.agent_id, type: "agentNode", position: graph.positions[a.agent_id] || { x: 0, y: 0 }, data: { agentId: a.agent_id, displayName: a.display_name, persona: a.persona || "", description: a.description || "", llmBackbone: a.llm_backbone || null, tools: a.tools || [], isStartNode: a.agent_id === graph.start_node, isEndNode: a.agent_id === graph.end_node, executionStatus: "idle" as const, }, })); const edges: Edge[] = graph.edges.map((e, i) => ({ id: `e-${e.source}-${e.target}-${i}`, source: e.source, target: e.target, type: "conditionEdge", animated: true, markerEnd: { type: MarkerType.ArrowClosed }, data: { condition: e.condition, label: e.label || e.condition, weight: e.weight ?? 1.0 }, })); set({ graphId: graph.graph_id, graphName: graph.name, graphDescription: graph.description, taskQuery: graph.task_query, startNode: graph.start_node || null, endNode: graph.end_node || null, nodes, edges, }); }, fetchSavedGraphs: async () => { const graphs = await api.get("/graphs"); set({ savedGraphs: graphs }); }, deleteGraph: async (graphId) => { await api.delete(`/graphs/${graphId}`); set((s) => ({ savedGraphs: s.savedGraphs.filter((g) => g.graph_id !== graphId), ...(s.graphId === graphId ? { graphId: null } : {}), })); }, validate: async () => { const data = get().toGraphRequest(); const result = await api.post("/graphs/validate", data); set({ validationErrors: result.errors, validationWarnings: result.warnings, executionOrder: result.execution_order, }); return result; }, newGraph: () => set({ nodes: [], edges: [], graphId: null, graphName: "Untitled Workflow", graphDescription: "", taskQuery: "", startNode: null, endNode: null, validationErrors: [], validationWarnings: [], executionOrder: [], }), loadTemplate: (template) => { const nodes: Node[] = template.agents.map((a, i) => ({ id: a.agent_id, type: "agentNode", position: { x: 300, y: 80 + i * 160 }, data: { agentId: a.agent_id, displayName: a.display_name, persona: a.persona, description: a.description, llmBackbone: null, tools: a.tools, isStartNode: a.agent_id === template.start_node, isEndNode: a.agent_id === template.end_node, executionStatus: "idle" as const, }, })); const edges: Edge[] = template.edges.map((e, i) => ({ id: `e-${e.source}-${e.target}-${i}`, source: e.source, target: e.target, type: "conditionEdge", animated: true, markerEnd: { type: MarkerType.ArrowClosed }, data: { condition: e.condition, label: e.condition, weight: e.weight }, })); set({ nodes, edges, graphId: null, graphName: template.name, graphDescription: template.description, taskQuery: "", startNode: template.start_node, endNode: template.end_node, validationErrors: [], validationWarnings: [], executionOrder: [], }); // Auto-layout after a tick so React Flow has the nodes setTimeout(() => get().autoLayout(), 50); }, }));