import { create } from 'zustand'; import { subscribeWithSelector } from 'zustand/middleware'; import type { NN3DModel, NN3DNode, NN3DEdge, VisualizationConfig, Position3D } from '@/schema/types'; /** * Computed node data for rendering */ export interface ComputedNode extends NN3DNode { computedPosition: Position3D; color: string; scale: { x: number; y: number; z: number }; visible: boolean; selected: boolean; hovered: boolean; lod: number; // 0 = high detail, 1 = medium, 2 = low } /** * Computed edge data for rendering */ export interface ComputedEdge extends NN3DEdge { sourcePosition: Position3D; targetPosition: Position3D; color: string; visible: boolean; highlighted: boolean; } /** * Camera state */ export interface CameraState { position: Position3D; target: Position3D; zoom: number; } /** * Selection state */ export interface SelectionState { selectedNodeId: string | null; hoveredNodeId: string | null; selectedEdgeId: string | null; } /** * Visualization store state */ export interface VisualizerState { // Model data model: NN3DModel | null; isLoading: boolean; error: string | null; // Computed data for rendering computedNodes: Map; computedEdges: ComputedEdge[]; // View state camera: CameraState; selection: SelectionState; // Configuration config: VisualizationConfig; // Actions loadModel: (model: NN3DModel) => void; clearModel: () => void; setLoading: (loading: boolean) => void; setError: (error: string | null) => void; // Selection actions selectNode: (nodeId: string | null) => void; hoverNode: (nodeId: string | null) => void; selectEdge: (edgeId: string | null) => void; // Camera actions setCameraPosition: (position: Position3D) => void; setCameraTarget: (target: Position3D) => void; resetCamera: () => void; // Config actions updateConfig: (config: Partial) => void; // Computed data actions updateNodePositions: (positions: Map) => void; updateNodeLOD: (lodMap: Map) => void; } // Default configuration const DEFAULT_CONFIG: VisualizationConfig = { layout: 'layered', theme: 'dark', layerSpacing: 3.0, nodeScale: 1.0, showLabels: true, showEdges: true, edgeStyle: 'tube', }; // Default camera state const DEFAULT_CAMERA: CameraState = { position: { x: 0, y: 5, z: 20 }, target: { x: 0, y: 0, z: 0 }, zoom: 1, }; /** * Main visualizer store */ export const useVisualizerStore = create()( subscribeWithSelector((set, get) => ({ // Initial state model: null, isLoading: false, error: null, computedNodes: new Map(), computedEdges: [], camera: DEFAULT_CAMERA, selection: { selectedNodeId: null, hoveredNodeId: null, selectedEdgeId: null, }, config: DEFAULT_CONFIG, // Model actions loadModel: (model: NN3DModel) => { const config = { ...DEFAULT_CONFIG, ...model.visualization }; // Initialize computed nodes const computedNodes = new Map(); model.graph.nodes.forEach((node, index) => { computedNodes.set(node.id, { ...node, computedPosition: node.position || { x: 0, y: index * config.layerSpacing!, z: 0 }, color: getNodeColor(node.type, config), scale: calculateNodeScale(node, config), visible: true, selected: false, hovered: false, lod: 0, }); }); // Initialize computed edges const computedEdges = model.graph.edges.map((edge, index) => { const sourceNode = computedNodes.get(edge.source); const targetNode = computedNodes.get(edge.target); return { ...edge, id: edge.id || `edge-${index}`, sourcePosition: sourceNode?.computedPosition || { x: 0, y: 0, z: 0 }, targetPosition: targetNode?.computedPosition || { x: 0, y: 0, z: 0 }, color: '#888888', visible: true, highlighted: false, }; }); set({ model, config, computedNodes, computedEdges, isLoading: false, error: null, }); }, clearModel: () => { set({ model: null, computedNodes: new Map(), computedEdges: [], selection: { selectedNodeId: null, hoveredNodeId: null, selectedEdgeId: null, }, }); }, setLoading: (isLoading: boolean) => set({ isLoading }), setError: (error: string | null) => set({ error, isLoading: false }), // Selection actions selectNode: (nodeId: string | null) => { const { computedNodes, computedEdges, selection } = get(); // Update previous selection if (selection.selectedNodeId) { const prevNode = computedNodes.get(selection.selectedNodeId); if (prevNode) { computedNodes.set(selection.selectedNodeId, { ...prevNode, selected: false }); } } // Update new selection if (nodeId) { const node = computedNodes.get(nodeId); if (node) { computedNodes.set(nodeId, { ...node, selected: true }); } } // Highlight connected edges const updatedEdges = computedEdges.map(edge => ({ ...edge, highlighted: nodeId ? (edge.source === nodeId || edge.target === nodeId) : false, })); set({ computedNodes: new Map(computedNodes), computedEdges: updatedEdges, selection: { ...selection, selectedNodeId: nodeId }, }); }, hoverNode: (nodeId: string | null) => { const { computedNodes, selection } = get(); // Update previous hover if (selection.hoveredNodeId && selection.hoveredNodeId !== nodeId) { const prevNode = computedNodes.get(selection.hoveredNodeId); if (prevNode) { computedNodes.set(selection.hoveredNodeId, { ...prevNode, hovered: false }); } } // Update new hover if (nodeId) { const node = computedNodes.get(nodeId); if (node) { computedNodes.set(nodeId, { ...node, hovered: true }); } } set({ computedNodes: new Map(computedNodes), selection: { ...selection, hoveredNodeId: nodeId }, }); }, selectEdge: (edgeId: string | null) => { set(state => ({ selection: { ...state.selection, selectedEdgeId: edgeId }, })); }, // Camera actions setCameraPosition: (position: Position3D) => { set(state => ({ camera: { ...state.camera, position }, })); }, setCameraTarget: (target: Position3D) => { set(state => ({ camera: { ...state.camera, target }, })); }, resetCamera: () => set({ camera: DEFAULT_CAMERA }), // Config actions updateConfig: (configUpdate: Partial) => { set(state => ({ config: { ...state.config, ...configUpdate }, })); }, // Computed data actions updateNodePositions: (positions: Map) => { const { computedNodes, computedEdges } = get(); positions.forEach((position, nodeId) => { const node = computedNodes.get(nodeId); if (node) { computedNodes.set(nodeId, { ...node, computedPosition: position }); } }); // Update edge positions const updatedEdges = computedEdges.map(edge => { const sourceNode = computedNodes.get(edge.source); const targetNode = computedNodes.get(edge.target); return { ...edge, sourcePosition: sourceNode?.computedPosition || edge.sourcePosition, targetPosition: targetNode?.computedPosition || edge.targetPosition, }; }); set({ computedNodes: new Map(computedNodes), computedEdges: updatedEdges, }); }, updateNodeLOD: (lodMap: Map) => { const { computedNodes } = get(); lodMap.forEach((lod, nodeId) => { const node = computedNodes.get(nodeId); if (node) { computedNodes.set(nodeId, { ...node, lod }); } }); set({ computedNodes: new Map(computedNodes) }); }, })) ); // Helper functions import { LAYER_CATEGORIES, DEFAULT_CATEGORY_COLORS, LayerType } from '@/schema/types'; function getNodeColor(layerType: LayerType, config: VisualizationConfig): string { const category = LAYER_CATEGORIES[layerType] || 'other'; return config.colorScheme?.[layerType] || config.colorScheme?.[category] || DEFAULT_CATEGORY_COLORS[category]; } function calculateNodeScale(node: NN3DNode, config: VisualizationConfig): { x: number; y: number; z: number } { const baseScale = config.nodeScale || 1.0; // Scale based on output shape if available if (node.outputShape && node.outputShape.length > 0) { const dims = node.outputShape.filter((d): d is number => typeof d === 'number'); if (dims.length >= 2) { const [h, w] = dims.slice(-2); return { x: Math.min(Math.sqrt(w) * 0.1, 2) * baseScale, y: Math.min(Math.sqrt(h) * 0.1, 2) * baseScale, z: 0.3 * baseScale, }; } } return { x: baseScale, y: baseScale, z: 0.3 * baseScale }; } // Selectors for optimized re-renders export const selectModel = (state: VisualizerState) => state.model; export const selectComputedNodes = (state: VisualizerState) => state.computedNodes; export const selectComputedEdges = (state: VisualizerState) => state.computedEdges; export const selectConfig = (state: VisualizerState) => state.config; export const selectSelection = (state: VisualizerState) => state.selection; export const selectCamera = (state: VisualizerState) => state.camera; export const selectIsLoading = (state: VisualizerState) => state.isLoading; export const selectError = (state: VisualizerState) => state.error;