/** * Neural Network 3D Scene * * Main visualization component that renders the hierarchical model. * Uses React Three Fiber for declarative 3D rendering. */ import React, { useRef, useMemo, useCallback, useState, useEffect } from 'react'; import { Canvas, useFrame, useThree } from '@react-three/fiber'; import { OrbitControls, Environment, Text, Html, Line, QuadraticBezierLine, } from '@react-three/drei'; import * as THREE from 'three'; import type { ModelHierarchy, HierarchyNode } from '@/core/model-hierarchy'; import type { LayoutResult, NodeLayout, ConnectionLayout, Position3D } from '@/core/layout-engine'; import { computeFullLayout } from '@/core/layout-engine'; import { getLayerColor, getGeometryType, } from '@/core/layer-geometry'; // ============================================================================ // Types // ============================================================================ export interface NeuralSceneProps { hierarchy: ModelHierarchy | null; level?: 1 | 2 | 3; showLabels?: boolean; showConnections?: boolean; animateFlow?: boolean; onNodeClick?: (nodeId: string) => void; onNodeHover?: (nodeId: string | null) => void; selectedNodeId?: string | null; hoveredNodeId?: string | null; } // ============================================================================ // Layer Node Component // ============================================================================ interface LayerNodeProps { node: HierarchyNode; layout: NodeLayout; isSelected: boolean; isHovered: boolean; showLabel: boolean; onClick: () => void; onPointerEnter: () => void; onPointerLeave: () => void; } const LayerNode: React.FC = React.memo(({ node, layout, isSelected, isHovered, showLabel, onClick, onPointerEnter, onPointerLeave, }) => { const meshRef = useRef(null); const [hoverScale, setHoverScale] = useState(1); // Animate hover scale useFrame((_, delta) => { const targetScale = isHovered ? 1.1 : 1; setHoverScale(prev => THREE.MathUtils.lerp(prev, targetScale, delta * 10)); }); const color = getLayerColor(node.category); const geometryType = getGeometryType(node.category); const dims = layout.dimensions; // Create geometry based on type const geometry = useMemo(() => { switch (geometryType) { case 'plate': return new THREE.BoxGeometry(dims.depth * 0.3, dims.height, dims.width); case 'cylinder': return new THREE.CylinderGeometry(dims.width / 2.5, dims.width / 2.5, dims.height, 16); case 'hexagon': return new THREE.CylinderGeometry(dims.width / 2, dims.width / 2, dims.depth, 6); case 'pyramid': return new THREE.ConeGeometry(dims.width / 2, dims.height, 4); case 'container': return new THREE.BoxGeometry(dims.depth * 2, dims.height, dims.width * 1.5); case 'prism': case 'box': default: return new THREE.BoxGeometry(dims.depth, dims.height, dims.width); } }, [geometryType, dims]); const isContainer = node.type === 'group'; return ( { e.stopPropagation(); onClick(); }} onPointerEnter={(e) => { e.stopPropagation(); onPointerEnter(); }} onPointerLeave={onPointerLeave} > {isContainer ? ( ) : ( )} {/* Selection outline */} {isSelected && ( )} {/* Label */} {showLabel && ( {node.displayName} )} ); }); LayerNode.displayName = 'LayerNode'; // ============================================================================ // Connection Component // ============================================================================ interface ConnectionProps { connection: ConnectionLayout; animate: boolean; color?: string; } const Connection: React.FC = React.memo(({ connection, animate: _animate, // Reserved for future animation features color = '#4A90D9', }) => { const { sourcePosition, targetPosition, controlPoints, isSkipConnection } = connection; const lineColor = isSkipConnection ? '#00BCD4' : color; const opacity = isSkipConnection ? 0.6 : 0.8; // For curved connections if (controlPoints.length > 0) { const midPoint = controlPoints[0]; return ( ); } // Straight line return ( ); }); Connection.displayName = 'Connection'; // ============================================================================ // Flow Particles (animated data flow) // ============================================================================ interface FlowParticlesProps { connections: ConnectionLayout[]; speed?: number; } const FlowParticles: React.FC = ({ connections, speed = 1 }) => { const particlesRef = useRef(null); const progressRef = useRef(new Float32Array(connections.length)); useEffect(() => { // Initialize random progress for each particle progressRef.current = new Float32Array(connections.length); for (let i = 0; i < connections.length; i++) { progressRef.current[i] = Math.random(); } }, [connections.length]); useFrame((_, delta) => { if (!particlesRef.current) return; const positions = particlesRef.current.geometry.attributes.position as THREE.BufferAttribute; for (let i = 0; i < connections.length; i++) { progressRef.current[i] = (progressRef.current[i] + delta * speed * 0.5) % 1; const t = progressRef.current[i]; const conn = connections[i]; const { sourcePosition, targetPosition } = conn; // Linear interpolation along connection const x = sourcePosition.x + (targetPosition.x - sourcePosition.x) * t; const y = sourcePosition.y + (targetPosition.y - sourcePosition.y) * t; const z = sourcePosition.z + (targetPosition.z - sourcePosition.z) * t; positions.setXYZ(i, x, y, z); } positions.needsUpdate = true; }); const particlePositions = useMemo(() => { const positions = new Float32Array(connections.length * 3); connections.forEach((conn, i) => { positions[i * 3] = conn.sourcePosition.x; positions[i * 3 + 1] = conn.sourcePosition.y; positions[i * 3 + 2] = conn.sourcePosition.z; }); return positions; }, [connections]); return ( ); }; // ============================================================================ // Camera Controller // ============================================================================ interface CameraControllerProps { target: Position3D; initialPosition: Position3D; } const CameraController: React.FC = ({ target, initialPosition }) => { const { camera } = useThree(); useEffect(() => { camera.position.set(initialPosition.x, initialPosition.y, initialPosition.z); camera.lookAt(target.x, target.y, target.z); }, [camera, target, initialPosition]); return null; }; // ============================================================================ // Info Tooltip // ============================================================================ interface NodeTooltipProps { node: HierarchyNode; position: Position3D; } const NodeTooltip: React.FC = ({ node, position }) => { return (
{node.displayName}
Type: {node.layerData?.type || node.category}
Params: {node.totalParams.toLocaleString()}
{node.inputShape && (
Shape: [{node.inputShape.join(', ')}] → [{node.outputShape?.join(', ') || '?'}]
)} ); }; // ============================================================================ // Main Scene Content // ============================================================================ interface SceneContentProps extends NeuralSceneProps { layout: LayoutResult; } const SceneContent: React.FC = ({ hierarchy, layout, showLabels = true, showConnections = true, animateFlow = false, onNodeClick, onNodeHover, selectedNodeId, hoveredNodeId, }) => { const handleNodeClick = useCallback((nodeId: string) => { onNodeClick?.(nodeId); }, [onNodeClick]); const handleNodeHover = useCallback((nodeId: string | null) => { onNodeHover?.(nodeId); }, [onNodeHover]); // Get hovered node for tooltip const hoveredNode = hoveredNodeId && hierarchy ? hierarchy.allNodes.get(hoveredNodeId) : null; const hoveredLayout = hoveredNodeId ? layout.nodes.get(hoveredNodeId) : null; return ( <> {/* Lighting */} {/* Camera */} {/* Controls */} {/* Layer Nodes */} {Array.from(layout.nodes.entries()).map(([nodeId, nodeLayout]) => { const node = hierarchy?.allNodes.get(nodeId); if (!node || !nodeLayout.visible) return null; return ( handleNodeClick(nodeId)} onPointerEnter={() => handleNodeHover(nodeId)} onPointerLeave={() => handleNodeHover(null)} /> ); })} {/* Connections */} {showConnections && layout.connections.map(conn => ( ))} {/* Flow Animation */} {animateFlow && showConnections && ( )} {/* Hover Tooltip */} {hoveredNode && hoveredLayout && ( )} {/* Environment */} {/* Ground Grid */} ); }; // ============================================================================ // Main Component // ============================================================================ export const NeuralScene: React.FC = (props) => { const { hierarchy, level = 3 } = props; // Compute layout const layout = useMemo(() => { if (!hierarchy) { return { nodes: new Map(), connections: [], bounds: { min: { x: 0, y: 0, z: 0 }, max: { x: 0, y: 0, z: 0 }, center: { x: 0, y: 0, z: 0 }, }, cameraSuggestion: { position: { x: -10, y: 10, z: 20 }, target: { x: 0, y: 0, z: 0 }, }, }; } return computeFullLayout(hierarchy, level); }, [hierarchy, level]); if (!hierarchy) { return ( Drop a model file to visualize ); } return ( ); }; export default NeuralScene;