/** * Architecture Scene - VGG-style 3D Visualization * * Renders neural network architecture as 3D blocks where: * - Block HEIGHT = spatial dimension (shrinks through network) * - Block DEPTH = channel count (grows through network) * - Position flows left-to-right showing the data transformation */ import React, { useMemo, useRef, useState } from 'react'; import { Canvas, useFrame, useThree } from '@react-three/fiber'; import { OrbitControls, Environment, Text, Html, RoundedBox, Line, } from '@react-three/drei'; import * as THREE from 'three'; import { computeArchitectureLayout, type ArchitectureLayout, type LayerBlock } from '@/core/arch-layout'; // ============================================================================ // Types // ============================================================================ export type CameraView = 'front' | 'top' | 'side' | 'isometric' | 'back' | 'bottom'; export interface ArchSceneProps { architecture: { name: string; framework: string; totalParameters: number; inputShape?: number[] | null; outputShape?: number[] | null; layers: Array<{ id: string; name: string; type: string; category: string; inputShape: number[] | null; outputShape: number[] | null; params: Record; numParameters: number; }>; connections: Array<{ source: string; target: string; }>; } | null; showLabels?: boolean; showDimensions?: boolean; showConnections?: boolean; onLayerClick?: (layerId: string) => void; onLayerHover?: (layerId: string | null) => void; selectedLayerId?: string | null; // Camera control props cameraView?: CameraView; onCameraChange?: () => void; } // ============================================================================ // Layer Block Component // ============================================================================ interface LayerBlockMeshProps { block: LayerBlock; isSelected: boolean; isHovered: boolean; showLabel: boolean; showDimension: boolean; onClick: () => void; onHover: (hovered: boolean) => void; } const LayerBlockMesh: React.FC = ({ block, isSelected, isHovered, showLabel, showDimension, onClick, onHover, }) => { const meshRef = useRef(null); const [hoverAnim, setHoverAnim] = useState(0); // Animate on hover useFrame((_, delta) => { if (isHovered || isSelected) { setHoverAnim(prev => Math.min(1, prev + delta * 5)); } else { setHoverAnim(prev => Math.max(0, prev - delta * 5)); } if (meshRef.current) { const scale = 1 + hoverAnim * 0.05; meshRef.current.scale.setScalar(scale); } }); const { width, height, depth } = block.dimensions; const color = new THREE.Color(block.color); const edgeColor = new THREE.Color(block.color).multiplyScalar(1.4); // Brighter edges // Lighten color on hover if (hoverAnim > 0) { color.lerp(new THREE.Color('#ffffff'), hoverAnim * 0.3); } return ( {/* Glow/shadow base for depth */} {/* Main block */} { e.stopPropagation(); onClick(); }} onPointerEnter={(e) => { e.stopPropagation(); onHover(true); }} onPointerLeave={() => onHover(false)} > {/* Edge highlight - subtle wireframe */} {/* Selection outline */} {isSelected && ( )} {/* Top label (layer name) */} {showLabel && ( {block.label} )} {/* Bottom label (dimensions) */} {showDimension && block.dimensionLabel && ( {block.dimensionLabel} )} {/* Hover tooltip */} {isHovered && (
{block.name}
Type: {block.type}
{block.outputShape && (
Output: {block.outputShape.height}×{block.outputShape.width}×{block.outputShape.channels}
)} {block.numParameters > 0 && (
Params: {block.numParameters.toLocaleString()}
)}
)}
); }; // ============================================================================ // Connection Lines (supports skip connections with curves) // ============================================================================ interface ConnectionLineProps { from: { x: number; y: number; z: number }; to: { x: number; y: number; z: number }; isSkipConnection: boolean; } const ConnectionLine: React.FC = ({ from, to, isSkipConnection }) => { // For skip connections or connections between different Y levels, use a curve const needsCurve = isSkipConnection || Math.abs(from.y - to.y) > 0.1; const points: [number, number, number][] = useMemo(() => { const distance = Math.sqrt( Math.pow(to.x - from.x, 2) + Math.pow(to.y - from.y, 2) + Math.pow(to.z - from.z, 2) ); if (!needsCurve && distance < 3) { // Short direct connections - smooth bezier const midX = (from.x + to.x) / 2; return [ [from.x, from.y, from.z], [from.x + 0.2, from.y, from.z], [midX, from.y, from.z], [to.x - 0.2, to.y, to.z], [to.x, to.y, to.z], ]; } // Create a curved path for skip connections or long connections const midX = (from.x + to.x) / 2; const offsetZ = isSkipConnection ? 2.0 : 0.8; // Arc out in Z for visibility const offsetY = (to.y - from.y) / 2; const curveHeight = Math.min(distance * 0.15, 1.0); // Gentle curve based on distance // More control points for smoother curves return [ [from.x, from.y, from.z], [from.x + 0.4, from.y + curveHeight * 0.3, from.z + offsetZ * 0.2], [from.x + (midX - from.x) * 0.4, from.y + offsetY * 0.5 + curveHeight, from.z + offsetZ * 0.6], [midX, from.y + offsetY, from.z + offsetZ], [to.x - (to.x - midX) * 0.4, to.y - offsetY * 0.5 + curveHeight, to.z + offsetZ * 0.6], [to.x - 0.4, to.y + curveHeight * 0.3, to.z + offsetZ * 0.2], [to.x, to.y, to.z], ]; }, [from, to, needsCurve, isSkipConnection]); // Connection colors - subtle gray for normal, warm for skip const normalColor = '#607080'; // Muted slate const skipColor = '#D08050'; // Soft orange for residual/skip const glowColor = isSkipConnection ? '#E0A070' : '#708090'; return ( {/* Glow effect - thicker background line */} {/* Main connection line */} {/* Arrow head at destination */} ); }; // ============================================================================ // Camera Controller - Full user control with view presets // ============================================================================ interface CameraControllerProps { bounds: ArchitectureLayout['bounds']; center: { x: number; y: number; z: number }; isLinear: boolean; cameraView?: CameraView; } const CameraController: React.FC = ({ bounds, center, isLinear, cameraView }) => { const { camera } = useThree(); const controlsRef = useRef(null); // Calculate optimal distance based on model size const getOptimalDistance = React.useCallback(() => { const width = bounds.maxX - bounds.minX; const height = bounds.maxY - bounds.minY; const depth = bounds.maxZ - bounds.minZ; const maxDim = Math.max(width, height, depth); const fov = (camera as THREE.PerspectiveCamera).fov || 50; const fovRad = (fov / 2) * Math.PI / 180; return Math.max((maxDim / 2) / Math.tan(fovRad) * 1.8, 3); }, [bounds, camera]); // Apply camera view preset React.useEffect(() => { const distance = getOptimalDistance(); let pos: [number, number, number]; switch (cameraView) { case 'front': pos = [center.x, center.y, center.z + distance]; break; case 'back': pos = [center.x, center.y, center.z - distance]; break; case 'top': pos = [center.x, center.y + distance, center.z + 0.01]; break; case 'bottom': pos = [center.x, center.y - distance, center.z + 0.01]; break; case 'side': pos = [center.x + distance, center.y, center.z]; break; case 'isometric': pos = [ center.x + distance * 0.6, center.y + distance * 0.5, center.z + distance * 0.6 ]; break; default: // Default: auto based on architecture type if (isLinear) { const height = bounds.maxY - bounds.minY; pos = [center.x, center.y + height * 0.3, center.z + distance]; } else { pos = [ center.x + distance * 0.3, center.y + distance * 0.5, center.z + distance * 0.8 ]; } } camera.position.set(...pos); camera.lookAt(center.x, center.y, center.z); camera.updateProjectionMatrix(); if (controlsRef.current) { controlsRef.current.target.set(center.x, center.y, center.z); controlsRef.current.update(); } }, [bounds, center, camera, isLinear, cameraView, getOptimalDistance]); return ( ); }; // ============================================================================ // Legend (fixed position HTML overlay - not in 3D space) // ============================================================================ const LegendOverlay: React.FC = () => { const items = [ { color: '#5B8BD9', label: 'Conv' }, { color: '#E07070', label: 'Pool' }, { color: '#6BAF6B', label: 'FC/Linear' }, { color: '#D9A740', label: 'Activation' }, { color: '#50A8A0', label: 'Norm' }, { color: '#9070C0', label: 'Attention' }, { color: '#708090', label: 'Dropout' }, ]; return (
LAYER_TYPES
{items.map(item => (
{item.label}
))}
); }; // ============================================================================ // Main Scene // ============================================================================ const SceneContent: React.FC<{ layout: ArchitectureLayout; showLabels: boolean; showDimensions: boolean; showConnections: boolean; selectedLayerId: string | null; onLayerClick: (id: string) => void; onLayerHover: (id: string | null) => void; cameraView?: CameraView; }> = ({ layout, showLabels, showDimensions, showConnections, selectedLayerId, onLayerClick, onLayerHover, cameraView, }) => { const [hoveredId, setHoveredId] = useState(null); const handleHover = (id: string, hovered: boolean) => { const newId = hovered ? id : null; setHoveredId(newId); onLayerHover(newId); }; // Calculate grid size based on model const gridSize = Math.max( layout.bounds.maxX - layout.bounds.minX, layout.bounds.maxY - layout.bounds.minY, 10 ); return ( <> {/* Lighting */} {/* Environment */} {/* Camera setup with integrated controls */} {/* Grid - centered below model */} {/* Connections */} {showConnections && layout.connections.map((conn, i) => ( ))} {/* Layer blocks */} {layout.blocks.map(block => ( onLayerClick(block.id)} onHover={(h) => handleHover(block.id, h)} /> ))} ); }; // ============================================================================ // Exported Component // ============================================================================ export const ArchScene: React.FC = ({ architecture, showLabels = true, showDimensions = true, showConnections = true, onLayerClick, onLayerHover, selectedLayerId = null, cameraView, }) => { // Compute layout from architecture const layout = useMemo(() => { if (!architecture) return null; return computeArchitectureLayout(architecture); }, [architecture]); if (!layout) { return (
Drop a model file to visualize
); } return (
{})} onLayerHover={onLayerHover || (() => {})} cameraView={cameraView} />
); }; export default ArchScene;