visualiser2 / src /core /store.ts
Vishalpainjane's picture
added files
8a01471
import { create } from 'zustand';
import { subscribeWithSelector } from 'zustand/middleware';
import type { NN3DModel, NN3DNode, NN3DEdge, VisualizationConfig, Position3D } from '@/schema/types';
/**
* Computed node data for rendering
*/
export interface ComputedNode extends NN3DNode {
computedPosition: Position3D;
color: string;
scale: { x: number; y: number; z: number };
visible: boolean;
selected: boolean;
hovered: boolean;
lod: number; // 0 = high detail, 1 = medium, 2 = low
}
/**
* Computed edge data for rendering
*/
export interface ComputedEdge extends NN3DEdge {
sourcePosition: Position3D;
targetPosition: Position3D;
color: string;
visible: boolean;
highlighted: boolean;
}
/**
* Camera state
*/
export interface CameraState {
position: Position3D;
target: Position3D;
zoom: number;
}
/**
* Selection state
*/
export interface SelectionState {
selectedNodeId: string | null;
hoveredNodeId: string | null;
selectedEdgeId: string | null;
}
/**
* Visualization store state
*/
export interface VisualizerState {
// Model data
model: NN3DModel | null;
isLoading: boolean;
error: string | null;
// Computed data for rendering
computedNodes: Map<string, ComputedNode>;
computedEdges: ComputedEdge[];
// View state
camera: CameraState;
selection: SelectionState;
// Configuration
config: VisualizationConfig;
// Actions
loadModel: (model: NN3DModel) => void;
clearModel: () => void;
setLoading: (loading: boolean) => void;
setError: (error: string | null) => void;
// Selection actions
selectNode: (nodeId: string | null) => void;
hoverNode: (nodeId: string | null) => void;
selectEdge: (edgeId: string | null) => void;
// Camera actions
setCameraPosition: (position: Position3D) => void;
setCameraTarget: (target: Position3D) => void;
resetCamera: () => void;
// Config actions
updateConfig: (config: Partial<VisualizationConfig>) => void;
// Computed data actions
updateNodePositions: (positions: Map<string, Position3D>) => void;
updateNodeLOD: (lodMap: Map<string, number>) => void;
}
// Default configuration
const DEFAULT_CONFIG: VisualizationConfig = {
layout: 'layered',
theme: 'dark',
layerSpacing: 3.0,
nodeScale: 1.0,
showLabels: true,
showEdges: true,
edgeStyle: 'tube',
};
// Default camera state
const DEFAULT_CAMERA: CameraState = {
position: { x: 0, y: 5, z: 20 },
target: { x: 0, y: 0, z: 0 },
zoom: 1,
};
/**
* Main visualizer store
*/
export const useVisualizerStore = create<VisualizerState>()(
subscribeWithSelector((set, get) => ({
// Initial state
model: null,
isLoading: false,
error: null,
computedNodes: new Map(),
computedEdges: [],
camera: DEFAULT_CAMERA,
selection: {
selectedNodeId: null,
hoveredNodeId: null,
selectedEdgeId: null,
},
config: DEFAULT_CONFIG,
// Model actions
loadModel: (model: NN3DModel) => {
const config = { ...DEFAULT_CONFIG, ...model.visualization };
// Initialize computed nodes
const computedNodes = new Map<string, ComputedNode>();
model.graph.nodes.forEach((node, index) => {
computedNodes.set(node.id, {
...node,
computedPosition: node.position || { x: 0, y: index * config.layerSpacing!, z: 0 },
color: getNodeColor(node.type, config),
scale: calculateNodeScale(node, config),
visible: true,
selected: false,
hovered: false,
lod: 0,
});
});
// Initialize computed edges
const computedEdges = model.graph.edges.map((edge, index) => {
const sourceNode = computedNodes.get(edge.source);
const targetNode = computedNodes.get(edge.target);
return {
...edge,
id: edge.id || `edge-${index}`,
sourcePosition: sourceNode?.computedPosition || { x: 0, y: 0, z: 0 },
targetPosition: targetNode?.computedPosition || { x: 0, y: 0, z: 0 },
color: '#888888',
visible: true,
highlighted: false,
};
});
set({
model,
config,
computedNodes,
computedEdges,
isLoading: false,
error: null,
});
},
clearModel: () => {
set({
model: null,
computedNodes: new Map(),
computedEdges: [],
selection: {
selectedNodeId: null,
hoveredNodeId: null,
selectedEdgeId: null,
},
});
},
setLoading: (isLoading: boolean) => set({ isLoading }),
setError: (error: string | null) => set({ error, isLoading: false }),
// Selection actions
selectNode: (nodeId: string | null) => {
const { computedNodes, computedEdges, selection } = get();
// Update previous selection
if (selection.selectedNodeId) {
const prevNode = computedNodes.get(selection.selectedNodeId);
if (prevNode) {
computedNodes.set(selection.selectedNodeId, { ...prevNode, selected: false });
}
}
// Update new selection
if (nodeId) {
const node = computedNodes.get(nodeId);
if (node) {
computedNodes.set(nodeId, { ...node, selected: true });
}
}
// Highlight connected edges
const updatedEdges = computedEdges.map(edge => ({
...edge,
highlighted: nodeId ? (edge.source === nodeId || edge.target === nodeId) : false,
}));
set({
computedNodes: new Map(computedNodes),
computedEdges: updatedEdges,
selection: { ...selection, selectedNodeId: nodeId },
});
},
hoverNode: (nodeId: string | null) => {
const { computedNodes, selection } = get();
// Update previous hover
if (selection.hoveredNodeId && selection.hoveredNodeId !== nodeId) {
const prevNode = computedNodes.get(selection.hoveredNodeId);
if (prevNode) {
computedNodes.set(selection.hoveredNodeId, { ...prevNode, hovered: false });
}
}
// Update new hover
if (nodeId) {
const node = computedNodes.get(nodeId);
if (node) {
computedNodes.set(nodeId, { ...node, hovered: true });
}
}
set({
computedNodes: new Map(computedNodes),
selection: { ...selection, hoveredNodeId: nodeId },
});
},
selectEdge: (edgeId: string | null) => {
set(state => ({
selection: { ...state.selection, selectedEdgeId: edgeId },
}));
},
// Camera actions
setCameraPosition: (position: Position3D) => {
set(state => ({
camera: { ...state.camera, position },
}));
},
setCameraTarget: (target: Position3D) => {
set(state => ({
camera: { ...state.camera, target },
}));
},
resetCamera: () => set({ camera: DEFAULT_CAMERA }),
// Config actions
updateConfig: (configUpdate: Partial<VisualizationConfig>) => {
set(state => ({
config: { ...state.config, ...configUpdate },
}));
},
// Computed data actions
updateNodePositions: (positions: Map<string, Position3D>) => {
const { computedNodes, computedEdges } = get();
positions.forEach((position, nodeId) => {
const node = computedNodes.get(nodeId);
if (node) {
computedNodes.set(nodeId, { ...node, computedPosition: position });
}
});
// Update edge positions
const updatedEdges = computedEdges.map(edge => {
const sourceNode = computedNodes.get(edge.source);
const targetNode = computedNodes.get(edge.target);
return {
...edge,
sourcePosition: sourceNode?.computedPosition || edge.sourcePosition,
targetPosition: targetNode?.computedPosition || edge.targetPosition,
};
});
set({
computedNodes: new Map(computedNodes),
computedEdges: updatedEdges,
});
},
updateNodeLOD: (lodMap: Map<string, number>) => {
const { computedNodes } = get();
lodMap.forEach((lod, nodeId) => {
const node = computedNodes.get(nodeId);
if (node) {
computedNodes.set(nodeId, { ...node, lod });
}
});
set({ computedNodes: new Map(computedNodes) });
},
}))
);
// Helper functions
import { LAYER_CATEGORIES, DEFAULT_CATEGORY_COLORS, LayerType } from '@/schema/types';
function getNodeColor(layerType: LayerType, config: VisualizationConfig): string {
const category = LAYER_CATEGORIES[layerType] || 'other';
return config.colorScheme?.[layerType] ||
config.colorScheme?.[category] ||
DEFAULT_CATEGORY_COLORS[category];
}
function calculateNodeScale(node: NN3DNode, config: VisualizationConfig): { x: number; y: number; z: number } {
const baseScale = config.nodeScale || 1.0;
// Scale based on output shape if available
if (node.outputShape && node.outputShape.length > 0) {
const dims = node.outputShape.filter((d): d is number => typeof d === 'number');
if (dims.length >= 2) {
const [h, w] = dims.slice(-2);
return {
x: Math.min(Math.sqrt(w) * 0.1, 2) * baseScale,
y: Math.min(Math.sqrt(h) * 0.1, 2) * baseScale,
z: 0.3 * baseScale,
};
}
}
return { x: baseScale, y: baseScale, z: 0.3 * baseScale };
}
// Selectors for optimized re-renders
export const selectModel = (state: VisualizerState) => state.model;
export const selectComputedNodes = (state: VisualizerState) => state.computedNodes;
export const selectComputedEdges = (state: VisualizerState) => state.computedEdges;
export const selectConfig = (state: VisualizerState) => state.config;
export const selectSelection = (state: VisualizerState) => state.selection;
export const selectCamera = (state: VisualizerState) => state.camera;
export const selectIsLoading = (state: VisualizerState) => state.isLoading;
export const selectError = (state: VisualizerState) => state.error;