Spaces:
Running
Running
| import React, { memo } from 'react'; | |
| import { Handle, Position, NodeProps } from 'reactflow'; | |
| import { | |
| Layers, Box, Activity, Zap, ArrowRight, Grid, Minimize, | |
| Database, GitBranch, Type, Combine, Maximize, | |
| ArrowUpCircle, Sliders, RefreshCcw, Brain, Crosshair, Network, Clock, Eye, Workflow, | |
| Terminal, MinusCircle, Scaling, BoxSelect, Wifi, | |
| Mic, Speaker, Radio, Cuboid, Target, Scan, | |
| AudioLines, FileAudio, Hexagon, Component, | |
| ScanText, Mountain, Move, Radar, Map, Orbit, Wind, | |
| Film, Video, FastForward, Timer, Clapperboard, Merge, | |
| Share2, Atom, Dna, Gamepad2, Sparkles, FlipVertical, RefreshCw, | |
| Scissors, Hash, Sigma, Calculator, BarChart3, Binary, X, Circle | |
| } from 'lucide-react'; | |
| import { NodeData, LayerType } from '../types'; | |
| const getIcon = (type: LayerType) => { | |
| switch (type) { | |
| case LayerType.INPUT: return <ArrowRight className="w-4 h-4" />; | |
| case LayerType.OUTPUT: return <Circle className="w-4 h-4" />; | |
| case LayerType.CONV1D: return <Activity className="w-4 h-4" />; | |
| case LayerType.CONV2D: return <Layers className="w-4 h-4" />; | |
| case LayerType.CONV3D: return <Box className="w-4 h-4" />; | |
| case LayerType.CONV_TRANSPOSE2D: return <ArrowUpCircle className="w-4 h-4" />; | |
| case LayerType.DEFORMABLE_CONV: return <Hexagon className="w-4 h-4" />; | |
| case LayerType.SEPARABLE_CONV2D: return <Layers className="w-4 h-4" />; | |
| case LayerType.DEPTHWISE_CONV2D: return <Layers className="w-4 h-4" />; | |
| case LayerType.LINEAR: return <Grid className="w-4 h-4" />; | |
| case LayerType.SIGMOID: return <Sigma className="w-4 h-4" />; | |
| case LayerType.RELU: | |
| case LayerType.LEAKYRELU: | |
| case LayerType.PRELU: | |
| case LayerType.GELU: | |
| case LayerType.SILU: | |
| case LayerType.SWIGLU: | |
| case LayerType.TANH: | |
| case LayerType.SOFTPLUS: | |
| case LayerType.SOFTSIGN: return <Zap className="w-4 h-4" />; | |
| case LayerType.MAXPOOL: | |
| case LayerType.MAXPOOL3D: | |
| case LayerType.AVGPOOL: | |
| case LayerType.ADAPTIVEAVGPOOL: return <Minimize className="w-4 h-4" />; | |
| case LayerType.GLOBAL_AVG_POOL: return <BoxSelect className="w-4 h-4" />; | |
| case LayerType.UPSAMPLE: return <Maximize className="w-4 h-4" />; | |
| case LayerType.PIXEL_SHUFFLE: return <Scaling className="w-4 h-4" />; | |
| // TF / Preprocessing | |
| case LayerType.RESCALING: | |
| case LayerType.RESIZING: return <Scaling className="w-4 h-4" />; | |
| case LayerType.CENTER_CROP: return <Scissors className="w-4 h-4" />; | |
| case LayerType.RANDOM_FLIP: | |
| case LayerType.RANDOM_ROTATION: | |
| case LayerType.RANDOM_ZOOM: | |
| case LayerType.RANDOM_CONTRAST: return <RefreshCw className="w-4 h-4" />; | |
| case LayerType.TEXT_VECTORIZATION: return <Type className="w-4 h-4" />; | |
| case LayerType.NORMALIZATION_LAYER: return <BarChart3 className="w-4 h-4" />; | |
| case LayerType.DISCRETIZATION: return <Binary className="w-4 h-4" />; | |
| case LayerType.CATEGORY_ENCODING: return <Hash className="w-4 h-4" />; | |
| // Transformer / GenAI | |
| case LayerType.ATTENTION: | |
| case LayerType.CROSS_ATTENTION: | |
| case LayerType.WINDOW_ATTENTION: return <RefreshCcw className="w-4 h-4" />; | |
| case LayerType.TRANSFORMER_BLOCK: | |
| case LayerType.TRANSFORMER_ENCODER: | |
| case LayerType.TRANSFORMER_DECODER: return <Brain className="w-4 h-4" />; | |
| case LayerType.MOE_BLOCK: return <Network className="w-4 h-4" />; | |
| case LayerType.ACTION_HEAD: return <Crosshair className="w-4 h-4" />; | |
| case LayerType.SE_BLOCK: return <Wifi className="w-4 h-4" />; | |
| case LayerType.PATCH_EMBED: | |
| case LayerType.SAM_PROMPT_ENCODER: return <Eye className="w-4 h-4" />; | |
| case LayerType.SAM_MASK_DECODER: return <Workflow className="w-4 h-4" />; | |
| case LayerType.TIME_EMBEDDING: | |
| case LayerType.ROPE: return <Clock className="w-4 h-4" />; | |
| // Audio | |
| case LayerType.MEL_SPECTROGRAM: | |
| case LayerType.STFT: return <Radio className="w-4 h-4" />; | |
| case LayerType.SPEC_AUGMENT: return <AudioLines className="w-4 h-4" />; | |
| case LayerType.CONFORMER_BLOCK: | |
| case LayerType.WAVENET_BLOCK: return <Activity className="w-4 h-4" />; | |
| case LayerType.RVC_ENCODER: | |
| case LayerType.WAV2VEC2_ENC: return <FileAudio className="w-4 h-4" />; | |
| case LayerType.VOCODER: return <Speaker className="w-4 h-4" />; | |
| case LayerType.AUDIO_EMBEDDING: return <Mic className="w-4 h-4" />; | |
| case LayerType.SINC_CONV: return <Activity className="w-4 h-4" />; | |
| // Detection / YOLO | |
| case LayerType.C2F_BLOCK: | |
| case LayerType.SPPF_BLOCK: | |
| case LayerType.DARKNET_BLOCK: return <Layers className="w-4 h-4" />; | |
| case LayerType.DETECT_HEAD: | |
| case LayerType.ANCHOR_BOX: return <Target className="w-4 h-4" />; | |
| case LayerType.NMS: return <Component className="w-4 h-4" />; | |
| // 3D | |
| case LayerType.NERF_BLOCK: | |
| case LayerType.GAUSSIAN_SPLAT: | |
| case LayerType.TRIPLANE_ENC: return <Cuboid className="w-4 h-4" />; | |
| case LayerType.POINTNET_BLOCK: | |
| case LayerType.POINT_TRANSFORMER: | |
| case LayerType.MESH_CONV: return <Scan className="w-4 h-4" />; | |
| // OCR | |
| case LayerType.TPS_TRANSFORM: | |
| case LayerType.CRNN_BLOCK: | |
| case LayerType.CTC_DECODER: return <ScanText className="w-4 h-4" />; | |
| // Robotics / Motion | |
| case LayerType.DEPTH_DECODER: | |
| case LayerType.DISPARITY_HEAD: return <Mountain className="w-4 h-4" />; | |
| case LayerType.OPTICAL_FLOW: | |
| case LayerType.VELOCITY_HEAD: return <Wind className="w-4 h-4" />; | |
| case LayerType.KALMAN_FILTER: return <Orbit className="w-4 h-4" />; | |
| case LayerType.BEV_TRANSFORM: return <Map className="w-4 h-4" />; | |
| case LayerType.RADAR_ENCODER: return <Radar className="w-4 h-4" />; | |
| // Video / Generation | |
| case LayerType.VIDEO_DIFFUSION_BLOCK: return <Film className="w-4 h-4" />; | |
| case LayerType.SPATIO_TEMPORAL_ATTN: return <Timer className="w-4 h-4" />; | |
| case LayerType.VIDEO_TOKENIZER: return <Video className="w-4 h-4" />; | |
| case LayerType.FRAME_INTERPOLATOR: return <FastForward className="w-4 h-4" />; | |
| case LayerType.TEMPORAL_SHIFT: return <Move className="w-4 h-4" />; | |
| case LayerType.NON_LOCAL_BLOCK: return <Clapperboard className="w-4 h-4" />; | |
| case LayerType.MULTIMODAL_FUSION: return <Merge className="w-4 h-4" />; | |
| // Graph | |
| case LayerType.GCN_CONV: | |
| case LayerType.GRAPH_SAGE: | |
| case LayerType.GAT_CONV: | |
| case LayerType.GIN_CONV: return <Share2 className="w-4 h-4" />; | |
| // Physics | |
| case LayerType.NEURAL_ODE: | |
| case LayerType.HAMILTONIAN_NN: | |
| case LayerType.PINN_LINEAR: return <Atom className="w-4 h-4" />; | |
| case LayerType.PROTEIN_FOLDING: return <Dna className="w-4 h-4" />; | |
| // RL | |
| case LayerType.DUELING_HEAD: | |
| case LayerType.PPO_HEAD: | |
| case LayerType.SAC_HEAD: return <Gamepad2 className="w-4 h-4" />; | |
| // Spiking | |
| case LayerType.LIF_NEURON: | |
| case LayerType.SPIKING_LAYER: return <Zap className="w-4 h-4" />; | |
| // Advanced / Niche | |
| case LayerType.CAPSULE: | |
| case LayerType.HYPER_NET: | |
| case LayerType.MAMBA_BLOCK: | |
| case LayerType.RWKV_BLOCK: | |
| case LayerType.HOPFIELD: | |
| case LayerType.NORMALIZING_FLOW: | |
| case LayerType.DNC_MEMORY: | |
| case LayerType.ARCFACE: | |
| case LayerType.ECHO_STATE: return <Sparkles className="w-4 h-4" />; | |
| case LayerType.LSTM: | |
| case LayerType.GRU: return <GitBranch className="w-4 h-4" />; | |
| case LayerType.EMBEDDING: return <Database className="w-4 h-4" />; | |
| case LayerType.LAYERNORM: | |
| case LayerType.BATCHNORM: | |
| case LayerType.GROUPNORM: | |
| case LayerType.INSTANCENORM: | |
| case LayerType.RMSNORM: return <Sliders className="w-4 h-4" />; | |
| case LayerType.FLATTEN: | |
| case LayerType.UNFLATTEN: return <Type className="w-4 h-4" />; | |
| case LayerType.RESHAPE: return <FlipVertical className="w-4 h-4" />; | |
| case LayerType.PERMUTE: return <RefreshCw className="w-4 h-4" />; | |
| // Merge | |
| case LayerType.CONCAT: | |
| case LayerType.ADD: | |
| case LayerType.SUBTRACT: | |
| case LayerType.MULTIPLY: | |
| case LayerType.AVERAGE: | |
| case LayerType.MAXIMUM: | |
| case LayerType.MINIMUM: return <Combine className="w-4 h-4" />; | |
| case LayerType.DOT: return <X className="w-4 h-4" />; | |
| case LayerType.LAMBDA: return <Calculator className="w-4 h-4" />; | |
| case LayerType.CUSTOM: return <Terminal className="w-4 h-4" />; | |
| case LayerType.IDENTITY: | |
| case LayerType.DROPOUT: | |
| case LayerType.SPATIAL_DROPOUT: | |
| case LayerType.DROPPATH: return <MinusCircle className="w-4 h-4" />; | |
| default: return <Box className="w-4 h-4" />; | |
| } | |
| }; | |
| const getColor = (type: LayerType) => { | |
| switch (type) { | |
| case LayerType.INPUT: return 'border-emerald-500 shadow-emerald-500/20'; | |
| case LayerType.OUTPUT: return 'border-red-500 shadow-red-500/20'; | |
| case LayerType.CONV2D: | |
| case LayerType.CONV1D: | |
| case LayerType.CONV3D: | |
| case LayerType.CONV_TRANSPOSE2D: | |
| case LayerType.DEFORMABLE_CONV: | |
| case LayerType.SEPARABLE_CONV2D: | |
| case LayerType.DEPTHWISE_CONV2D: | |
| case LayerType.MAXPOOL: | |
| case LayerType.MAXPOOL3D: | |
| case LayerType.AVGPOOL: | |
| case LayerType.GLOBAL_AVG_POOL: | |
| case LayerType.ADAPTIVEAVGPOOL: | |
| case LayerType.UPSAMPLE: | |
| case LayerType.PIXEL_SHUFFLE: return 'border-blue-500 shadow-blue-500/20'; | |
| case LayerType.LINEAR: | |
| case LayerType.ACTION_HEAD: return 'border-violet-500 shadow-violet-500/20'; | |
| // TF / Preprocessing (Orange/Amber) | |
| case LayerType.RESCALING: | |
| case LayerType.RESIZING: | |
| case LayerType.CENTER_CROP: | |
| case LayerType.RANDOM_FLIP: | |
| case LayerType.RANDOM_ROTATION: | |
| case LayerType.RANDOM_ZOOM: | |
| case LayerType.RANDOM_CONTRAST: | |
| case LayerType.TEXT_VECTORIZATION: | |
| case LayerType.NORMALIZATION_LAYER: | |
| case LayerType.DISCRETIZATION: | |
| case LayerType.CATEGORY_ENCODING: return 'border-amber-600 shadow-amber-600/20'; | |
| // GenAI / Transformer - Gold | |
| case LayerType.ATTENTION: | |
| case LayerType.CROSS_ATTENTION: | |
| case LayerType.WINDOW_ATTENTION: | |
| case LayerType.TRANSFORMER_BLOCK: | |
| case LayerType.TRANSFORMER_ENCODER: | |
| case LayerType.TRANSFORMER_DECODER: | |
| case LayerType.MOE_BLOCK: | |
| case LayerType.SE_BLOCK: | |
| case LayerType.PATCH_EMBED: | |
| case LayerType.SAM_PROMPT_ENCODER: | |
| case LayerType.SAM_MASK_DECODER: return 'border-yellow-500 shadow-yellow-500/20'; | |
| case LayerType.LSTM: | |
| case LayerType.GRU: | |
| case LayerType.EMBEDDING: | |
| case LayerType.TIME_EMBEDDING: | |
| case LayerType.ROPE: return 'border-orange-500 shadow-orange-500/20'; | |
| case LayerType.BATCHNORM: | |
| case LayerType.LAYERNORM: | |
| case LayerType.GROUPNORM: | |
| case LayerType.INSTANCENORM: | |
| case LayerType.RMSNORM: return 'border-cyan-500 shadow-cyan-500/20'; | |
| // Detection - Rose/Red theme | |
| case LayerType.C2F_BLOCK: | |
| case LayerType.SPPF_BLOCK: | |
| case LayerType.DARKNET_BLOCK: | |
| case LayerType.DETECT_HEAD: | |
| case LayerType.ANCHOR_BOX: | |
| case LayerType.NMS: return 'border-rose-500 shadow-rose-500/20'; | |
| // Audio - Indigo theme | |
| case LayerType.MEL_SPECTROGRAM: | |
| case LayerType.STFT: | |
| case LayerType.SPEC_AUGMENT: | |
| case LayerType.CONFORMER_BLOCK: | |
| case LayerType.WAVENET_BLOCK: | |
| case LayerType.VOCODER: | |
| case LayerType.RVC_ENCODER: | |
| case LayerType.WAV2VEC2_ENC: | |
| case LayerType.AUDIO_EMBEDDING: | |
| case LayerType.SINC_CONV: return 'border-indigo-500 shadow-indigo-500/20'; | |
| // 3D - Teal theme | |
| case LayerType.NERF_BLOCK: | |
| case LayerType.POINTNET_BLOCK: | |
| case LayerType.POINT_TRANSFORMER: | |
| case LayerType.TRIPLANE_ENC: | |
| case LayerType.GAUSSIAN_SPLAT: | |
| case LayerType.MESH_CONV: return 'border-teal-500 shadow-teal-500/20'; | |
| // OCR - Yellow/Lime | |
| case LayerType.TPS_TRANSFORM: | |
| case LayerType.CRNN_BLOCK: | |
| case LayerType.CTC_DECODER: return 'border-yellow-500 shadow-yellow-500/20'; | |
| // Robotics/Motion - Cyan/Sky | |
| case LayerType.DEPTH_DECODER: | |
| case LayerType.DISPARITY_HEAD: | |
| case LayerType.OPTICAL_FLOW: | |
| case LayerType.VELOCITY_HEAD: | |
| case LayerType.KALMAN_FILTER: | |
| case LayerType.BEV_TRANSFORM: | |
| case LayerType.RADAR_ENCODER: return 'border-sky-500 shadow-sky-500/20'; | |
| // Video - Fuchsia | |
| case LayerType.VIDEO_DIFFUSION_BLOCK: | |
| case LayerType.SPATIO_TEMPORAL_ATTN: | |
| case LayerType.VIDEO_TOKENIZER: | |
| case LayerType.FRAME_INTERPOLATOR: | |
| case LayerType.TEMPORAL_SHIFT: | |
| case LayerType.NON_LOCAL_BLOCK: | |
| case LayerType.MULTIMODAL_FUSION: return 'border-fuchsia-500 shadow-fuchsia-500/20'; | |
| // Graph - Green | |
| case LayerType.GCN_CONV: | |
| case LayerType.GRAPH_SAGE: | |
| case LayerType.GAT_CONV: | |
| case LayerType.GIN_CONV: return 'border-green-500 shadow-green-500/20'; | |
| // Physics - Blue/Gray | |
| case LayerType.NEURAL_ODE: | |
| case LayerType.PINN_LINEAR: | |
| case LayerType.HAMILTONIAN_NN: | |
| case LayerType.PROTEIN_FOLDING: return 'border-blue-400 shadow-blue-400/20'; | |
| // Spiking - Electric Yellow | |
| case LayerType.LIF_NEURON: | |
| case LayerType.SPIKING_LAYER: return 'border-yellow-400 shadow-yellow-400/20'; | |
| // RL - Purple/Red | |
| case LayerType.DUELING_HEAD: | |
| case LayerType.PPO_HEAD: | |
| case LayerType.SAC_HEAD: return 'border-purple-600 shadow-purple-600/20'; | |
| // Advanced - Multi/Rainbow feel (White glow) | |
| case LayerType.CAPSULE: | |
| case LayerType.HYPER_NET: | |
| case LayerType.MAMBA_BLOCK: | |
| case LayerType.RWKV_BLOCK: | |
| case LayerType.HOPFIELD: | |
| case LayerType.NORMALIZING_FLOW: | |
| case LayerType.DNC_MEMORY: | |
| case LayerType.ARCFACE: | |
| case LayerType.ECHO_STATE: return 'border-white shadow-white/30'; | |
| case LayerType.CONCAT: | |
| case LayerType.ADD: | |
| case LayerType.SUBTRACT: | |
| case LayerType.MULTIPLY: | |
| case LayerType.AVERAGE: | |
| case LayerType.MAXIMUM: | |
| case LayerType.MINIMUM: | |
| case LayerType.DOT: return 'border-pink-500 shadow-pink-500/20'; | |
| case LayerType.LAMBDA: return 'border-lime-500 shadow-lime-500/20'; | |
| case LayerType.CUSTOM: return 'border-lime-500 shadow-lime-500/20'; | |
| case LayerType.IDENTITY: | |
| case LayerType.FLATTEN: | |
| case LayerType.RESHAPE: | |
| case LayerType.PERMUTE: | |
| case LayerType.UNFLATTEN: return 'border-slate-600 shadow-slate-600/20'; | |
| default: return 'border-slate-500 shadow-slate-500/20'; | |
| } | |
| }; | |
| const CustomNode = ({ data, selected }: NodeProps<NodeData>) => { | |
| const isInput = data.type === LayerType.INPUT; | |
| const isOutput = data.type === LayerType.OUTPUT; | |
| // Render input modality badge | |
| const inputBadge = isInput && data.params.modality ? ( | |
| <span className="ml-auto text-[9px] px-1.5 py-0.5 rounded-full bg-slate-700 text-emerald-400 font-mono tracking-tighter uppercase"> | |
| {data.params.modality} | |
| </span> | |
| ) : null; | |
| // Custom layer name badge | |
| const customBadge = data.type === LayerType.CUSTOM && data.params.class_name ? ( | |
| <span className="ml-auto text-[8px] px-1.5 py-0.5 rounded-full bg-slate-700 text-lime-400 font-mono truncate max-w-[80px]"> | |
| {data.params.class_name.split('.').pop()} | |
| </span> | |
| ) : null; | |
| return ( | |
| <div className={` | |
| relative min-w-[150px] bg-slate-900 rounded-lg border-2 | |
| transition-all duration-200 shadow-lg | |
| ${getColor(data.type)} | |
| ${selected ? 'ring-2 ring-white ring-offset-2 ring-offset-slate-900 scale-105' : ''} | |
| `}> | |
| {/* Input Handle */} | |
| {!isInput && ( | |
| <Handle | |
| type="target" | |
| position={Position.Top} | |
| className="!bg-slate-400 !w-3 !h-3 !-top-2 !border-slate-900" | |
| /> | |
| )} | |
| {/* Header */} | |
| <div className="flex items-center gap-2 px-3 py-2 border-b border-slate-800 bg-slate-800/50 rounded-t-md"> | |
| <span className="text-slate-300"> | |
| {getIcon(data.type)} | |
| </span> | |
| <span className="text-xs font-bold uppercase tracking-wider text-slate-200 truncate max-w-[100px]"> | |
| {data.label} | |
| </span> | |
| {inputBadge} | |
| {customBadge} | |
| </div> | |
| {/* Body - Parameters Summary */} | |
| <div className="p-3"> | |
| <div className="text-[10px] text-slate-400 font-mono space-y-1"> | |
| {Object.entries(data.params).slice(0, 3).map(([key, value]) => { | |
| // Hide large code blocks in summary | |
| if (key === 'definition_code' || key === 'imports') return null; | |
| return ( | |
| <div key={key} className="flex justify-between gap-4"> | |
| <span className="opacity-70 truncate max-w-[80px]">{key}:</span> | |
| <span className="text-slate-200 truncate max-w-[80px]">{String(value)}</span> | |
| </div> | |
| ); | |
| })} | |
| {Object.keys(data.params).length === 0 && ( | |
| <span className="opacity-50 italic">No parameters</span> | |
| )} | |
| </div> | |
| </div> | |
| {/* Output Handle */} | |
| {!isOutput && ( | |
| <Handle | |
| type="source" | |
| position={Position.Bottom} | |
| className="!bg-slate-400 !w-3 !h-3 !-bottom-2 !border-slate-900" | |
| /> | |
| )} | |
| </div> | |
| ); | |
| }; | |
| export default memo(CustomNode); |