Spaces:
Running
Running
Update components/CustomNode.tsx
Browse files- components/CustomNode.tsx +20 -4
components/CustomNode.tsx
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
|
|
| 1 |
import React, { memo } from 'react';
|
| 2 |
import { Handle, Position, NodeProps } from 'reactflow';
|
| 3 |
import {
|
| 4 |
Layers, Box, Activity, Zap, ArrowRight, Grid, Minimize,
|
| 5 |
Database, GitBranch, AlignJustify, Type, Combine, Maximize,
|
| 6 |
ArrowUpCircle, Sliders, RefreshCcw, Brain, Crosshair, Network, Clock, Eye, Workflow,
|
| 7 |
-
Terminal, MinusCircle
|
| 8 |
} from 'lucide-react';
|
| 9 |
import { NodeData, LayerType } from '../types';
|
| 10 |
|
|
@@ -13,19 +14,24 @@ const getIcon = (type: LayerType) => {
|
|
| 13 |
case LayerType.INPUT: return <ArrowRight className="w-4 h-4" />;
|
| 14 |
case LayerType.CONV1D: return <Activity className="w-4 h-4" />;
|
| 15 |
case LayerType.CONV2D: return <Layers className="w-4 h-4" />;
|
|
|
|
| 16 |
case LayerType.CONV_TRANSPOSE2D: return <ArrowUpCircle className="w-4 h-4" />;
|
| 17 |
case LayerType.LINEAR: return <Grid className="w-4 h-4" />;
|
| 18 |
case LayerType.RELU:
|
| 19 |
case LayerType.LEAKYRELU:
|
|
|
|
| 20 |
case LayerType.GELU:
|
| 21 |
case LayerType.SILU:
|
| 22 |
case LayerType.SWIGLU:
|
| 23 |
case LayerType.SIGMOID:
|
| 24 |
case LayerType.TANH: return <Zap className="w-4 h-4" />;
|
| 25 |
case LayerType.MAXPOOL:
|
|
|
|
| 26 |
case LayerType.AVGPOOL:
|
| 27 |
case LayerType.ADAPTIVEAVGPOOL: return <Minimize className="w-4 h-4" />;
|
|
|
|
| 28 |
case LayerType.UPSAMPLE: return <Maximize className="w-4 h-4" />;
|
|
|
|
| 29 |
|
| 30 |
// Transformer / GenAI
|
| 31 |
case LayerType.ATTENTION:
|
|
@@ -35,6 +41,7 @@ const getIcon = (type: LayerType) => {
|
|
| 35 |
case LayerType.TRANSFORMER_DECODER: return <Brain className="w-4 h-4" />;
|
| 36 |
case LayerType.MOE_BLOCK: return <Network className="w-4 h-4" />;
|
| 37 |
case LayerType.ACTION_HEAD: return <Crosshair className="w-4 h-4" />;
|
|
|
|
| 38 |
case LayerType.PATCH_EMBED:
|
| 39 |
case LayerType.SAM_PROMPT_ENCODER: return <Eye className="w-4 h-4" />;
|
| 40 |
case LayerType.SAM_MASK_DECODER: return <Workflow className="w-4 h-4" />;
|
|
@@ -46,6 +53,7 @@ const getIcon = (type: LayerType) => {
|
|
| 46 |
case LayerType.EMBEDDING: return <Database className="w-4 h-4" />;
|
| 47 |
case LayerType.LAYERNORM:
|
| 48 |
case LayerType.BATCHNORM:
|
|
|
|
| 49 |
case LayerType.INSTANCENORM:
|
| 50 |
case LayerType.RMSNORM: return <Sliders className="w-4 h-4" />;
|
| 51 |
case LayerType.FLATTEN: return <Type className="w-4 h-4" />;
|
|
@@ -53,7 +61,9 @@ const getIcon = (type: LayerType) => {
|
|
| 53 |
case LayerType.ADD: return <Combine className="w-4 h-4" />;
|
| 54 |
|
| 55 |
case LayerType.CUSTOM: return <Terminal className="w-4 h-4" />;
|
| 56 |
-
case LayerType.IDENTITY:
|
|
|
|
|
|
|
| 57 |
|
| 58 |
default: return <Box className="w-4 h-4" />;
|
| 59 |
}
|
|
@@ -66,11 +76,15 @@ const getColor = (type: LayerType) => {
|
|
| 66 |
|
| 67 |
case LayerType.CONV2D:
|
| 68 |
case LayerType.CONV1D:
|
|
|
|
| 69 |
case LayerType.CONV_TRANSPOSE2D:
|
| 70 |
case LayerType.MAXPOOL:
|
|
|
|
| 71 |
case LayerType.AVGPOOL:
|
|
|
|
| 72 |
case LayerType.ADAPTIVEAVGPOOL:
|
| 73 |
-
case LayerType.UPSAMPLE:
|
|
|
|
| 74 |
|
| 75 |
case LayerType.LINEAR:
|
| 76 |
case LayerType.ACTION_HEAD: return 'border-violet-500 shadow-violet-500/20';
|
|
@@ -82,6 +96,7 @@ const getColor = (type: LayerType) => {
|
|
| 82 |
case LayerType.TRANSFORMER_ENCODER:
|
| 83 |
case LayerType.TRANSFORMER_DECODER:
|
| 84 |
case LayerType.MOE_BLOCK:
|
|
|
|
| 85 |
case LayerType.PATCH_EMBED:
|
| 86 |
case LayerType.SAM_PROMPT_ENCODER:
|
| 87 |
case LayerType.SAM_MASK_DECODER: return 'border-amber-500 shadow-amber-500/20';
|
|
@@ -94,6 +109,7 @@ const getColor = (type: LayerType) => {
|
|
| 94 |
|
| 95 |
case LayerType.BATCHNORM:
|
| 96 |
case LayerType.LAYERNORM:
|
|
|
|
| 97 |
case LayerType.INSTANCENORM:
|
| 98 |
case LayerType.RMSNORM: return 'border-cyan-500 shadow-cyan-500/20';
|
| 99 |
|
|
@@ -180,4 +196,4 @@ const CustomNode = ({ data, selected }: NodeProps<NodeData>) => {
|
|
| 180 |
);
|
| 181 |
};
|
| 182 |
|
| 183 |
-
export default memo(CustomNode);
|
|
|
|
| 1 |
+
|
| 2 |
import React, { memo } from 'react';
|
| 3 |
import { Handle, Position, NodeProps } from 'reactflow';
|
| 4 |
import {
|
| 5 |
Layers, Box, Activity, Zap, ArrowRight, Grid, Minimize,
|
| 6 |
Database, GitBranch, AlignJustify, Type, Combine, Maximize,
|
| 7 |
ArrowUpCircle, Sliders, RefreshCcw, Brain, Crosshair, Network, Clock, Eye, Workflow,
|
| 8 |
+
Terminal, MinusCircle, Scaling, BoxSelect, Wifi
|
| 9 |
} from 'lucide-react';
|
| 10 |
import { NodeData, LayerType } from '../types';
|
| 11 |
|
|
|
|
| 14 |
case LayerType.INPUT: return <ArrowRight className="w-4 h-4" />;
|
| 15 |
case LayerType.CONV1D: return <Activity className="w-4 h-4" />;
|
| 16 |
case LayerType.CONV2D: return <Layers className="w-4 h-4" />;
|
| 17 |
+
case LayerType.CONV3D: return <Box className="w-4 h-4" />;
|
| 18 |
case LayerType.CONV_TRANSPOSE2D: return <ArrowUpCircle className="w-4 h-4" />;
|
| 19 |
case LayerType.LINEAR: return <Grid className="w-4 h-4" />;
|
| 20 |
case LayerType.RELU:
|
| 21 |
case LayerType.LEAKYRELU:
|
| 22 |
+
case LayerType.PRELU:
|
| 23 |
case LayerType.GELU:
|
| 24 |
case LayerType.SILU:
|
| 25 |
case LayerType.SWIGLU:
|
| 26 |
case LayerType.SIGMOID:
|
| 27 |
case LayerType.TANH: return <Zap className="w-4 h-4" />;
|
| 28 |
case LayerType.MAXPOOL:
|
| 29 |
+
case LayerType.MAXPOOL3D:
|
| 30 |
case LayerType.AVGPOOL:
|
| 31 |
case LayerType.ADAPTIVEAVGPOOL: return <Minimize className="w-4 h-4" />;
|
| 32 |
+
case LayerType.GLOBAL_AVG_POOL: return <BoxSelect className="w-4 h-4" />;
|
| 33 |
case LayerType.UPSAMPLE: return <Maximize className="w-4 h-4" />;
|
| 34 |
+
case LayerType.PIXEL_SHUFFLE: return <Scaling className="w-4 h-4" />;
|
| 35 |
|
| 36 |
// Transformer / GenAI
|
| 37 |
case LayerType.ATTENTION:
|
|
|
|
| 41 |
case LayerType.TRANSFORMER_DECODER: return <Brain className="w-4 h-4" />;
|
| 42 |
case LayerType.MOE_BLOCK: return <Network className="w-4 h-4" />;
|
| 43 |
case LayerType.ACTION_HEAD: return <Crosshair className="w-4 h-4" />;
|
| 44 |
+
case LayerType.SE_BLOCK: return <Wifi className="w-4 h-4" />;
|
| 45 |
case LayerType.PATCH_EMBED:
|
| 46 |
case LayerType.SAM_PROMPT_ENCODER: return <Eye className="w-4 h-4" />;
|
| 47 |
case LayerType.SAM_MASK_DECODER: return <Workflow className="w-4 h-4" />;
|
|
|
|
| 53 |
case LayerType.EMBEDDING: return <Database className="w-4 h-4" />;
|
| 54 |
case LayerType.LAYERNORM:
|
| 55 |
case LayerType.BATCHNORM:
|
| 56 |
+
case LayerType.GROUPNORM:
|
| 57 |
case LayerType.INSTANCENORM:
|
| 58 |
case LayerType.RMSNORM: return <Sliders className="w-4 h-4" />;
|
| 59 |
case LayerType.FLATTEN: return <Type className="w-4 h-4" />;
|
|
|
|
| 61 |
case LayerType.ADD: return <Combine className="w-4 h-4" />;
|
| 62 |
|
| 63 |
case LayerType.CUSTOM: return <Terminal className="w-4 h-4" />;
|
| 64 |
+
case LayerType.IDENTITY:
|
| 65 |
+
case LayerType.DROPOUT:
|
| 66 |
+
case LayerType.DROPPATH: return <MinusCircle className="w-4 h-4" />;
|
| 67 |
|
| 68 |
default: return <Box className="w-4 h-4" />;
|
| 69 |
}
|
|
|
|
| 76 |
|
| 77 |
case LayerType.CONV2D:
|
| 78 |
case LayerType.CONV1D:
|
| 79 |
+
case LayerType.CONV3D:
|
| 80 |
case LayerType.CONV_TRANSPOSE2D:
|
| 81 |
case LayerType.MAXPOOL:
|
| 82 |
+
case LayerType.MAXPOOL3D:
|
| 83 |
case LayerType.AVGPOOL:
|
| 84 |
+
case LayerType.GLOBAL_AVG_POOL:
|
| 85 |
case LayerType.ADAPTIVEAVGPOOL:
|
| 86 |
+
case LayerType.UPSAMPLE:
|
| 87 |
+
case LayerType.PIXEL_SHUFFLE: return 'border-blue-500 shadow-blue-500/20';
|
| 88 |
|
| 89 |
case LayerType.LINEAR:
|
| 90 |
case LayerType.ACTION_HEAD: return 'border-violet-500 shadow-violet-500/20';
|
|
|
|
| 96 |
case LayerType.TRANSFORMER_ENCODER:
|
| 97 |
case LayerType.TRANSFORMER_DECODER:
|
| 98 |
case LayerType.MOE_BLOCK:
|
| 99 |
+
case LayerType.SE_BLOCK:
|
| 100 |
case LayerType.PATCH_EMBED:
|
| 101 |
case LayerType.SAM_PROMPT_ENCODER:
|
| 102 |
case LayerType.SAM_MASK_DECODER: return 'border-amber-500 shadow-amber-500/20';
|
|
|
|
| 109 |
|
| 110 |
case LayerType.BATCHNORM:
|
| 111 |
case LayerType.LAYERNORM:
|
| 112 |
+
case LayerType.GROUPNORM:
|
| 113 |
case LayerType.INSTANCENORM:
|
| 114 |
case LayerType.RMSNORM: return 'border-cyan-500 shadow-cyan-500/20';
|
| 115 |
|
|
|
|
| 196 |
);
|
| 197 |
};
|
| 198 |
|
| 199 |
+
export default memo(CustomNode);
|