Spaces:
Running
Running
| import { create } from "zustand"; | |
| import { | |
| type Node, | |
| type Edge, | |
| type OnNodesChange, | |
| type OnEdgesChange, | |
| type OnConnect, | |
| type XYPosition, | |
| applyNodeChanges, | |
| applyEdgeChanges, | |
| addEdge as rfAddEdge, | |
| MarkerType, | |
| } from "@xyflow/react"; | |
| import dagre from "dagre"; | |
| import { api } from "@/lib/api"; | |
| import type { AgentProfile } from "@/types/agent"; | |
| import type { | |
| GraphSaveRequest, | |
| GraphResponse, | |
| GraphListItem, | |
| GraphValidationResponse, | |
| } from "@/types/graph"; | |
| import type { AgentExecutionStatus } from "@/types/execution"; | |
| export interface AgentNodeData { | |
| agentId: string; | |
| displayName: string; | |
| persona: string; | |
| description: string; | |
| llmBackbone: string | null; | |
| tools: string[]; | |
| isStartNode: boolean; | |
| isEndNode: boolean; | |
| executionStatus: AgentExecutionStatus; | |
| response?: string; | |
| tokensUsed?: number; | |
| [key: string]: unknown; | |
| } | |
| export interface EdgeData { | |
| condition?: string | null; | |
| label?: string | null; | |
| weight?: number; | |
| evaluated?: boolean; | |
| [key: string]: unknown; | |
| } | |
| interface GraphStore { | |
| // React Flow state | |
| nodes: Node<AgentNodeData>[]; | |
| edges: Edge<EdgeData>[]; | |
| // Metadata | |
| graphId: string | null; | |
| graphName: string; | |
| graphDescription: string; | |
| taskQuery: string; | |
| startNode: string | null; | |
| endNode: string | null; | |
| // Saved graphs list | |
| savedGraphs: GraphListItem[]; | |
| // Validation | |
| validationErrors: string[]; | |
| validationWarnings: string[]; | |
| executionOrder: string[]; | |
| // React Flow handlers | |
| onNodesChange: OnNodesChange; | |
| onEdgesChange: OnEdgesChange; | |
| onConnect: OnConnect; | |
| // Operations | |
| addAgentNode: (agent: AgentProfile, position?: XYPosition) => void; | |
| removeNode: (nodeId: string) => void; | |
| setEdgeCondition: (edgeId: string, condition: string, weight?: number) => void; | |
| setStartNode: (nodeId: string | null) => void; | |
| setEndNode: (nodeId: string | null) => void; | |
| setTaskQuery: (query: string) => void; | |
| setGraphName: (name: string) => void; | |
| setGraphDescription: (desc: string) => void; | |
| updateNodeExecutionStatus: (nodeId: string, status: AgentExecutionStatus, response?: string, tokens?: number) => void; | |
| resetExecutionStatus: () => void; | |
| // Auto-layout | |
| autoLayout: () => void; | |
| // Persistence | |
| saveGraph: () => Promise<GraphResponse>; | |
| loadGraph: (graphId: string) => Promise<void>; | |
| fetchSavedGraphs: () => Promise<void>; | |
| deleteGraph: (graphId: string) => Promise<void>; | |
| validate: () => Promise<GraphValidationResponse>; | |
| newGraph: () => void; | |
| loadTemplate: (template: GraphTemplate) => void; | |
| // Node editing | |
| editingNodeId: string | null; | |
| setEditingNodeId: (nodeId: string | null) => void; | |
| updateNodeData: (nodeId: string, data: Partial<AgentNodeData>) => void; | |
| // Conversion | |
| toGraphRequest: () => GraphSaveRequest; | |
| } | |
| export interface GraphTemplate { | |
| name: string; | |
| description: string; | |
| agents: { agent_id: string; display_name: string; persona: string; description: string; tools: string[] }[]; | |
| edges: { source: string; target: string; weight: number; condition: string | null }[]; | |
| start_node: string; | |
| end_node: string; | |
| } | |
| const NODE_WIDTH = 220; | |
| const NODE_HEIGHT = 120; | |
| export const useGraphStore = create<GraphStore>((set, get) => ({ | |
| nodes: [], | |
| edges: [], | |
| graphId: null, | |
| graphName: "Untitled Workflow", | |
| graphDescription: "", | |
| taskQuery: "", | |
| startNode: null, | |
| endNode: null, | |
| savedGraphs: [], | |
| validationErrors: [], | |
| validationWarnings: [], | |
| executionOrder: [], | |
| editingNodeId: null, | |
| onNodesChange: (changes) => | |
| set((s) => ({ nodes: applyNodeChanges(changes, s.nodes) as Node<AgentNodeData>[] })), | |
| onEdgesChange: (changes) => | |
| set((s) => ({ edges: applyEdgeChanges(changes, s.edges) as Edge<EdgeData>[] })), | |
| onConnect: (connection) => | |
| set((s) => ({ | |
| edges: rfAddEdge( | |
| { | |
| ...connection, | |
| type: "conditionEdge", | |
| animated: true, | |
| markerEnd: { type: MarkerType.ArrowClosed }, | |
| data: { condition: null, label: null, weight: 1.0 }, | |
| }, | |
| s.edges | |
| ), | |
| })), | |
| addAgentNode: (agent, position) => { | |
| const existing = get().nodes; | |
| const pos = position || { | |
| x: 100 + (existing.length % 4) * 260, | |
| y: 100 + Math.floor(existing.length / 4) * 160, | |
| }; | |
| const newNode: Node<AgentNodeData> = { | |
| id: agent.agent_id, | |
| type: "agentNode", | |
| position: pos, | |
| data: { | |
| agentId: agent.agent_id, | |
| displayName: agent.display_name, | |
| persona: agent.persona || "", | |
| description: agent.description || "", | |
| llmBackbone: agent.llm_backbone || null, | |
| tools: agent.tools || [], | |
| isStartNode: false, | |
| isEndNode: false, | |
| executionStatus: "idle", | |
| }, | |
| }; | |
| set((s) => ({ nodes: [...s.nodes, newNode] })); | |
| }, | |
| removeNode: (nodeId) => | |
| set((s) => ({ | |
| nodes: s.nodes.filter((n) => n.id !== nodeId), | |
| edges: s.edges.filter((e) => e.source !== nodeId && e.target !== nodeId), | |
| startNode: s.startNode === nodeId ? null : s.startNode, | |
| endNode: s.endNode === nodeId ? null : s.endNode, | |
| })), | |
| setEdgeCondition: (edgeId, condition, weight) => | |
| set((s) => ({ | |
| edges: s.edges.map((e) => | |
| e.id === edgeId | |
| ? { ...e, data: { ...e.data, condition, label: condition, weight: weight ?? e.data?.weight ?? 1.0 } } | |
| : e | |
| ), | |
| })), | |
| setStartNode: (nodeId) => | |
| set((s) => ({ | |
| startNode: nodeId, | |
| nodes: s.nodes.map((n) => ({ | |
| ...n, | |
| data: { ...n.data, isStartNode: n.id === nodeId }, | |
| })), | |
| })), | |
| setEndNode: (nodeId) => | |
| set((s) => ({ | |
| endNode: nodeId, | |
| nodes: s.nodes.map((n) => ({ | |
| ...n, | |
| data: { ...n.data, isEndNode: n.id === nodeId }, | |
| })), | |
| })), | |
| setEditingNodeId: (nodeId) => set({ editingNodeId: nodeId }), | |
| updateNodeData: (nodeId, data) => | |
| set((s) => ({ | |
| nodes: s.nodes.map((n) => | |
| n.id === nodeId ? { ...n, data: { ...n.data, ...data } } : n | |
| ), | |
| })), | |
| setTaskQuery: (query) => set({ taskQuery: query }), | |
| setGraphName: (name) => set({ graphName: name }), | |
| setGraphDescription: (desc) => set({ graphDescription: desc }), | |
| updateNodeExecutionStatus: (nodeId, status, response, tokens) => | |
| set((s) => ({ | |
| nodes: s.nodes.map((n) => | |
| n.id === nodeId | |
| ? { | |
| ...n, | |
| data: { | |
| ...n.data, | |
| executionStatus: status, | |
| response: response ?? n.data.response, | |
| tokensUsed: tokens ?? n.data.tokensUsed, | |
| }, | |
| } | |
| : n | |
| ), | |
| })), | |
| resetExecutionStatus: () => | |
| set((s) => ({ | |
| nodes: s.nodes.map((n) => ({ | |
| ...n, | |
| data: { ...n.data, executionStatus: "idle", response: undefined, tokensUsed: undefined }, | |
| })), | |
| edges: s.edges.map((e) => ({ | |
| ...e, | |
| data: { ...e.data, evaluated: undefined }, | |
| })), | |
| })), | |
| autoLayout: () => { | |
| const { nodes, edges } = get(); | |
| if (nodes.length === 0) return; | |
| const g = new dagre.graphlib.Graph(); | |
| g.setGraph({ rankdir: "TB", ranksep: 80, nodesep: 60 }); | |
| g.setDefaultEdgeLabel(() => ({})); | |
| nodes.forEach((node) => { | |
| g.setNode(node.id, { width: NODE_WIDTH, height: NODE_HEIGHT }); | |
| }); | |
| edges.forEach((edge) => { | |
| g.setEdge(edge.source, edge.target); | |
| }); | |
| dagre.layout(g); | |
| const layoutNodes = nodes.map((node) => { | |
| const n = g.node(node.id); | |
| return { | |
| ...node, | |
| position: { x: n.x - NODE_WIDTH / 2, y: n.y - NODE_HEIGHT / 2 }, | |
| }; | |
| }); | |
| set({ nodes: layoutNodes }); | |
| }, | |
| toGraphRequest: () => { | |
| const s = get(); | |
| const positions: Record<string, { x: number; y: number }> = {}; | |
| s.nodes.forEach((n) => { | |
| positions[n.id] = { x: n.position.x, y: n.position.y }; | |
| }); | |
| return { | |
| name: s.graphName, | |
| description: s.graphDescription, | |
| agents: s.nodes.map((n) => ({ | |
| agent_id: n.data.agentId, | |
| display_name: n.data.displayName, | |
| persona: n.data.persona, | |
| description: n.data.description, | |
| llm_backbone: n.data.llmBackbone, | |
| tools: n.data.tools, | |
| })), | |
| edges: s.edges.map((e) => ({ | |
| source: e.source, | |
| target: e.target, | |
| weight: e.data?.weight ?? 1.0, | |
| condition: e.data?.condition || null, | |
| label: e.data?.label || null, | |
| })), | |
| positions, | |
| start_node: s.startNode, | |
| end_node: s.endNode, | |
| task_query: s.taskQuery, | |
| }; | |
| }, | |
| saveGraph: async () => { | |
| const s = get(); | |
| const data = s.toGraphRequest(); | |
| let result: GraphResponse; | |
| if (s.graphId) { | |
| result = await api.put<GraphResponse>(`/graphs/${s.graphId}`, data); | |
| } else { | |
| result = await api.post<GraphResponse>("/graphs", data); | |
| } | |
| set({ graphId: result.graph_id }); | |
| return result; | |
| }, | |
| loadGraph: async (graphId) => { | |
| const graph = await api.get<GraphResponse>(`/graphs/${graphId}`); | |
| const nodes: Node<AgentNodeData>[] = graph.agents.map((a) => ({ | |
| id: a.agent_id, | |
| type: "agentNode", | |
| position: graph.positions[a.agent_id] || { x: 0, y: 0 }, | |
| data: { | |
| agentId: a.agent_id, | |
| displayName: a.display_name, | |
| persona: a.persona || "", | |
| description: a.description || "", | |
| llmBackbone: a.llm_backbone || null, | |
| tools: a.tools || [], | |
| isStartNode: a.agent_id === graph.start_node, | |
| isEndNode: a.agent_id === graph.end_node, | |
| executionStatus: "idle" as const, | |
| }, | |
| })); | |
| const edges: Edge<EdgeData>[] = graph.edges.map((e, i) => ({ | |
| id: `e-${e.source}-${e.target}-${i}`, | |
| source: e.source, | |
| target: e.target, | |
| type: "conditionEdge", | |
| animated: true, | |
| markerEnd: { type: MarkerType.ArrowClosed }, | |
| data: { condition: e.condition, label: e.label || e.condition, weight: e.weight ?? 1.0 }, | |
| })); | |
| set({ | |
| graphId: graph.graph_id, | |
| graphName: graph.name, | |
| graphDescription: graph.description, | |
| taskQuery: graph.task_query, | |
| startNode: graph.start_node || null, | |
| endNode: graph.end_node || null, | |
| nodes, | |
| edges, | |
| }); | |
| }, | |
| fetchSavedGraphs: async () => { | |
| const graphs = await api.get<GraphListItem[]>("/graphs"); | |
| set({ savedGraphs: graphs }); | |
| }, | |
| deleteGraph: async (graphId) => { | |
| await api.delete(`/graphs/${graphId}`); | |
| set((s) => ({ | |
| savedGraphs: s.savedGraphs.filter((g) => g.graph_id !== graphId), | |
| ...(s.graphId === graphId ? { graphId: null } : {}), | |
| })); | |
| }, | |
| validate: async () => { | |
| const data = get().toGraphRequest(); | |
| const result = await api.post<GraphValidationResponse>("/graphs/validate", data); | |
| set({ | |
| validationErrors: result.errors, | |
| validationWarnings: result.warnings, | |
| executionOrder: result.execution_order, | |
| }); | |
| return result; | |
| }, | |
| newGraph: () => | |
| set({ | |
| nodes: [], | |
| edges: [], | |
| graphId: null, | |
| graphName: "Untitled Workflow", | |
| graphDescription: "", | |
| taskQuery: "", | |
| startNode: null, | |
| endNode: null, | |
| validationErrors: [], | |
| validationWarnings: [], | |
| executionOrder: [], | |
| }), | |
| loadTemplate: (template) => { | |
| const nodes: Node<AgentNodeData>[] = template.agents.map((a, i) => ({ | |
| id: a.agent_id, | |
| type: "agentNode", | |
| position: { x: 300, y: 80 + i * 160 }, | |
| data: { | |
| agentId: a.agent_id, | |
| displayName: a.display_name, | |
| persona: a.persona, | |
| description: a.description, | |
| llmBackbone: null, | |
| tools: a.tools, | |
| isStartNode: a.agent_id === template.start_node, | |
| isEndNode: a.agent_id === template.end_node, | |
| executionStatus: "idle" as const, | |
| }, | |
| })); | |
| const edges: Edge<EdgeData>[] = template.edges.map((e, i) => ({ | |
| id: `e-${e.source}-${e.target}-${i}`, | |
| source: e.source, | |
| target: e.target, | |
| type: "conditionEdge", | |
| animated: true, | |
| markerEnd: { type: MarkerType.ArrowClosed }, | |
| data: { condition: e.condition, label: e.condition, weight: e.weight }, | |
| })); | |
| set({ | |
| nodes, | |
| edges, | |
| graphId: null, | |
| graphName: template.name, | |
| graphDescription: template.description, | |
| taskQuery: "", | |
| startNode: template.start_node, | |
| endNode: template.end_node, | |
| validationErrors: [], | |
| validationWarnings: [], | |
| executionOrder: [], | |
| }); | |
| // Auto-layout after a tick so React Flow has the nodes | |
| setTimeout(() => get().autoLayout(), 50); | |
| }, | |
| })); | |