| import type {} from '@redux-devtools/extension'; |
| import { humanId } from 'human-id'; |
| import differenceWith from 'lodash/differenceWith'; |
| import intersectionWith from 'lodash/intersectionWith'; |
| import lodashSet from 'lodash/set'; |
| import { |
| Connection, |
| Edge, |
| EdgeChange, |
| Node, |
| NodeChange, |
| OnConnect, |
| OnEdgesChange, |
| OnNodesChange, |
| OnSelectionChangeFunc, |
| OnSelectionChangeParams, |
| addEdge, |
| applyEdgeChanges, |
| applyNodeChanges, |
| } from 'reactflow'; |
| import { create } from 'zustand'; |
| import { devtools } from 'zustand/middleware'; |
| import { immer } from 'zustand/middleware/immer'; |
| import { Operator, SwitchElseTo } from './constant'; |
| import { NodeData } from './interface'; |
| import { getOperatorIndex, isEdgeEqual } from './utils'; |
|
|
| export type RFState = { |
| nodes: Node<NodeData>[]; |
| edges: Edge[]; |
| selectedNodeIds: string[]; |
| selectedEdgeIds: string[]; |
| clickedNodeId: string; |
| onNodesChange: OnNodesChange; |
| onEdgesChange: OnEdgesChange; |
| onConnect: OnConnect; |
| setNodes: (nodes: Node[]) => void; |
| setEdges: (edges: Edge[]) => void; |
| setEdgesByNodeId: (nodeId: string, edges: Edge[]) => void; |
| updateNodeForm: ( |
| nodeId: string, |
| values: any, |
| path?: (string | number)[], |
| ) => void; |
| onSelectionChange: OnSelectionChangeFunc; |
| addNode: (nodes: Node) => void; |
| getNode: (id?: string | null) => Node<NodeData> | undefined; |
| addEdge: (connection: Connection) => void; |
| getEdge: (id: string) => Edge | undefined; |
| updateFormDataOnConnect: (connection: Connection) => void; |
| updateSwitchFormData: ( |
| source: string, |
| sourceHandle?: string | null, |
| target?: string | null, |
| ) => void; |
| deletePreviousEdgeOfClassificationNode: (connection: Connection) => void; |
| duplicateNode: (id: string) => void; |
| deleteEdge: () => void; |
| deleteEdgeById: (id: string) => void; |
| deleteNodeById: (id: string) => void; |
| deleteEdgeBySourceAndSourceHandle: (connection: Partial<Connection>) => void; |
| findNodeByName: (operatorName: Operator) => Node | undefined; |
| updateMutableNodeFormItem: (id: string, field: string, value: any) => void; |
| getOperatorTypeFromId: (id?: string | null) => string | undefined; |
| updateNodeName: (id: string, name: string) => void; |
| setClickedNodeId: (id?: string) => void; |
| }; |
|
|
| |
| const useGraphStore = create<RFState>()( |
| devtools( |
| immer((set, get) => ({ |
| nodes: [] as Node[], |
| edges: [] as Edge[], |
| selectedNodeIds: [] as string[], |
| selectedEdgeIds: [] as string[], |
| clickedNodeId: '', |
| onNodesChange: (changes: NodeChange[]) => { |
| set({ |
| nodes: applyNodeChanges(changes, get().nodes), |
| }); |
| }, |
| onEdgesChange: (changes: EdgeChange[]) => { |
| set({ |
| edges: applyEdgeChanges(changes, get().edges), |
| }); |
| }, |
| onConnect: (connection: Connection) => { |
| const { |
| deletePreviousEdgeOfClassificationNode, |
| updateFormDataOnConnect, |
| } = get(); |
| set({ |
| edges: addEdge(connection, get().edges), |
| }); |
| deletePreviousEdgeOfClassificationNode(connection); |
| updateFormDataOnConnect(connection); |
| }, |
| onSelectionChange: ({ nodes, edges }: OnSelectionChangeParams) => { |
| set({ |
| selectedEdgeIds: edges.map((x) => x.id), |
| selectedNodeIds: nodes.map((x) => x.id), |
| }); |
| }, |
| setNodes: (nodes: Node[]) => { |
| set({ nodes }); |
| }, |
| setEdges: (edges: Edge[]) => { |
| set({ edges }); |
| }, |
| setEdgesByNodeId: (nodeId: string, currentDownstreamEdges: Edge[]) => { |
| const { edges, setEdges } = get(); |
| |
| const previousDownstreamEdges = edges.filter( |
| (x) => x.source === nodeId, |
| ); |
| const isDifferent = |
| previousDownstreamEdges.length !== currentDownstreamEdges.length || |
| !previousDownstreamEdges.every((x) => |
| currentDownstreamEdges.some( |
| (y) => |
| y.source === x.source && |
| y.target === x.target && |
| y.sourceHandle === x.sourceHandle, |
| ), |
| ) || |
| !currentDownstreamEdges.every((x) => |
| previousDownstreamEdges.some( |
| (y) => |
| y.source === x.source && |
| y.target === x.target && |
| y.sourceHandle === x.sourceHandle, |
| ), |
| ); |
|
|
| const intersectionDownstreamEdges = intersectionWith( |
| previousDownstreamEdges, |
| currentDownstreamEdges, |
| isEdgeEqual, |
| ); |
| if (isDifferent) { |
| |
| const irrelevantEdges = edges.filter((x) => x.source !== nodeId); |
| |
| const selfAddedDownstreamEdges = differenceWith( |
| currentDownstreamEdges, |
| intersectionDownstreamEdges, |
| isEdgeEqual, |
| ); |
| setEdges([ |
| ...irrelevantEdges, |
| ...intersectionDownstreamEdges, |
| ...selfAddedDownstreamEdges, |
| ]); |
| } |
| }, |
| addNode: (node: Node) => { |
| set({ nodes: get().nodes.concat(node) }); |
| }, |
| getNode: (id?: string | null) => { |
| return get().nodes.find((x) => x.id === id); |
| }, |
| getOperatorTypeFromId: (id?: string | null) => { |
| return get().getNode(id)?.data?.label; |
| }, |
| addEdge: (connection: Connection) => { |
| set({ |
| edges: addEdge(connection, get().edges), |
| }); |
| get().deletePreviousEdgeOfClassificationNode(connection); |
| |
| get().updateFormDataOnConnect(connection); |
| }, |
| getEdge: (id: string) => { |
| return get().edges.find((x) => x.id === id); |
| }, |
| updateFormDataOnConnect: (connection: Connection) => { |
| const { getOperatorTypeFromId, updateNodeForm, updateSwitchFormData } = |
| get(); |
| const { source, target, sourceHandle } = connection; |
| const operatorType = getOperatorTypeFromId(source); |
| if (source) { |
| switch (operatorType) { |
| case Operator.Relevant: |
| updateNodeForm(source, { [sourceHandle as string]: target }); |
| break; |
| case Operator.Categorize: |
| if (sourceHandle) |
| updateNodeForm(source, target, [ |
| 'category_description', |
| sourceHandle, |
| 'to', |
| ]); |
| break; |
| case Operator.Switch: { |
| updateSwitchFormData(source, sourceHandle, target); |
| break; |
| } |
| default: |
| break; |
| } |
| } |
| }, |
| deletePreviousEdgeOfClassificationNode: (connection: Connection) => { |
| |
| const { edges, getOperatorTypeFromId, deleteEdgeById } = get(); |
| |
| const anchoredNodes = [ |
| Operator.Categorize, |
| Operator.Relevant, |
| Operator.Switch, |
| ]; |
| if ( |
| anchoredNodes.some( |
| (x) => x === getOperatorTypeFromId(connection.source), |
| ) |
| ) { |
| const previousEdge = edges.find( |
| (x) => |
| x.source === connection.source && |
| x.sourceHandle === connection.sourceHandle && |
| x.target !== connection.target, |
| ); |
| if (previousEdge) { |
| deleteEdgeById(previousEdge.id); |
| } |
| } |
| }, |
| duplicateNode: (id: string) => { |
| const { getNode, addNode } = get(); |
| const node = getNode(id); |
| const position = { |
| x: (node?.position?.x || 0) + 30, |
| y: (node?.position?.y || 0) + 20, |
| }; |
|
|
| addNode({ |
| ...(node || {}), |
| data: node?.data, |
| selected: false, |
| dragging: false, |
| id: `${node?.data?.label}:${humanId()}`, |
| position, |
| }); |
| }, |
| deleteEdge: () => { |
| const { edges, selectedEdgeIds } = get(); |
| set({ |
| edges: edges.filter((edge) => |
| selectedEdgeIds.every((x) => x !== edge.id), |
| ), |
| }); |
| }, |
| deleteEdgeById: (id: string) => { |
| const { |
| edges, |
| updateNodeForm, |
| getOperatorTypeFromId, |
| updateSwitchFormData, |
| } = get(); |
| const currentEdge = edges.find((x) => x.id === id); |
|
|
| if (currentEdge) { |
| const { source, sourceHandle } = currentEdge; |
| const operatorType = getOperatorTypeFromId(source); |
| |
| switch (operatorType) { |
| case Operator.Relevant: |
| updateNodeForm(source, { |
| [sourceHandle as string]: undefined, |
| }); |
| break; |
| case Operator.Categorize: |
| if (sourceHandle) |
| updateNodeForm(source, undefined, [ |
| 'category_description', |
| sourceHandle, |
| 'to', |
| ]); |
| break; |
| case Operator.Switch: { |
| updateSwitchFormData(source, sourceHandle, undefined); |
| break; |
| } |
| default: |
| break; |
| } |
| } |
| set({ |
| edges: edges.filter((edge) => edge.id !== id), |
| }); |
| }, |
| deleteEdgeBySourceAndSourceHandle: ({ |
| source, |
| sourceHandle, |
| }: Partial<Connection>) => { |
| const { edges } = get(); |
| const nextEdges = edges.filter( |
| (edge) => |
| edge.source !== source || edge.sourceHandle !== sourceHandle, |
| ); |
| set({ |
| edges: nextEdges, |
| }); |
| }, |
| deleteNodeById: (id: string) => { |
| const { nodes, edges } = get(); |
| set({ |
| nodes: nodes.filter((node) => node.id !== id), |
| edges: edges |
| .filter((edge) => edge.source !== id) |
| .filter((edge) => edge.target !== id), |
| }); |
| }, |
| findNodeByName: (name: Operator) => { |
| return get().nodes.find((x) => x.data.label === name); |
| }, |
| updateNodeForm: ( |
| nodeId: string, |
| values: any, |
| path: (string | number)[] = [], |
| ) => { |
| set({ |
| nodes: get().nodes.map((node) => { |
| if (node.id === nodeId) { |
| let nextForm: Record<string, unknown> = { ...node.data.form }; |
| if (path.length === 0) { |
| nextForm = Object.assign(nextForm, values); |
| } else { |
| lodashSet(nextForm, path, values); |
| } |
| return { |
| ...node, |
| data: { |
| ...node.data, |
| form: nextForm, |
| }, |
| } as any; |
| } |
|
|
| return node; |
| }), |
| }); |
| }, |
| updateSwitchFormData: (source, sourceHandle, target) => { |
| const { updateNodeForm } = get(); |
| if (sourceHandle) { |
| if (sourceHandle === SwitchElseTo) { |
| updateNodeForm(source, target, [SwitchElseTo]); |
| } else { |
| const operatorIndex = getOperatorIndex(sourceHandle); |
| if (operatorIndex) { |
| updateNodeForm(source, target, [ |
| 'conditions', |
| Number(operatorIndex) - 1, |
| 'to', |
| ]); |
| } |
| } |
| } |
| }, |
| updateMutableNodeFormItem: (id: string, field: string, value: any) => { |
| const { nodes } = get(); |
| const idx = nodes.findIndex((x) => x.id === id); |
| if (idx) { |
| lodashSet(nodes, [idx, 'data', 'form', field], value); |
| } |
| }, |
| updateNodeName: (id, name) => { |
| if (id) { |
| set({ |
| nodes: get().nodes.map((node) => { |
| if (node.id === id) { |
| node.data.name = name; |
| } |
|
|
| return node; |
| }), |
| }); |
| } |
| }, |
| setClickedNodeId: (id?: string) => { |
| set({ clickedNodeId: id }); |
| }, |
| })), |
| { name: 'graph' }, |
| ), |
| ); |
|
|
| export default useGraphStore; |
|
|