Spaces:
Build error
Build error
File size: 3,237 Bytes
8a01471 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | import { useMemo } from 'react';
import { useVisualizerStore } from '@/core/store';
import { getLayerComponent } from './LayerGeometry';
import { getLayerMeshComponent } from './NeuralLayerMesh';
import type { ComputedNode } from '@/core/store';
/**
* Props for individual layer node component
*/
interface LayerNodeProps {
node: ComputedNode;
}
/**
* Single layer node in the 3D scene
*/
function LayerNode({ node }: LayerNodeProps) {
const selectNode = useVisualizerStore(state => state.selectNode);
const hoverNode = useVisualizerStore(state => state.hoverNode);
const config = useVisualizerStore(state => state.config);
// Check if this model was analyzed by backend (has enhanced data)
const useEnhancedVisualization = useMemo(() => {
// Use enhanced visualization for models with proper layer attributes
return node.params && (
node.params.out_features !== undefined ||
node.params.outFeatures !== undefined ||
node.params.out_channels !== undefined ||
node.params.outChannels !== undefined ||
node.params.hidden_size !== undefined ||
node.params.hiddenSize !== undefined ||
node.params.num_heads !== undefined ||
node.params.numHeads !== undefined
);
}, [node.params]);
if (!node.visible) return null;
// Enhanced visualization
if (useEnhancedVisualization) {
const NeuralComponent = getLayerMeshComponent(node.type);
return (
<group position={[node.computedPosition.x, node.computedPosition.y, node.computedPosition.z]}>
<NeuralComponent
node={{
id: node.id,
name: node.name,
type: node.type,
attributes: node.params,
computedPosition: { x: 0, y: 0, z: 0 }
}}
color={node.color}
selected={node.selected}
hovered={node.hovered}
showNeurons={config.showLabels}
onClick={() => selectNode(node.id)}
onPointerOver={() => hoverNode(node.id)}
onPointerOut={() => hoverNode(null)}
/>
</group>
);
}
// Original visualization
const LayerComponent = getLayerComponent(node.type);
return (
<group position={[node.computedPosition.x, node.computedPosition.y, node.computedPosition.z]}>
<LayerComponent
type={node.type}
params={node.params}
inputShape={node.inputShape}
outputShape={node.outputShape}
color={node.color}
scale={node.scale}
selected={node.selected}
hovered={node.hovered}
lod={node.lod}
showLabel={config.showLabels}
label={node.name}
onClick={() => selectNode(node.id)}
onPointerOver={() => hoverNode(node.id)}
onPointerOut={() => hoverNode(null)}
/>
</group>
);
}
/**
* Renders all layer nodes in the network
*/
export function LayerNodes() {
const computedNodes = useVisualizerStore(state => state.computedNodes);
const nodeArray = useMemo(() => Array.from(computedNodes.values()), [computedNodes]);
return (
<group name="layer-nodes">
{nodeArray.map(node => (
<LayerNode key={node.id} node={node} />
))}
</group>
);
}
export { LayerNode };
|