testarcbuilder / components /CustomNode.tsx
wuhp's picture
Update components/CustomNode.tsx
1ec8015 verified
import React, { memo } from 'react';
import { Handle, Position, NodeProps } from 'reactflow';
import {
Layers, Box, Activity, Zap, ArrowRight, Grid, Minimize,
Database, GitBranch, Type, Combine, Maximize,
ArrowUpCircle, Sliders, RefreshCcw, Brain, Crosshair, Network, Clock, Eye, Workflow,
Terminal, MinusCircle, Scaling, BoxSelect, Wifi,
Mic, Speaker, Radio, Cuboid, Target, Scan,
AudioLines, FileAudio, Hexagon, Component,
ScanText, Mountain, Move, Radar, Map, Orbit, Wind,
Film, Video, FastForward, Timer, Clapperboard, Merge,
Share2, Atom, Dna, Gamepad2, Sparkles, FlipVertical, RefreshCw,
Scissors, Hash, Sigma, Calculator, BarChart3, Binary, X, Circle
} from 'lucide-react';
import { NodeData, LayerType } from '../types';
const getIcon = (type: LayerType) => {
switch (type) {
case LayerType.INPUT: return <ArrowRight className="w-4 h-4" />;
case LayerType.OUTPUT: return <Circle className="w-4 h-4" />;
case LayerType.CONV1D: return <Activity className="w-4 h-4" />;
case LayerType.CONV2D: return <Layers className="w-4 h-4" />;
case LayerType.CONV3D: return <Box className="w-4 h-4" />;
case LayerType.CONV_TRANSPOSE2D: return <ArrowUpCircle className="w-4 h-4" />;
case LayerType.DEFORMABLE_CONV: return <Hexagon className="w-4 h-4" />;
case LayerType.SEPARABLE_CONV2D: return <Layers className="w-4 h-4" />;
case LayerType.DEPTHWISE_CONV2D: return <Layers className="w-4 h-4" />;
case LayerType.LINEAR: return <Grid className="w-4 h-4" />;
case LayerType.SIGMOID: return <Sigma className="w-4 h-4" />;
case LayerType.RELU:
case LayerType.LEAKYRELU:
case LayerType.PRELU:
case LayerType.GELU:
case LayerType.SILU:
case LayerType.SWIGLU:
case LayerType.TANH:
case LayerType.SOFTPLUS:
case LayerType.SOFTSIGN: return <Zap className="w-4 h-4" />;
case LayerType.MAXPOOL:
case LayerType.MAXPOOL3D:
case LayerType.AVGPOOL:
case LayerType.ADAPTIVEAVGPOOL: return <Minimize className="w-4 h-4" />;
case LayerType.GLOBAL_AVG_POOL: return <BoxSelect className="w-4 h-4" />;
case LayerType.UPSAMPLE: return <Maximize className="w-4 h-4" />;
case LayerType.PIXEL_SHUFFLE: return <Scaling className="w-4 h-4" />;
// TF / Preprocessing
case LayerType.RESCALING:
case LayerType.RESIZING: return <Scaling className="w-4 h-4" />;
case LayerType.CENTER_CROP: return <Scissors className="w-4 h-4" />;
case LayerType.RANDOM_FLIP:
case LayerType.RANDOM_ROTATION:
case LayerType.RANDOM_ZOOM:
case LayerType.RANDOM_CONTRAST: return <RefreshCw className="w-4 h-4" />;
case LayerType.TEXT_VECTORIZATION: return <Type className="w-4 h-4" />;
case LayerType.NORMALIZATION_LAYER: return <BarChart3 className="w-4 h-4" />;
case LayerType.DISCRETIZATION: return <Binary className="w-4 h-4" />;
case LayerType.CATEGORY_ENCODING: return <Hash className="w-4 h-4" />;
// Transformer / GenAI
case LayerType.ATTENTION:
case LayerType.CROSS_ATTENTION:
case LayerType.WINDOW_ATTENTION: return <RefreshCcw className="w-4 h-4" />;
case LayerType.TRANSFORMER_BLOCK:
case LayerType.TRANSFORMER_ENCODER:
case LayerType.TRANSFORMER_DECODER: return <Brain className="w-4 h-4" />;
case LayerType.MOE_BLOCK: return <Network className="w-4 h-4" />;
case LayerType.ACTION_HEAD: return <Crosshair className="w-4 h-4" />;
case LayerType.SE_BLOCK: return <Wifi className="w-4 h-4" />;
case LayerType.PATCH_EMBED:
case LayerType.SAM_PROMPT_ENCODER: return <Eye className="w-4 h-4" />;
case LayerType.SAM_MASK_DECODER: return <Workflow className="w-4 h-4" />;
case LayerType.TIME_EMBEDDING:
case LayerType.ROPE: return <Clock className="w-4 h-4" />;
// Audio
case LayerType.MEL_SPECTROGRAM:
case LayerType.STFT: return <Radio className="w-4 h-4" />;
case LayerType.SPEC_AUGMENT: return <AudioLines className="w-4 h-4" />;
case LayerType.CONFORMER_BLOCK:
case LayerType.WAVENET_BLOCK: return <Activity className="w-4 h-4" />;
case LayerType.RVC_ENCODER:
case LayerType.WAV2VEC2_ENC: return <FileAudio className="w-4 h-4" />;
case LayerType.VOCODER: return <Speaker className="w-4 h-4" />;
case LayerType.AUDIO_EMBEDDING: return <Mic className="w-4 h-4" />;
case LayerType.SINC_CONV: return <Activity className="w-4 h-4" />;
// Detection / YOLO
case LayerType.C2F_BLOCK:
case LayerType.SPPF_BLOCK:
case LayerType.DARKNET_BLOCK: return <Layers className="w-4 h-4" />;
case LayerType.DETECT_HEAD:
case LayerType.ANCHOR_BOX: return <Target className="w-4 h-4" />;
case LayerType.NMS: return <Component className="w-4 h-4" />;
// 3D
case LayerType.NERF_BLOCK:
case LayerType.GAUSSIAN_SPLAT:
case LayerType.TRIPLANE_ENC: return <Cuboid className="w-4 h-4" />;
case LayerType.POINTNET_BLOCK:
case LayerType.POINT_TRANSFORMER:
case LayerType.MESH_CONV: return <Scan className="w-4 h-4" />;
// OCR
case LayerType.TPS_TRANSFORM:
case LayerType.CRNN_BLOCK:
case LayerType.CTC_DECODER: return <ScanText className="w-4 h-4" />;
// Robotics / Motion
case LayerType.DEPTH_DECODER:
case LayerType.DISPARITY_HEAD: return <Mountain className="w-4 h-4" />;
case LayerType.OPTICAL_FLOW:
case LayerType.VELOCITY_HEAD: return <Wind className="w-4 h-4" />;
case LayerType.KALMAN_FILTER: return <Orbit className="w-4 h-4" />;
case LayerType.BEV_TRANSFORM: return <Map className="w-4 h-4" />;
case LayerType.RADAR_ENCODER: return <Radar className="w-4 h-4" />;
// Video / Generation
case LayerType.VIDEO_DIFFUSION_BLOCK: return <Film className="w-4 h-4" />;
case LayerType.SPATIO_TEMPORAL_ATTN: return <Timer className="w-4 h-4" />;
case LayerType.VIDEO_TOKENIZER: return <Video className="w-4 h-4" />;
case LayerType.FRAME_INTERPOLATOR: return <FastForward className="w-4 h-4" />;
case LayerType.TEMPORAL_SHIFT: return <Move className="w-4 h-4" />;
case LayerType.NON_LOCAL_BLOCK: return <Clapperboard className="w-4 h-4" />;
case LayerType.MULTIMODAL_FUSION: return <Merge className="w-4 h-4" />;
// Graph
case LayerType.GCN_CONV:
case LayerType.GRAPH_SAGE:
case LayerType.GAT_CONV:
case LayerType.GIN_CONV: return <Share2 className="w-4 h-4" />;
// Physics
case LayerType.NEURAL_ODE:
case LayerType.HAMILTONIAN_NN:
case LayerType.PINN_LINEAR: return <Atom className="w-4 h-4" />;
case LayerType.PROTEIN_FOLDING: return <Dna className="w-4 h-4" />;
// RL
case LayerType.DUELING_HEAD:
case LayerType.PPO_HEAD:
case LayerType.SAC_HEAD: return <Gamepad2 className="w-4 h-4" />;
// Spiking
case LayerType.LIF_NEURON:
case LayerType.SPIKING_LAYER: return <Zap className="w-4 h-4" />;
// Advanced / Niche
case LayerType.CAPSULE:
case LayerType.HYPER_NET:
case LayerType.MAMBA_BLOCK:
case LayerType.RWKV_BLOCK:
case LayerType.HOPFIELD:
case LayerType.NORMALIZING_FLOW:
case LayerType.DNC_MEMORY:
case LayerType.ARCFACE:
case LayerType.ECHO_STATE: return <Sparkles className="w-4 h-4" />;
case LayerType.LSTM:
case LayerType.GRU: return <GitBranch className="w-4 h-4" />;
case LayerType.EMBEDDING: return <Database className="w-4 h-4" />;
case LayerType.LAYERNORM:
case LayerType.BATCHNORM:
case LayerType.GROUPNORM:
case LayerType.INSTANCENORM:
case LayerType.RMSNORM: return <Sliders className="w-4 h-4" />;
case LayerType.FLATTEN:
case LayerType.UNFLATTEN: return <Type className="w-4 h-4" />;
case LayerType.RESHAPE: return <FlipVertical className="w-4 h-4" />;
case LayerType.PERMUTE: return <RefreshCw className="w-4 h-4" />;
// Merge
case LayerType.CONCAT:
case LayerType.ADD:
case LayerType.SUBTRACT:
case LayerType.MULTIPLY:
case LayerType.AVERAGE:
case LayerType.MAXIMUM:
case LayerType.MINIMUM: return <Combine className="w-4 h-4" />;
case LayerType.DOT: return <X className="w-4 h-4" />;
case LayerType.LAMBDA: return <Calculator className="w-4 h-4" />;
case LayerType.CUSTOM: return <Terminal className="w-4 h-4" />;
case LayerType.IDENTITY:
case LayerType.DROPOUT:
case LayerType.SPATIAL_DROPOUT:
case LayerType.DROPPATH: return <MinusCircle className="w-4 h-4" />;
default: return <Box className="w-4 h-4" />;
}
};
const getColor = (type: LayerType) => {
switch (type) {
case LayerType.INPUT: return 'border-emerald-500 shadow-emerald-500/20';
case LayerType.OUTPUT: return 'border-red-500 shadow-red-500/20';
case LayerType.CONV2D:
case LayerType.CONV1D:
case LayerType.CONV3D:
case LayerType.CONV_TRANSPOSE2D:
case LayerType.DEFORMABLE_CONV:
case LayerType.SEPARABLE_CONV2D:
case LayerType.DEPTHWISE_CONV2D:
case LayerType.MAXPOOL:
case LayerType.MAXPOOL3D:
case LayerType.AVGPOOL:
case LayerType.GLOBAL_AVG_POOL:
case LayerType.ADAPTIVEAVGPOOL:
case LayerType.UPSAMPLE:
case LayerType.PIXEL_SHUFFLE: return 'border-blue-500 shadow-blue-500/20';
case LayerType.LINEAR:
case LayerType.ACTION_HEAD: return 'border-violet-500 shadow-violet-500/20';
// TF / Preprocessing (Orange/Amber)
case LayerType.RESCALING:
case LayerType.RESIZING:
case LayerType.CENTER_CROP:
case LayerType.RANDOM_FLIP:
case LayerType.RANDOM_ROTATION:
case LayerType.RANDOM_ZOOM:
case LayerType.RANDOM_CONTRAST:
case LayerType.TEXT_VECTORIZATION:
case LayerType.NORMALIZATION_LAYER:
case LayerType.DISCRETIZATION:
case LayerType.CATEGORY_ENCODING: return 'border-amber-600 shadow-amber-600/20';
// GenAI / Transformer - Gold
case LayerType.ATTENTION:
case LayerType.CROSS_ATTENTION:
case LayerType.WINDOW_ATTENTION:
case LayerType.TRANSFORMER_BLOCK:
case LayerType.TRANSFORMER_ENCODER:
case LayerType.TRANSFORMER_DECODER:
case LayerType.MOE_BLOCK:
case LayerType.SE_BLOCK:
case LayerType.PATCH_EMBED:
case LayerType.SAM_PROMPT_ENCODER:
case LayerType.SAM_MASK_DECODER: return 'border-yellow-500 shadow-yellow-500/20';
case LayerType.LSTM:
case LayerType.GRU:
case LayerType.EMBEDDING:
case LayerType.TIME_EMBEDDING:
case LayerType.ROPE: return 'border-orange-500 shadow-orange-500/20';
case LayerType.BATCHNORM:
case LayerType.LAYERNORM:
case LayerType.GROUPNORM:
case LayerType.INSTANCENORM:
case LayerType.RMSNORM: return 'border-cyan-500 shadow-cyan-500/20';
// Detection - Rose/Red theme
case LayerType.C2F_BLOCK:
case LayerType.SPPF_BLOCK:
case LayerType.DARKNET_BLOCK:
case LayerType.DETECT_HEAD:
case LayerType.ANCHOR_BOX:
case LayerType.NMS: return 'border-rose-500 shadow-rose-500/20';
// Audio - Indigo theme
case LayerType.MEL_SPECTROGRAM:
case LayerType.STFT:
case LayerType.SPEC_AUGMENT:
case LayerType.CONFORMER_BLOCK:
case LayerType.WAVENET_BLOCK:
case LayerType.VOCODER:
case LayerType.RVC_ENCODER:
case LayerType.WAV2VEC2_ENC:
case LayerType.AUDIO_EMBEDDING:
case LayerType.SINC_CONV: return 'border-indigo-500 shadow-indigo-500/20';
// 3D - Teal theme
case LayerType.NERF_BLOCK:
case LayerType.POINTNET_BLOCK:
case LayerType.POINT_TRANSFORMER:
case LayerType.TRIPLANE_ENC:
case LayerType.GAUSSIAN_SPLAT:
case LayerType.MESH_CONV: return 'border-teal-500 shadow-teal-500/20';
// OCR - Yellow/Lime
case LayerType.TPS_TRANSFORM:
case LayerType.CRNN_BLOCK:
case LayerType.CTC_DECODER: return 'border-yellow-500 shadow-yellow-500/20';
// Robotics/Motion - Cyan/Sky
case LayerType.DEPTH_DECODER:
case LayerType.DISPARITY_HEAD:
case LayerType.OPTICAL_FLOW:
case LayerType.VELOCITY_HEAD:
case LayerType.KALMAN_FILTER:
case LayerType.BEV_TRANSFORM:
case LayerType.RADAR_ENCODER: return 'border-sky-500 shadow-sky-500/20';
// Video - Fuchsia
case LayerType.VIDEO_DIFFUSION_BLOCK:
case LayerType.SPATIO_TEMPORAL_ATTN:
case LayerType.VIDEO_TOKENIZER:
case LayerType.FRAME_INTERPOLATOR:
case LayerType.TEMPORAL_SHIFT:
case LayerType.NON_LOCAL_BLOCK:
case LayerType.MULTIMODAL_FUSION: return 'border-fuchsia-500 shadow-fuchsia-500/20';
// Graph - Green
case LayerType.GCN_CONV:
case LayerType.GRAPH_SAGE:
case LayerType.GAT_CONV:
case LayerType.GIN_CONV: return 'border-green-500 shadow-green-500/20';
// Physics - Blue/Gray
case LayerType.NEURAL_ODE:
case LayerType.PINN_LINEAR:
case LayerType.HAMILTONIAN_NN:
case LayerType.PROTEIN_FOLDING: return 'border-blue-400 shadow-blue-400/20';
// Spiking - Electric Yellow
case LayerType.LIF_NEURON:
case LayerType.SPIKING_LAYER: return 'border-yellow-400 shadow-yellow-400/20';
// RL - Purple/Red
case LayerType.DUELING_HEAD:
case LayerType.PPO_HEAD:
case LayerType.SAC_HEAD: return 'border-purple-600 shadow-purple-600/20';
// Advanced - Multi/Rainbow feel (White glow)
case LayerType.CAPSULE:
case LayerType.HYPER_NET:
case LayerType.MAMBA_BLOCK:
case LayerType.RWKV_BLOCK:
case LayerType.HOPFIELD:
case LayerType.NORMALIZING_FLOW:
case LayerType.DNC_MEMORY:
case LayerType.ARCFACE:
case LayerType.ECHO_STATE: return 'border-white shadow-white/30';
case LayerType.CONCAT:
case LayerType.ADD:
case LayerType.SUBTRACT:
case LayerType.MULTIPLY:
case LayerType.AVERAGE:
case LayerType.MAXIMUM:
case LayerType.MINIMUM:
case LayerType.DOT: return 'border-pink-500 shadow-pink-500/20';
case LayerType.LAMBDA: return 'border-lime-500 shadow-lime-500/20';
case LayerType.CUSTOM: return 'border-lime-500 shadow-lime-500/20';
case LayerType.IDENTITY:
case LayerType.FLATTEN:
case LayerType.RESHAPE:
case LayerType.PERMUTE:
case LayerType.UNFLATTEN: return 'border-slate-600 shadow-slate-600/20';
default: return 'border-slate-500 shadow-slate-500/20';
}
};
const CustomNode = ({ data, selected }: NodeProps<NodeData>) => {
const isInput = data.type === LayerType.INPUT;
const isOutput = data.type === LayerType.OUTPUT;
// Render input modality badge
const inputBadge = isInput && data.params.modality ? (
<span className="ml-auto text-[9px] px-1.5 py-0.5 rounded-full bg-slate-700 text-emerald-400 font-mono tracking-tighter uppercase">
{data.params.modality}
</span>
) : null;
// Custom layer name badge
const customBadge = data.type === LayerType.CUSTOM && data.params.class_name ? (
<span className="ml-auto text-[8px] px-1.5 py-0.5 rounded-full bg-slate-700 text-lime-400 font-mono truncate max-w-[80px]">
{data.params.class_name.split('.').pop()}
</span>
) : null;
return (
<div className={`
relative min-w-[150px] bg-slate-900 rounded-lg border-2
transition-all duration-200 shadow-lg
${getColor(data.type)}
${selected ? 'ring-2 ring-white ring-offset-2 ring-offset-slate-900 scale-105' : ''}
`}>
{/* Input Handle */}
{!isInput && (
<Handle
type="target"
position={Position.Top}
className="!bg-slate-400 !w-3 !h-3 !-top-2 !border-slate-900"
/>
)}
{/* Header */}
<div className="flex items-center gap-2 px-3 py-2 border-b border-slate-800 bg-slate-800/50 rounded-t-md">
<span className="text-slate-300">
{getIcon(data.type)}
</span>
<span className="text-xs font-bold uppercase tracking-wider text-slate-200 truncate max-w-[100px]">
{data.label}
</span>
{inputBadge}
{customBadge}
</div>
{/* Body - Parameters Summary */}
<div className="p-3">
<div className="text-[10px] text-slate-400 font-mono space-y-1">
{Object.entries(data.params).slice(0, 3).map(([key, value]) => {
// Hide large code blocks in summary
if (key === 'definition_code' || key === 'imports') return null;
return (
<div key={key} className="flex justify-between gap-4">
<span className="opacity-70 truncate max-w-[80px]">{key}:</span>
<span className="text-slate-200 truncate max-w-[80px]">{String(value)}</span>
</div>
);
})}
{Object.keys(data.params).length === 0 && (
<span className="opacity-50 italic">No parameters</span>
)}
</div>
</div>
{/* Output Handle */}
{!isOutput && (
<Handle
type="source"
position={Position.Bottom}
className="!bg-slate-400 !w-3 !h-3 !-bottom-2 !border-slate-900"
/>
)}
</div>
);
};
export default memo(CustomNode);