wuhp commited on
Commit
b30ebd6
·
verified ·
1 Parent(s): 69aaee3

Update components/CustomNode.tsx

Browse files
Files changed (1) hide show
  1. components/CustomNode.tsx +20 -4
components/CustomNode.tsx CHANGED
@@ -1,10 +1,11 @@
 
1
  import React, { memo } from 'react';
2
  import { Handle, Position, NodeProps } from 'reactflow';
3
  import {
4
  Layers, Box, Activity, Zap, ArrowRight, Grid, Minimize,
5
  Database, GitBranch, AlignJustify, Type, Combine, Maximize,
6
  ArrowUpCircle, Sliders, RefreshCcw, Brain, Crosshair, Network, Clock, Eye, Workflow,
7
- Terminal, MinusCircle
8
  } from 'lucide-react';
9
  import { NodeData, LayerType } from '../types';
10
 
@@ -13,19 +14,24 @@ const getIcon = (type: LayerType) => {
13
  case LayerType.INPUT: return <ArrowRight className="w-4 h-4" />;
14
  case LayerType.CONV1D: return <Activity className="w-4 h-4" />;
15
  case LayerType.CONV2D: return <Layers className="w-4 h-4" />;
 
16
  case LayerType.CONV_TRANSPOSE2D: return <ArrowUpCircle className="w-4 h-4" />;
17
  case LayerType.LINEAR: return <Grid className="w-4 h-4" />;
18
  case LayerType.RELU:
19
  case LayerType.LEAKYRELU:
 
20
  case LayerType.GELU:
21
  case LayerType.SILU:
22
  case LayerType.SWIGLU:
23
  case LayerType.SIGMOID:
24
  case LayerType.TANH: return <Zap className="w-4 h-4" />;
25
  case LayerType.MAXPOOL:
 
26
  case LayerType.AVGPOOL:
27
  case LayerType.ADAPTIVEAVGPOOL: return <Minimize className="w-4 h-4" />;
 
28
  case LayerType.UPSAMPLE: return <Maximize className="w-4 h-4" />;
 
29
 
30
  // Transformer / GenAI
31
  case LayerType.ATTENTION:
@@ -35,6 +41,7 @@ const getIcon = (type: LayerType) => {
35
  case LayerType.TRANSFORMER_DECODER: return <Brain className="w-4 h-4" />;
36
  case LayerType.MOE_BLOCK: return <Network className="w-4 h-4" />;
37
  case LayerType.ACTION_HEAD: return <Crosshair className="w-4 h-4" />;
 
38
  case LayerType.PATCH_EMBED:
39
  case LayerType.SAM_PROMPT_ENCODER: return <Eye className="w-4 h-4" />;
40
  case LayerType.SAM_MASK_DECODER: return <Workflow className="w-4 h-4" />;
@@ -46,6 +53,7 @@ const getIcon = (type: LayerType) => {
46
  case LayerType.EMBEDDING: return <Database className="w-4 h-4" />;
47
  case LayerType.LAYERNORM:
48
  case LayerType.BATCHNORM:
 
49
  case LayerType.INSTANCENORM:
50
  case LayerType.RMSNORM: return <Sliders className="w-4 h-4" />;
51
  case LayerType.FLATTEN: return <Type className="w-4 h-4" />;
@@ -53,7 +61,9 @@ const getIcon = (type: LayerType) => {
53
  case LayerType.ADD: return <Combine className="w-4 h-4" />;
54
 
55
  case LayerType.CUSTOM: return <Terminal className="w-4 h-4" />;
56
- case LayerType.IDENTITY: return <MinusCircle className="w-4 h-4" />;
 
 
57
 
58
  default: return <Box className="w-4 h-4" />;
59
  }
@@ -66,11 +76,15 @@ const getColor = (type: LayerType) => {
66
 
67
  case LayerType.CONV2D:
68
  case LayerType.CONV1D:
 
69
  case LayerType.CONV_TRANSPOSE2D:
70
  case LayerType.MAXPOOL:
 
71
  case LayerType.AVGPOOL:
 
72
  case LayerType.ADAPTIVEAVGPOOL:
73
- case LayerType.UPSAMPLE: return 'border-blue-500 shadow-blue-500/20';
 
74
 
75
  case LayerType.LINEAR:
76
  case LayerType.ACTION_HEAD: return 'border-violet-500 shadow-violet-500/20';
@@ -82,6 +96,7 @@ const getColor = (type: LayerType) => {
82
  case LayerType.TRANSFORMER_ENCODER:
83
  case LayerType.TRANSFORMER_DECODER:
84
  case LayerType.MOE_BLOCK:
 
85
  case LayerType.PATCH_EMBED:
86
  case LayerType.SAM_PROMPT_ENCODER:
87
  case LayerType.SAM_MASK_DECODER: return 'border-amber-500 shadow-amber-500/20';
@@ -94,6 +109,7 @@ const getColor = (type: LayerType) => {
94
 
95
  case LayerType.BATCHNORM:
96
  case LayerType.LAYERNORM:
 
97
  case LayerType.INSTANCENORM:
98
  case LayerType.RMSNORM: return 'border-cyan-500 shadow-cyan-500/20';
99
 
@@ -180,4 +196,4 @@ const CustomNode = ({ data, selected }: NodeProps<NodeData>) => {
180
  );
181
  };
182
 
183
- export default memo(CustomNode);
 
1
+
2
  import React, { memo } from 'react';
3
  import { Handle, Position, NodeProps } from 'reactflow';
4
  import {
5
  Layers, Box, Activity, Zap, ArrowRight, Grid, Minimize,
6
  Database, GitBranch, AlignJustify, Type, Combine, Maximize,
7
  ArrowUpCircle, Sliders, RefreshCcw, Brain, Crosshair, Network, Clock, Eye, Workflow,
8
+ Terminal, MinusCircle, Scaling, BoxSelect, Wifi
9
  } from 'lucide-react';
10
  import { NodeData, LayerType } from '../types';
11
 
 
14
  case LayerType.INPUT: return <ArrowRight className="w-4 h-4" />;
15
  case LayerType.CONV1D: return <Activity className="w-4 h-4" />;
16
  case LayerType.CONV2D: return <Layers className="w-4 h-4" />;
17
+ case LayerType.CONV3D: return <Box className="w-4 h-4" />;
18
  case LayerType.CONV_TRANSPOSE2D: return <ArrowUpCircle className="w-4 h-4" />;
19
  case LayerType.LINEAR: return <Grid className="w-4 h-4" />;
20
  case LayerType.RELU:
21
  case LayerType.LEAKYRELU:
22
+ case LayerType.PRELU:
23
  case LayerType.GELU:
24
  case LayerType.SILU:
25
  case LayerType.SWIGLU:
26
  case LayerType.SIGMOID:
27
  case LayerType.TANH: return <Zap className="w-4 h-4" />;
28
  case LayerType.MAXPOOL:
29
+ case LayerType.MAXPOOL3D:
30
  case LayerType.AVGPOOL:
31
  case LayerType.ADAPTIVEAVGPOOL: return <Minimize className="w-4 h-4" />;
32
+ case LayerType.GLOBAL_AVG_POOL: return <BoxSelect className="w-4 h-4" />;
33
  case LayerType.UPSAMPLE: return <Maximize className="w-4 h-4" />;
34
+ case LayerType.PIXEL_SHUFFLE: return <Scaling className="w-4 h-4" />;
35
 
36
  // Transformer / GenAI
37
  case LayerType.ATTENTION:
 
41
  case LayerType.TRANSFORMER_DECODER: return <Brain className="w-4 h-4" />;
42
  case LayerType.MOE_BLOCK: return <Network className="w-4 h-4" />;
43
  case LayerType.ACTION_HEAD: return <Crosshair className="w-4 h-4" />;
44
+ case LayerType.SE_BLOCK: return <Wifi className="w-4 h-4" />;
45
  case LayerType.PATCH_EMBED:
46
  case LayerType.SAM_PROMPT_ENCODER: return <Eye className="w-4 h-4" />;
47
  case LayerType.SAM_MASK_DECODER: return <Workflow className="w-4 h-4" />;
 
53
  case LayerType.EMBEDDING: return <Database className="w-4 h-4" />;
54
  case LayerType.LAYERNORM:
55
  case LayerType.BATCHNORM:
56
+ case LayerType.GROUPNORM:
57
  case LayerType.INSTANCENORM:
58
  case LayerType.RMSNORM: return <Sliders className="w-4 h-4" />;
59
  case LayerType.FLATTEN: return <Type className="w-4 h-4" />;
 
61
  case LayerType.ADD: return <Combine className="w-4 h-4" />;
62
 
63
  case LayerType.CUSTOM: return <Terminal className="w-4 h-4" />;
64
+ case LayerType.IDENTITY:
65
+ case LayerType.DROPOUT:
66
+ case LayerType.DROPPATH: return <MinusCircle className="w-4 h-4" />;
67
 
68
  default: return <Box className="w-4 h-4" />;
69
  }
 
76
 
77
  case LayerType.CONV2D:
78
  case LayerType.CONV1D:
79
+ case LayerType.CONV3D:
80
  case LayerType.CONV_TRANSPOSE2D:
81
  case LayerType.MAXPOOL:
82
+ case LayerType.MAXPOOL3D:
83
  case LayerType.AVGPOOL:
84
+ case LayerType.GLOBAL_AVG_POOL:
85
  case LayerType.ADAPTIVEAVGPOOL:
86
+ case LayerType.UPSAMPLE:
87
+ case LayerType.PIXEL_SHUFFLE: return 'border-blue-500 shadow-blue-500/20';
88
 
89
  case LayerType.LINEAR:
90
  case LayerType.ACTION_HEAD: return 'border-violet-500 shadow-violet-500/20';
 
96
  case LayerType.TRANSFORMER_ENCODER:
97
  case LayerType.TRANSFORMER_DECODER:
98
  case LayerType.MOE_BLOCK:
99
+ case LayerType.SE_BLOCK:
100
  case LayerType.PATCH_EMBED:
101
  case LayerType.SAM_PROMPT_ENCODER:
102
  case LayerType.SAM_MASK_DECODER: return 'border-amber-500 shadow-amber-500/20';
 
109
 
110
  case LayerType.BATCHNORM:
111
  case LayerType.LAYERNORM:
112
+ case LayerType.GROUPNORM:
113
  case LayerType.INSTANCENORM:
114
  case LayerType.RMSNORM: return 'border-cyan-500 shadow-cyan-500/20';
115
 
 
196
  );
197
  };
198
 
199
+ export default memo(CustomNode);