import { useMemo } from 'react';
import { useVisualizerStore } from '@/core/store';
import { getLayerComponent } from './LayerGeometry';
import { getLayerMeshComponent } from './NeuralLayerMesh';
import type { ComputedNode } from '@/core/store';
/**
* Props for individual layer node component
*/
interface LayerNodeProps {
node: ComputedNode;
}
/**
* Single layer node in the 3D scene
*/
function LayerNode({ node }: LayerNodeProps) {
const selectNode = useVisualizerStore(state => state.selectNode);
const hoverNode = useVisualizerStore(state => state.hoverNode);
const config = useVisualizerStore(state => state.config);
// Check if this model was analyzed by backend (has enhanced data)
const useEnhancedVisualization = useMemo(() => {
// Use enhanced visualization for models with proper layer attributes
return node.params && (
node.params.out_features !== undefined ||
node.params.outFeatures !== undefined ||
node.params.out_channels !== undefined ||
node.params.outChannels !== undefined ||
node.params.hidden_size !== undefined ||
node.params.hiddenSize !== undefined ||
node.params.num_heads !== undefined ||
node.params.numHeads !== undefined
);
}, [node.params]);
if (!node.visible) return null;
// Enhanced visualization
if (useEnhancedVisualization) {
const NeuralComponent = getLayerMeshComponent(node.type);
return (
selectNode(node.id)}
onPointerOver={() => hoverNode(node.id)}
onPointerOut={() => hoverNode(null)}
/>
);
}
// Original visualization
const LayerComponent = getLayerComponent(node.type);
return (
selectNode(node.id)}
onPointerOver={() => hoverNode(node.id)}
onPointerOut={() => hoverNode(null)}
/>
);
}
/**
* Renders all layer nodes in the network
*/
export function LayerNodes() {
const computedNodes = useVisualizerStore(state => state.computedNodes);
const nodeArray = useMemo(() => Array.from(computedNodes.values()), [computedNodes]);
return (
{nodeArray.map(node => (
))}
);
}
export { LayerNode };