Spaces:
Running
Running
Update components/CustomNode.tsx
Browse files- components/CustomNode.tsx +63 -10
components/CustomNode.tsx
CHANGED
|
@@ -1,37 +1,44 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
import React, { memo } from 'react';
|
| 4 |
import { Handle, Position, NodeProps } from 'reactflow';
|
| 5 |
import {
|
| 6 |
Layers, Box, Activity, Zap, ArrowRight, Grid, Minimize,
|
| 7 |
-
Database, GitBranch,
|
| 8 |
ArrowUpCircle, Sliders, RefreshCcw, Brain, Crosshair, Network, Clock, Eye, Workflow,
|
| 9 |
Terminal, MinusCircle, Scaling, BoxSelect, Wifi,
|
| 10 |
Mic, Speaker, Radio, Cuboid, Target, Scan,
|
| 11 |
AudioLines, FileAudio, Hexagon, Component,
|
| 12 |
ScanText, Mountain, Move, Radar, Map, Orbit, Wind,
|
| 13 |
Film, Video, FastForward, Timer, Clapperboard, Merge,
|
| 14 |
-
Share2, Atom, Dna, Gamepad2, Sparkles, FlipVertical, RefreshCw
|
|
|
|
| 15 |
} from 'lucide-react';
|
| 16 |
import { NodeData, LayerType } from '../types';
|
| 17 |
|
| 18 |
const getIcon = (type: LayerType) => {
|
| 19 |
switch (type) {
|
| 20 |
case LayerType.INPUT: return <ArrowRight className="w-4 h-4" />;
|
|
|
|
|
|
|
| 21 |
case LayerType.CONV1D: return <Activity className="w-4 h-4" />;
|
| 22 |
case LayerType.CONV2D: return <Layers className="w-4 h-4" />;
|
| 23 |
case LayerType.CONV3D: return <Box className="w-4 h-4" />;
|
| 24 |
case LayerType.CONV_TRANSPOSE2D: return <ArrowUpCircle className="w-4 h-4" />;
|
| 25 |
case LayerType.DEFORMABLE_CONV: return <Hexagon className="w-4 h-4" />;
|
|
|
|
|
|
|
| 26 |
case LayerType.LINEAR: return <Grid className="w-4 h-4" />;
|
|
|
|
|
|
|
| 27 |
case LayerType.RELU:
|
| 28 |
case LayerType.LEAKYRELU:
|
| 29 |
case LayerType.PRELU:
|
| 30 |
case LayerType.GELU:
|
| 31 |
case LayerType.SILU:
|
| 32 |
case LayerType.SWIGLU:
|
| 33 |
-
case LayerType.
|
| 34 |
-
case LayerType.
|
|
|
|
|
|
|
| 35 |
case LayerType.MAXPOOL:
|
| 36 |
case LayerType.MAXPOOL3D:
|
| 37 |
case LayerType.AVGPOOL:
|
|
@@ -40,6 +47,19 @@ const getIcon = (type: LayerType) => {
|
|
| 40 |
case LayerType.UPSAMPLE: return <Maximize className="w-4 h-4" />;
|
| 41 |
case LayerType.PIXEL_SHUFFLE: return <Scaling className="w-4 h-4" />;
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
// Transformer / GenAI
|
| 44 |
case LayerType.ATTENTION:
|
| 45 |
case LayerType.CROSS_ATTENTION:
|
|
@@ -152,12 +172,22 @@ const getIcon = (type: LayerType) => {
|
|
| 152 |
case LayerType.UNFLATTEN: return <Type className="w-4 h-4" />;
|
| 153 |
case LayerType.RESHAPE: return <FlipVertical className="w-4 h-4" />;
|
| 154 |
case LayerType.PERMUTE: return <RefreshCw className="w-4 h-4" />;
|
|
|
|
|
|
|
| 155 |
case LayerType.CONCAT:
|
| 156 |
-
case LayerType.ADD:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
|
|
|
| 158 |
case LayerType.CUSTOM: return <Terminal className="w-4 h-4" />;
|
| 159 |
case LayerType.IDENTITY:
|
| 160 |
case LayerType.DROPOUT:
|
|
|
|
| 161 |
case LayerType.DROPPATH: return <MinusCircle className="w-4 h-4" />;
|
| 162 |
|
| 163 |
default: return <Box className="w-4 h-4" />;
|
|
@@ -174,6 +204,8 @@ const getColor = (type: LayerType) => {
|
|
| 174 |
case LayerType.CONV3D:
|
| 175 |
case LayerType.CONV_TRANSPOSE2D:
|
| 176 |
case LayerType.DEFORMABLE_CONV:
|
|
|
|
|
|
|
| 177 |
case LayerType.MAXPOOL:
|
| 178 |
case LayerType.MAXPOOL3D:
|
| 179 |
case LayerType.AVGPOOL:
|
|
@@ -185,7 +217,20 @@ const getColor = (type: LayerType) => {
|
|
| 185 |
case LayerType.LINEAR:
|
| 186 |
case LayerType.ACTION_HEAD: return 'border-violet-500 shadow-violet-500/20';
|
| 187 |
|
| 188 |
-
//
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
case LayerType.ATTENTION:
|
| 190 |
case LayerType.CROSS_ATTENTION:
|
| 191 |
case LayerType.WINDOW_ATTENTION:
|
|
@@ -196,7 +241,7 @@ const getColor = (type: LayerType) => {
|
|
| 196 |
case LayerType.SE_BLOCK:
|
| 197 |
case LayerType.PATCH_EMBED:
|
| 198 |
case LayerType.SAM_PROMPT_ENCODER:
|
| 199 |
-
case LayerType.SAM_MASK_DECODER: return 'border-
|
| 200 |
|
| 201 |
case LayerType.LSTM:
|
| 202 |
case LayerType.GRU:
|
|
@@ -294,9 +339,17 @@ const getColor = (type: LayerType) => {
|
|
| 294 |
case LayerType.ECHO_STATE: return 'border-white shadow-white/30';
|
| 295 |
|
| 296 |
case LayerType.CONCAT:
|
| 297 |
-
case LayerType.ADD:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
|
|
|
| 299 |
case LayerType.CUSTOM: return 'border-lime-500 shadow-lime-500/20';
|
|
|
|
| 300 |
case LayerType.IDENTITY:
|
| 301 |
case LayerType.FLATTEN:
|
| 302 |
case LayerType.RESHAPE:
|
|
|
|
|
|
|
|
|
|
| 1 |
import React, { memo } from 'react';
|
| 2 |
import { Handle, Position, NodeProps } from 'reactflow';
|
| 3 |
import {
|
| 4 |
Layers, Box, Activity, Zap, ArrowRight, Grid, Minimize,
|
| 5 |
+
Database, GitBranch, Type, Combine, Maximize,
|
| 6 |
ArrowUpCircle, Sliders, RefreshCcw, Brain, Crosshair, Network, Clock, Eye, Workflow,
|
| 7 |
Terminal, MinusCircle, Scaling, BoxSelect, Wifi,
|
| 8 |
Mic, Speaker, Radio, Cuboid, Target, Scan,
|
| 9 |
AudioLines, FileAudio, Hexagon, Component,
|
| 10 |
ScanText, Mountain, Move, Radar, Map, Orbit, Wind,
|
| 11 |
Film, Video, FastForward, Timer, Clapperboard, Merge,
|
| 12 |
+
Share2, Atom, Dna, Gamepad2, Sparkles, FlipVertical, RefreshCw,
|
| 13 |
+
Scissors, Hash, Sigma, Calculator, BarChart3, Binary, X, Circle
|
| 14 |
} from 'lucide-react';
|
| 15 |
import { NodeData, LayerType } from '../types';
|
| 16 |
|
| 17 |
const getIcon = (type: LayerType) => {
|
| 18 |
switch (type) {
|
| 19 |
case LayerType.INPUT: return <ArrowRight className="w-4 h-4" />;
|
| 20 |
+
case LayerType.OUTPUT: return <Circle className="w-4 h-4" />;
|
| 21 |
+
|
| 22 |
case LayerType.CONV1D: return <Activity className="w-4 h-4" />;
|
| 23 |
case LayerType.CONV2D: return <Layers className="w-4 h-4" />;
|
| 24 |
case LayerType.CONV3D: return <Box className="w-4 h-4" />;
|
| 25 |
case LayerType.CONV_TRANSPOSE2D: return <ArrowUpCircle className="w-4 h-4" />;
|
| 26 |
case LayerType.DEFORMABLE_CONV: return <Hexagon className="w-4 h-4" />;
|
| 27 |
+
case LayerType.SEPARABLE_CONV2D: return <Layers className="w-4 h-4" />;
|
| 28 |
+
case LayerType.DEPTHWISE_CONV2D: return <Layers className="w-4 h-4" />;
|
| 29 |
case LayerType.LINEAR: return <Grid className="w-4 h-4" />;
|
| 30 |
+
|
| 31 |
+
case LayerType.SIGMOID: return <Sigma className="w-4 h-4" />;
|
| 32 |
case LayerType.RELU:
|
| 33 |
case LayerType.LEAKYRELU:
|
| 34 |
case LayerType.PRELU:
|
| 35 |
case LayerType.GELU:
|
| 36 |
case LayerType.SILU:
|
| 37 |
case LayerType.SWIGLU:
|
| 38 |
+
case LayerType.TANH:
|
| 39 |
+
case LayerType.SOFTPLUS:
|
| 40 |
+
case LayerType.SOFTSIGN: return <Zap className="w-4 h-4" />;
|
| 41 |
+
|
| 42 |
case LayerType.MAXPOOL:
|
| 43 |
case LayerType.MAXPOOL3D:
|
| 44 |
case LayerType.AVGPOOL:
|
|
|
|
| 47 |
case LayerType.UPSAMPLE: return <Maximize className="w-4 h-4" />;
|
| 48 |
case LayerType.PIXEL_SHUFFLE: return <Scaling className="w-4 h-4" />;
|
| 49 |
|
| 50 |
+
// TF / Preprocessing
|
| 51 |
+
case LayerType.RESCALING:
|
| 52 |
+
case LayerType.RESIZING: return <Scaling className="w-4 h-4" />;
|
| 53 |
+
case LayerType.CENTER_CROP: return <Scissors className="w-4 h-4" />;
|
| 54 |
+
case LayerType.RANDOM_FLIP:
|
| 55 |
+
case LayerType.RANDOM_ROTATION:
|
| 56 |
+
case LayerType.RANDOM_ZOOM:
|
| 57 |
+
case LayerType.RANDOM_CONTRAST: return <RefreshCw className="w-4 h-4" />;
|
| 58 |
+
case LayerType.TEXT_VECTORIZATION: return <Type className="w-4 h-4" />;
|
| 59 |
+
case LayerType.NORMALIZATION_LAYER: return <BarChart3 className="w-4 h-4" />;
|
| 60 |
+
case LayerType.DISCRETIZATION: return <Binary className="w-4 h-4" />;
|
| 61 |
+
case LayerType.CATEGORY_ENCODING: return <Hash className="w-4 h-4" />;
|
| 62 |
+
|
| 63 |
// Transformer / GenAI
|
| 64 |
case LayerType.ATTENTION:
|
| 65 |
case LayerType.CROSS_ATTENTION:
|
|
|
|
| 172 |
case LayerType.UNFLATTEN: return <Type className="w-4 h-4" />;
|
| 173 |
case LayerType.RESHAPE: return <FlipVertical className="w-4 h-4" />;
|
| 174 |
case LayerType.PERMUTE: return <RefreshCw className="w-4 h-4" />;
|
| 175 |
+
|
| 176 |
+
// Merge
|
| 177 |
case LayerType.CONCAT:
|
| 178 |
+
case LayerType.ADD:
|
| 179 |
+
case LayerType.SUBTRACT:
|
| 180 |
+
case LayerType.MULTIPLY:
|
| 181 |
+
case LayerType.AVERAGE:
|
| 182 |
+
case LayerType.MAXIMUM:
|
| 183 |
+
case LayerType.MINIMUM: return <Combine className="w-4 h-4" />;
|
| 184 |
+
case LayerType.DOT: return <X className="w-4 h-4" />;
|
| 185 |
|
| 186 |
+
case LayerType.LAMBDA: return <Calculator className="w-4 h-4" />;
|
| 187 |
case LayerType.CUSTOM: return <Terminal className="w-4 h-4" />;
|
| 188 |
case LayerType.IDENTITY:
|
| 189 |
case LayerType.DROPOUT:
|
| 190 |
+
case LayerType.SPATIAL_DROPOUT:
|
| 191 |
case LayerType.DROPPATH: return <MinusCircle className="w-4 h-4" />;
|
| 192 |
|
| 193 |
default: return <Box className="w-4 h-4" />;
|
|
|
|
| 204 |
case LayerType.CONV3D:
|
| 205 |
case LayerType.CONV_TRANSPOSE2D:
|
| 206 |
case LayerType.DEFORMABLE_CONV:
|
| 207 |
+
case LayerType.SEPARABLE_CONV2D:
|
| 208 |
+
case LayerType.DEPTHWISE_CONV2D:
|
| 209 |
case LayerType.MAXPOOL:
|
| 210 |
case LayerType.MAXPOOL3D:
|
| 211 |
case LayerType.AVGPOOL:
|
|
|
|
| 217 |
case LayerType.LINEAR:
|
| 218 |
case LayerType.ACTION_HEAD: return 'border-violet-500 shadow-violet-500/20';
|
| 219 |
|
| 220 |
+
// TF / Preprocessing (Orange/Amber)
|
| 221 |
+
case LayerType.RESCALING:
|
| 222 |
+
case LayerType.RESIZING:
|
| 223 |
+
case LayerType.CENTER_CROP:
|
| 224 |
+
case LayerType.RANDOM_FLIP:
|
| 225 |
+
case LayerType.RANDOM_ROTATION:
|
| 226 |
+
case LayerType.RANDOM_ZOOM:
|
| 227 |
+
case LayerType.RANDOM_CONTRAST:
|
| 228 |
+
case LayerType.TEXT_VECTORIZATION:
|
| 229 |
+
case LayerType.NORMALIZATION_LAYER:
|
| 230 |
+
case LayerType.DISCRETIZATION:
|
| 231 |
+
case LayerType.CATEGORY_ENCODING: return 'border-amber-600 shadow-amber-600/20';
|
| 232 |
+
|
| 233 |
+
// GenAI / Transformer - Gold
|
| 234 |
case LayerType.ATTENTION:
|
| 235 |
case LayerType.CROSS_ATTENTION:
|
| 236 |
case LayerType.WINDOW_ATTENTION:
|
|
|
|
| 241 |
case LayerType.SE_BLOCK:
|
| 242 |
case LayerType.PATCH_EMBED:
|
| 243 |
case LayerType.SAM_PROMPT_ENCODER:
|
| 244 |
+
case LayerType.SAM_MASK_DECODER: return 'border-yellow-500 shadow-yellow-500/20';
|
| 245 |
|
| 246 |
case LayerType.LSTM:
|
| 247 |
case LayerType.GRU:
|
|
|
|
| 339 |
case LayerType.ECHO_STATE: return 'border-white shadow-white/30';
|
| 340 |
|
| 341 |
case LayerType.CONCAT:
|
| 342 |
+
case LayerType.ADD:
|
| 343 |
+
case LayerType.SUBTRACT:
|
| 344 |
+
case LayerType.MULTIPLY:
|
| 345 |
+
case LayerType.AVERAGE:
|
| 346 |
+
case LayerType.MAXIMUM:
|
| 347 |
+
case LayerType.MINIMUM:
|
| 348 |
+
case LayerType.DOT: return 'border-pink-500 shadow-pink-500/20';
|
| 349 |
|
| 350 |
+
case LayerType.LAMBDA: return 'border-lime-500 shadow-lime-500/20';
|
| 351 |
case LayerType.CUSTOM: return 'border-lime-500 shadow-lime-500/20';
|
| 352 |
+
|
| 353 |
case LayerType.IDENTITY:
|
| 354 |
case LayerType.FLATTEN:
|
| 355 |
case LayerType.RESHAPE:
|