import React, { useCallback, useState } from 'react'; import { useVisualizerStore } from '@/core/store'; import { loadModelFromFile, SUPPORTED_EXTENSIONS } from '@/core/loader'; import { computeLayout } from '@/core/layout'; import { isSupportedExtension, getFormatDisplayName, detectFormatFromExtension } from '@/core/formats'; import { SavedModelsPanel } from '@/components/visualization/SavedModelsPanel'; import styles from './DropZone.module.css'; // Example CNN model for demo purposes const EXAMPLE_MODEL = { name: 'EXAMPLE_CNN', framework: 'demo', totalParameters: 1234567, trainableParameters: 1234567, inputShape: [1, 3, 224, 224], outputShape: [1, 1000], layers: [ { id: 'input', name: 'input', type: 'Input', category: 'input', inputShape: null, outputShape: [1, 3, 224, 224], params: {}, numParameters: 0, trainable: false }, { id: 'conv1', name: 'conv1', type: 'Conv2d', category: 'convolution', inputShape: [1, 3, 224, 224], outputShape: [1, 64, 112, 112], params: { filters: 64, kernelSize: [7, 7], strides: [2, 2], padding: 'same' }, numParameters: 9472, trainable: true }, { id: 'bn1', name: 'bn1', type: 'BatchNorm2d', category: 'normalization', inputShape: [1, 64, 112, 112], outputShape: [1, 64, 112, 112], params: { epsilon: 0.00001, momentum: 0.1 }, numParameters: 256, trainable: true }, { id: 'relu1', name: 'relu1', type: 'ReLU', category: 'activation', inputShape: [1, 64, 112, 112], outputShape: [1, 64, 112, 112], params: {}, numParameters: 0, trainable: false }, { id: 'pool1', name: 'pool1', type: 'MaxPool2d', category: 'pooling', inputShape: [1, 64, 112, 112], outputShape: [1, 64, 56, 56], params: { kernelSize: [3, 3], strides: [2, 2] }, numParameters: 0, trainable: false }, { id: 'conv2', name: 'conv2', type: 'Conv2d', category: 'convolution', inputShape: [1, 64, 56, 56], outputShape: [1, 128, 56, 56], params: { filters: 128, kernelSize: [3, 3], strides: [1, 1], padding: 'same' }, numParameters: 73856, trainable: true }, { id: 'bn2', name: 'bn2', type: 'BatchNorm2d', category: 'normalization', inputShape: [1, 128, 56, 56], outputShape: [1, 128, 56, 56], params: { epsilon: 0.00001, momentum: 0.1 }, numParameters: 512, trainable: true }, { id: 'relu2', name: 'relu2', type: 'ReLU', category: 'activation', inputShape: [1, 128, 56, 56], outputShape: [1, 128, 56, 56], params: {}, numParameters: 0, trainable: false }, { id: 'conv3', name: 'conv3', type: 'Conv2d', category: 'convolution', inputShape: [1, 128, 56, 56], outputShape: [1, 256, 28, 28], params: { filters: 256, kernelSize: [3, 3], strides: [2, 2], padding: 'same' }, numParameters: 295168, trainable: true }, { id: 'bn3', name: 'bn3', type: 'BatchNorm2d', category: 'normalization', inputShape: [1, 256, 28, 28], outputShape: [1, 256, 28, 28], params: { epsilon: 0.00001, momentum: 0.1 }, numParameters: 1024, trainable: true }, { id: 'relu3', name: 'relu3', type: 'ReLU', category: 'activation', inputShape: [1, 256, 28, 28], outputShape: [1, 256, 28, 28], params: {}, numParameters: 0, trainable: false }, { id: 'conv4', name: 'conv4', type: 'Conv2d', category: 'convolution', inputShape: [1, 256, 28, 28], outputShape: [1, 512, 14, 14], params: { filters: 512, kernelSize: [3, 3], strides: [2, 2], padding: 'same' }, numParameters: 1180160, trainable: true }, { id: 'bn4', name: 'bn4', type: 'BatchNorm2d', category: 'normalization', inputShape: [1, 512, 14, 14], outputShape: [1, 512, 14, 14], params: { epsilon: 0.00001, momentum: 0.1 }, numParameters: 2048, trainable: true }, { id: 'relu4', name: 'relu4', type: 'ReLU', category: 'activation', inputShape: [1, 512, 14, 14], outputShape: [1, 512, 14, 14], params: {}, numParameters: 0, trainable: false }, { id: 'gap', name: 'global_avg_pool', type: 'AdaptiveAvgPool2d', category: 'pooling', inputShape: [1, 512, 14, 14], outputShape: [1, 512, 1, 1], params: { outputSize: [1, 1] }, numParameters: 0, trainable: false }, { id: 'flatten', name: 'flatten', type: 'Flatten', category: 'reshape', inputShape: [1, 512, 1, 1], outputShape: [1, 512], params: {}, numParameters: 0, trainable: false }, { id: 'fc', name: 'fc', type: 'Linear', category: 'linear', inputShape: [1, 512], outputShape: [1, 1000], params: { inFeatures: 512, outFeatures: 1000, bias: true }, numParameters: 513000, trainable: true }, { id: 'softmax', name: 'softmax', type: 'Softmax', category: 'output', inputShape: [1, 1000], outputShape: [1, 1000], params: { dim: 1 }, numParameters: 0, trainable: false }, ], connections: [ { source: 'input', target: 'conv1', tensorShape: [1, 3, 224, 224] }, { source: 'conv1', target: 'bn1', tensorShape: [1, 64, 112, 112] }, { source: 'bn1', target: 'relu1', tensorShape: [1, 64, 112, 112] }, { source: 'relu1', target: 'pool1', tensorShape: [1, 64, 112, 112] }, { source: 'pool1', target: 'conv2', tensorShape: [1, 64, 56, 56] }, { source: 'conv2', target: 'bn2', tensorShape: [1, 128, 56, 56] }, { source: 'bn2', target: 'relu2', tensorShape: [1, 128, 56, 56] }, { source: 'relu2', target: 'conv3', tensorShape: [1, 128, 56, 56] }, { source: 'conv3', target: 'bn3', tensorShape: [1, 256, 28, 28] }, { source: 'bn3', target: 'relu3', tensorShape: [1, 256, 28, 28] }, { source: 'relu3', target: 'conv4', tensorShape: [1, 256, 28, 28] }, { source: 'conv4', target: 'bn4', tensorShape: [1, 512, 14, 14] }, { source: 'bn4', target: 'relu4', tensorShape: [1, 512, 14, 14] }, { source: 'relu4', target: 'gap', tensorShape: [1, 512, 14, 14] }, { source: 'gap', target: 'flatten', tensorShape: [1, 512, 1, 1] }, { source: 'flatten', target: 'fc', tensorShape: [1, 512] }, { source: 'fc', target: 'softmax', tensorShape: [1, 1000] }, ], }; interface DropZoneProps { children?: React.ReactNode; } /** * File drop zone overlay when no model is loaded */ export function DropZone({ children }: DropZoneProps) { const [isDragging, setIsDragging] = useState(false); const [isLoading, setIsLoading] = useState(false); const [loadingMessage, setLoadingMessage] = useState(''); const [showSavedModels, setShowSavedModels] = useState(false); const model = useVisualizerStore(state => state.model); const loadModel = useVisualizerStore(state => state.loadModel); const updateNodePositions = useVisualizerStore(state => state.updateNodePositions); const config = useVisualizerStore(state => state.config); const handleDragOver = useCallback((e: React.DragEvent) => { e.preventDefault(); e.stopPropagation(); setIsDragging(true); }, []); const handleDragLeave = useCallback((e: React.DragEvent) => { e.preventDefault(); e.stopPropagation(); setIsDragging(false); }, []); const processFile = useCallback(async (file: File) => { // Check file extension if (!isSupportedExtension(file.name)) { const ext = '.' + file.name.split('.').pop()?.toLowerCase(); alert( `Unsupported file format: ${ext}\n\n` + `Supported formats:\n` + `• .nn3d, .json - Native format\n` + `• .onnx - ONNX models\n` + `• .safetensors - SafeTensors (LLM weights)\n` + `• .pt, .pth, .ckpt - PyTorch (requires conversion)\n` + `• .h5, .hdf5 - Keras/TensorFlow (requires conversion)` ); return; } const formatInfo = detectFormatFromExtension(file.name); const formatName = getFormatDisplayName(formatInfo.category); setIsLoading(true); setLoadingMessage(`Loading ${formatName} model...`); try { const loadedModel = await loadModelFromFile(file); loadModel(loadedModel); setLoadingMessage('Computing layout...'); const layoutResult = computeLayout(loadedModel, { type: config.layout || 'layered', layerSpacing: config.layerSpacing || 3, }); updateNodePositions(layoutResult.positions); } catch (error) { console.error('Failed to load model:', error); alert(`Failed to load model:\n\n${error instanceof Error ? error.message : 'Unknown error'}`); } finally { setIsLoading(false); setLoadingMessage(''); } }, [loadModel, updateNodePositions, config]); const handleDrop = useCallback(async (e: React.DragEvent) => { e.preventDefault(); e.stopPropagation(); setIsDragging(false); const files = e.dataTransfer.files; if (files.length === 0) return; await processFile(files[0]); }, [processFile]); const handleFileInput = useCallback((e: React.ChangeEvent) => { const files = e.target.files; if (files && files.length > 0) { processFile(files[0]); } }, [processFile]); // Handle loading a saved model from the panel const handleLoadSavedModel = useCallback((architecture: any) => { // Create a model structure from the saved architecture (matching App.tsx format) const savedModel = { version: '1.0', metadata: { name: architecture.name, framework: architecture.framework, totalParams: architecture.totalParameters, trainableParams: architecture.trainableParameters, inputShape: architecture.inputShape, outputShape: architecture.outputShape, }, graph: { nodes: architecture.layers.map((layer: any) => ({ id: layer.id, name: layer.name, type: layer.type, inputShape: layer.inputShape, outputShape: layer.outputShape, params: layer.params || {}, attributes: { category: layer.category, parameters: layer.numParameters, }, })), edges: architecture.connections?.map((conn: any, idx: number) => ({ id: `edge-${idx}`, source: conn.source, target: conn.target, tensorShape: conn.tensorShape, })) || [], }, }; loadModel(savedModel as any); // Compute layout for the loaded model const layoutResult = computeLayout(savedModel as any, { type: config.layout || 'layered', layerSpacing: config.layerSpacing || 3, }); updateNodePositions(layoutResult.positions); }, [loadModel, updateNodePositions, config]); // Handle loading the example model const handleLoadExample = useCallback(() => { handleLoadSavedModel(EXAMPLE_MODEL); }, [handleLoadSavedModel]); return (
{children} {/* Loading overlay */} {isLoading && (

{loadingMessage}

)} {!model && !isLoading && (
[NN]

NN3D_VISUALIZER

// DROP A MODEL FILE TO VISUALIZE

.nn3d .onnx .safetensors .pt/.pth .h5

SUPPORTS: ONNX | SAFETENSORS | PYTORCH | KERAS

[*] 3D_VISUALIZATION [*] INTERACTIVE_EXPLORATION [*] LAYER_DETAILS
)} {isDragging && model && (
Drop to load new model
)} {/* Saved Models Panel */} {showSavedModels && ( setShowSavedModels(false)} /> )}
); } export default DropZone;