File size: 3,004 Bytes
dfea997
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import React, { useCallback } from 'react'
import ReactFlow, {
  Background,
  Controls,
  MiniMap,
  useNodesState,
  useEdgesState,
  addEdge,
} from 'react-flow-renderer'

const nodeTypes = {
  checkpoint: ({ data }) => (
    <div className="p-3 rounded-lg bg-dark-800 border border-primary-500 text-white">
      <div className="font-bold text-primary-500">{data.label}</div>
      <div className="text-sm mt-1">{data.ckpt_name}</div>
    </div>
  ),
  textEncode: ({ data }) => (
    <div className="p-3 rounded-lg bg-dark-800 border border-green-500 text-white max-w-xs">
      <div className="font-bold text-green-500">{data.label}</div>
      <div className="text-sm mt-1 truncate">{data.text}</div>
    </div>
  ),
  sampler: ({ data }) => (
    <div className="p-3 rounded-lg bg-dark-800 border border-purple-500 text-white">
      <div className="font-bold text-purple-500">{data.label}</div>
      <div className="text-xs mt-1">
        Steps: {data.steps}, CFG: {data.cfg}
      </div>
    </div>
  ),
  vae: ({ data }) => (
    <div className="p-3 rounded-lg bg-dark-800 border border-yellow-500 text-white">
      <div className="font-bold text-yellow-500">{data.label}</div>
    </div>
  ),
  save: ({ data }) => (
    <div className="p-3 rounded-lg bg-dark-800 border border-red-500 text-white">
      <div className="font-bold text-red-500">{data.label}</div>
    </div>
  ),
}

export default function NodeEditor({ workflow }) {
  const initialNodes = workflow.nodes.map(node => {
    const baseNode = {
      id: node.id.toString(),
      position: { x: node.pos[0], y: node.pos[1] },
      data: { ...node.inputs, label: node.type },
    }

    switch(node.type) {
      case 'CheckpointLoaderSimple':
        return { ...baseNode, type: 'checkpoint' }
      case 'CLIPTextEncode':
        return { ...baseNode, type: 'textEncode' }
      case 'KSampler':
        return { ...baseNode, type: 'sampler' }
      case 'VAEDecode':
        return { ...baseNode, type: 'vae' }
      case 'SaveImage':
        return { ...baseNode, type: 'save' }
      default:
        return baseNode
    }
  })

  const initialEdges = workflow.links.map(link => ({
    id: `e${link[0]}-${link[1]}-${link[2]}-${link[3]}`,
    source: link[0].toString(),
    sourceHandle: link[1].toString(),
    target: link[2].toString(),
    targetHandle: link[3].toString(),
  }))

  const [nodes, setNodes, onNodesChange] = useNodesState(initialNodes)
  const [edges, setEdges, onEdgesChange] = useEdgesState(initialEdges)

  const onConnect = useCallback(
    params => setEdges(eds => addEdge(params, eds)),
    [setEdges]
  )

  return (
    <div className="h-[calc(100vh-64px)] bg-dark-900">
      <ReactFlow
        nodes={nodes}
        edges={edges}
        onNodesChange={onNodesChange}
        onEdgesChange={onEdgesChange}
        onConnect={onConnect}
        nodeTypes={nodeTypes}
        fitView
      >
        <Background />
        <Controls />
        <MiniMap />
      </ReactFlow>
    </div>
  )
}