import { useMemo } from 'react'; import { useVisualizerStore } from '@/core/store'; import { getLayerComponent } from './LayerGeometry'; import { getLayerMeshComponent } from './NeuralLayerMesh'; import type { ComputedNode } from '@/core/store'; /** * Props for individual layer node component */ interface LayerNodeProps { node: ComputedNode; } /** * Single layer node in the 3D scene */ function LayerNode({ node }: LayerNodeProps) { const selectNode = useVisualizerStore(state => state.selectNode); const hoverNode = useVisualizerStore(state => state.hoverNode); const config = useVisualizerStore(state => state.config); // Check if this model was analyzed by backend (has enhanced data) const useEnhancedVisualization = useMemo(() => { // Use enhanced visualization for models with proper layer attributes return node.params && ( node.params.out_features !== undefined || node.params.outFeatures !== undefined || node.params.out_channels !== undefined || node.params.outChannels !== undefined || node.params.hidden_size !== undefined || node.params.hiddenSize !== undefined || node.params.num_heads !== undefined || node.params.numHeads !== undefined ); }, [node.params]); if (!node.visible) return null; // Enhanced visualization if (useEnhancedVisualization) { const NeuralComponent = getLayerMeshComponent(node.type); return ( selectNode(node.id)} onPointerOver={() => hoverNode(node.id)} onPointerOut={() => hoverNode(null)} /> ); } // Original visualization const LayerComponent = getLayerComponent(node.type); return ( selectNode(node.id)} onPointerOver={() => hoverNode(node.id)} onPointerOut={() => hoverNode(null)} /> ); } /** * Renders all layer nodes in the network */ export function LayerNodes() { const computedNodes = useVisualizerStore(state => state.computedNodes); const nodeArray = useMemo(() => Array.from(computedNodes.values()), [computedNodes]); return ( {nodeArray.map(node => ( ))} ); } export { LayerNode };