/** * Neural Network Layer Visualization Components * Enhanced 3D representation for different layer types */ import { useMemo, useRef } from 'react'; import * as THREE from 'three'; import { useFrame } from '@react-three/fiber'; import { Text, Sphere, Cylinder, Box, Torus } from '@react-three/drei'; import type { NN3DNode } from '@/schema/types'; export interface NeuralLayerProps { node: NN3DNode & { computedPosition: { x: number; y: number; z: number } }; color: string; selected?: boolean; hovered?: boolean; showNeurons?: boolean; maxNeurons?: number; onClick?: () => void; onPointerOver?: () => void; onPointerOut?: () => void; } /** * Determine neuron count from layer attributes */ function getNeuronCount(node: NN3DNode): number { const attrs = node.attributes || {}; // Check various attribute names const count = attrs.out_features || attrs.outFeatures || attrs.out_channels || attrs.outChannels || attrs.hidden_size || attrs.hiddenSize || attrs.units || attrs.num_features || attrs.numFeatures || attrs.embed_dim || attrs.embedDim || 16; // Default return typeof count === 'number' ? count : 16; } /** * Get input size from layer attributes */ function _getInputSize(node: NN3DNode): number { const attrs = node.attributes || {}; const count = attrs.in_features || attrs.inFeatures || attrs.in_channels || attrs.inChannels || attrs.input_size || attrs.inputSize || 16; return typeof count === 'number' ? count : 16; } // Export to prevent unused warning void _getInputSize; /** * Dense/Linear Layer - shows as a grid of neurons */ export function DenseLayerMesh({ node, color, selected = false, hovered = false, showNeurons = true, maxNeurons = 64, onClick, onPointerOver, onPointerOut, }: NeuralLayerProps) { const groupRef = useRef(null); const neuronCount = Math.min(getNeuronCount(node), maxNeurons); // Calculate grid dimensions const cols = Math.ceil(Math.sqrt(neuronCount)); const rows = Math.ceil(neuronCount / cols); const spacing = 0.15; const neuronRadius = 0.06; // Animate on hover useFrame(() => { if (groupRef.current) { const targetScale = hovered ? 1.1 : 1.0; groupRef.current.scale.lerp(new THREE.Vector3(targetScale, targetScale, targetScale), 0.1); } }); const neurons = useMemo(() => { const positions: [number, number, number][] = []; for (let i = 0; i < neuronCount; i++) { const row = Math.floor(i / cols); const col = i % cols; const x = (col - (cols - 1) / 2) * spacing; const y = (row - (rows - 1) / 2) * spacing; positions.push([x, y, 0]); } return positions; }, [neuronCount, cols, rows]); const baseColor = new THREE.Color(color); const displayColor = selected ? baseColor.clone().multiplyScalar(1.3) : hovered ? baseColor.clone().multiplyScalar(1.15) : baseColor; return ( {/* Background plane */} {/* Neurons */} {showNeurons && neurons.map((pos, i) => ( ))} {/* Selection indicator */} {selected && ( )} {/* Label */} {node.name} {/* Neuron count */} {getNeuronCount(node)} neurons ); } /** * Convolutional Layer - shows as a 3D block with feature maps */ export function ConvLayerMesh({ node, color, selected = false, hovered = false, onClick, onPointerOver, onPointerOut, }: NeuralLayerProps) { const groupRef = useRef(null); const attrs = node.attributes || {}; const outChannels = Math.min(Number(attrs.out_channels || attrs.outChannels || 32), 64); const kernelSize = attrs.kernel_size || attrs.kernelSize || [3, 3]; const kSize = Array.isArray(kernelSize) ? kernelSize[0] : kernelSize; // Stack of feature maps const layers = Math.min(Math.ceil(outChannels / 8), 8); const layerSpacing = 0.08; const mapSize = 0.5 + Math.log2(kSize + 1) * 0.2; useFrame(() => { if (groupRef.current) { const targetScale = hovered ? 1.1 : 1.0; groupRef.current.scale.lerp(new THREE.Vector3(targetScale, targetScale, targetScale), 0.1); groupRef.current.rotation.y += 0.002; } }); const baseColor = new THREE.Color(color); const displayColor = selected ? baseColor.clone().multiplyScalar(1.3) : hovered ? baseColor.clone().multiplyScalar(1.15) : baseColor; return ( {/* Stacked feature maps */} {Array.from({ length: layers }).map((_, i) => { const z = (i - (layers - 1) / 2) * layerSpacing; const alpha = 0.3 + (i / layers) * 0.5; return ( ); })} {/* Kernel indicator */} {/* Selection indicator */} {selected && ( )} {/* Label */} {node.name} {outChannels}ch / {kSize}×{kSize} ); } /** * Pooling Layer - shows as a compressed block */ export function PoolingLayerMesh({ node, color, selected = false, hovered = false, onClick, onPointerOver, onPointerOut, }: NeuralLayerProps) { const groupRef = useRef(null); const attrs = node.attributes || {}; const poolSize = attrs.kernel_size || attrs.kernelSize || [2, 2]; const pSize = Array.isArray(poolSize) ? poolSize[0] : poolSize; useFrame(() => { if (groupRef.current) { const targetScale = hovered ? 1.1 : 1.0; groupRef.current.scale.lerp(new THREE.Vector3(targetScale, targetScale, targetScale), 0.1); } }); const baseColor = new THREE.Color(color); const displayColor = selected ? baseColor.clone().multiplyScalar(1.3) : hovered ? baseColor.clone().multiplyScalar(1.15) : baseColor; return ( {/* Funnel shape representation */} {/* Grid overlay */} {selected && ( )} {node.name} ); } /** * Normalization Layer - shows as a flat processing block */ export function NormLayerMesh({ node, color, selected = false, hovered = false, onClick, onPointerOver, onPointerOut, }: NeuralLayerProps) { const groupRef = useRef(null); useFrame(() => { if (groupRef.current) { const targetScale = hovered ? 1.1 : 1.0; groupRef.current.scale.lerp(new THREE.Vector3(targetScale, targetScale, targetScale), 0.1); } }); const baseColor = new THREE.Color(color); const displayColor = selected ? baseColor.clone().multiplyScalar(1.3) : hovered ? baseColor.clone().multiplyScalar(1.15) : baseColor; return ( {/* Normalization symbol - horizontal lines */} {[-0.1, 0, 0.1].map((y, i) => ( ))} {selected && ( )} {node.name} ); } /** * Activation Layer - shows as a small function block */ export function ActivationLayerMesh({ node, color, selected = false, hovered = false, onClick, onPointerOver, onPointerOut, }: NeuralLayerProps) { const groupRef = useRef(null); useFrame(() => { if (groupRef.current) { const targetScale = hovered ? 1.2 : 1.0; groupRef.current.scale.lerp(new THREE.Vector3(targetScale, targetScale, targetScale), 0.1); } }); const baseColor = new THREE.Color(color); const displayColor = selected ? baseColor.clone().multiplyScalar(1.3) : hovered ? baseColor.clone().multiplyScalar(1.15) : baseColor; // Draw activation function shape const activationType = node.type.toLowerCase(); return ( {/* Sphere for activation */} {/* Function symbol */} {activationType === 'relu' ? 'ƒ' : activationType === 'sigmoid' ? 'σ' : activationType === 'tanh' ? 'tanh' : activationType === 'softmax' ? 'soft' : 'ƒ'} {selected && ( )} {node.name} ); } /** * LSTM/GRU/RNN Layer - shows as a recurrent block */ export function RecurrentLayerMesh({ node, color, selected = false, hovered = false, onClick, onPointerOver, onPointerOut, }: NeuralLayerProps) { const groupRef = useRef(null); const attrs = node.attributes || {}; const hiddenSize = attrs.hidden_size || attrs.hiddenSize || 128; useFrame((_, delta) => { if (groupRef.current) { const targetScale = hovered ? 1.1 : 1.0; groupRef.current.scale.lerp(new THREE.Vector3(targetScale, targetScale, targetScale), 0.1); // Slow rotation to show recurrence groupRef.current.rotation.y += delta * 0.3; } }); const baseColor = new THREE.Color(color); const displayColor = selected ? baseColor.clone().multiplyScalar(1.3) : hovered ? baseColor.clone().multiplyScalar(1.15) : baseColor; return ( {/* Main cell block */} {/* Recurrence loop */} {/* Arrow indicator */} {selected && ( )} {node.name} {`h=${hiddenSize}`} ); } /** * Attention/Transformer Layer - shows as a multi-head attention block */ export function AttentionLayerMesh({ node, color, selected = false, hovered = false, onClick, onPointerOver, onPointerOut, }: NeuralLayerProps) { const groupRef = useRef(null); const attrs = node.attributes || {}; const numHeads = Math.min(Number(attrs.num_heads || attrs.numHeads || 8), 12); useFrame(() => { if (groupRef.current) { const targetScale = hovered ? 1.1 : 1.0; groupRef.current.scale.lerp(new THREE.Vector3(targetScale, targetScale, targetScale), 0.1); } }); const baseColor = new THREE.Color(color); const displayColor = selected ? baseColor.clone().multiplyScalar(1.3) : hovered ? baseColor.clone().multiplyScalar(1.15) : baseColor; // Create attention head indicators const heads = useMemo(() => { const positions: [number, number, number][] = []; const radius = 0.25; for (let i = 0; i < numHeads; i++) { const angle = (i / numHeads) * Math.PI * 2; positions.push([Math.cos(angle) * radius, Math.sin(angle) * radius, 0]); } return positions; }, [numHeads]); return ( {/* Center query block */} {/* Attention heads */} {heads.map((pos, i) => ( {/* Connection to center */} ))} {selected && ( )} {node.name} {numHeads} heads ); } /** * Embedding Layer - shows as a lookup table */ export function EmbeddingLayerMesh({ node, color, selected = false, hovered = false, onClick, onPointerOver, onPointerOut, }: NeuralLayerProps) { const groupRef = useRef(null); const attrs = node.attributes || {}; const vocabSize = attrs.num_embeddings || attrs.numEmbeddings || 10000; const embedDim = attrs.embedding_dim || attrs.embeddingDim || 256; useFrame(() => { if (groupRef.current) { const targetScale = hovered ? 1.1 : 1.0; groupRef.current.scale.lerp(new THREE.Vector3(targetScale, targetScale, targetScale), 0.1); } }); const baseColor = new THREE.Color(color); const displayColor = selected ? baseColor.clone().multiplyScalar(1.3) : hovered ? baseColor.clone().multiplyScalar(1.15) : baseColor; return ( {/* Table representation - stacked rows */} {[0, 0.1, 0.2, 0.3].map((y, i) => ( ))} {/* Lookup arrow */} {selected && ( )} {node.name} {`${vocabSize}→${embedDim}`} ); } /** * Generic/Other Layer - fallback */ export function GenericLayerMesh({ node, color, selected = false, hovered = false, onClick, onPointerOver, onPointerOut, }: NeuralLayerProps) { const groupRef = useRef(null); useFrame(() => { if (groupRef.current) { const targetScale = hovered ? 1.1 : 1.0; groupRef.current.scale.lerp(new THREE.Vector3(targetScale, targetScale, targetScale), 0.1); } }); const baseColor = new THREE.Color(color); const displayColor = selected ? baseColor.clone().multiplyScalar(1.3) : hovered ? baseColor.clone().multiplyScalar(1.15) : baseColor; return ( {selected && ( )} {node.name} ); } /** * Get the appropriate layer mesh component based on node type */ export function getLayerMeshComponent(nodeType: string): React.ComponentType { const type = nodeType.toLowerCase(); if (type.includes('linear') || type.includes('dense') || type.includes('fc')) { return DenseLayerMesh; } if (type.includes('conv')) { return ConvLayerMesh; } if (type.includes('pool')) { return PoolingLayerMesh; } if (type.includes('norm') || type.includes('batch') || type.includes('layer')) { return NormLayerMesh; } if (type.includes('relu') || type.includes('sigmoid') || type.includes('tanh') || type.includes('gelu') || type.includes('softmax') || type.includes('activation')) { return ActivationLayerMesh; } if (type.includes('lstm') || type.includes('gru') || type.includes('rnn')) { return RecurrentLayerMesh; } if (type.includes('attention') || type.includes('transformer')) { return AttentionLayerMesh; } if (type.includes('embed')) { return EmbeddingLayerMesh; } return GenericLayerMesh; }