Spaces:
Build error
Build error
| import type { NN3DModel, NN3DNode, NN3DEdge } from '@/schema/types'; | |
| import { parseNN3DModel, validateModelSemantics } from '@/schema/validator'; | |
| import { | |
| detectFormatFromExtension, | |
| detectFormatFromContent, | |
| isSupportedExtension, | |
| getFormatDisplayName, | |
| SUPPORTED_EXTENSIONS, | |
| } from './formats'; | |
| import { OnnxParser } from './formats/onnx-parser'; | |
| import { SafeTensorsParser } from './formats/safetensors-parser'; | |
| import { PyTorchParser } from './formats/pytorch-parser'; | |
| import { KerasParser } from './formats/keras-parser'; | |
| import { | |
| isBackendAvailable, | |
| analyzeUniversal, | |
| type ModelArchitecture, | |
| type LayerInfo | |
| } from './api-client'; | |
| /** | |
| * All supported file extensions | |
| */ | |
| export { SUPPORTED_EXTENSIONS }; | |
| /** | |
| * Track backend availability | |
| */ | |
| let backendAvailable: boolean | null = null; | |
| /** | |
| * Check if backend is available (cached) | |
| */ | |
| async function checkBackend(): Promise<boolean> { | |
| if (backendAvailable === null) { | |
| backendAvailable = await isBackendAvailable(); | |
| if (backendAvailable) { | |
| console.log('[NN3D] Python backend available - using enhanced model analysis'); | |
| } else { | |
| console.log('[NN3D] Python backend unavailable - using JavaScript parsers'); | |
| } | |
| } | |
| return backendAvailable; | |
| } | |
| /** | |
| * Convert backend layer type to NN3D type | |
| * Handles both PyTorch and Keras naming conventions | |
| */ | |
| function mapLayerType(layer: LayerInfo): string { | |
| const typeMap: Record<string, string> = { | |
| // PyTorch layers | |
| 'Linear': 'linear', | |
| 'Conv1d': 'conv1d', | |
| 'Conv2d': 'conv2d', | |
| 'Conv3d': 'conv3d', | |
| 'BatchNorm1d': 'batchNorm1d', | |
| 'BatchNorm2d': 'batchNorm2d', | |
| 'BatchNorm3d': 'batchNorm3d', | |
| 'LayerNorm': 'layerNorm', | |
| 'GroupNorm': 'groupNorm', | |
| 'ReLU': 'relu', | |
| 'LeakyReLU': 'leakyRelu', | |
| 'GELU': 'gelu', | |
| 'Sigmoid': 'sigmoid', | |
| 'Tanh': 'tanh', | |
| 'Softmax': 'softmax', | |
| 'Dropout': 'dropout', | |
| 'MaxPool1d': 'maxPool1d', | |
| 'MaxPool2d': 'maxPool2d', | |
| 'AvgPool2d': 'avgPool2d', | |
| 'AdaptiveAvgPool2d': 'adaptiveAvgPool', | |
| 'LSTM': 'lstm', | |
| 'GRU': 'gru', | |
| 'RNN': 'rnn', | |
| 'Embedding': 'embedding', | |
| 'MultiheadAttention': 'multiHeadAttention', | |
| 'Transformer': 'transformer', | |
| 'Flatten': 'flatten', | |
| // Keras/TensorFlow layers | |
| 'InputLayer': 'input', | |
| 'Dense': 'dense', | |
| 'Conv2D': 'conv2d', | |
| 'Conv1D': 'conv1d', | |
| 'Conv3D': 'conv3d', | |
| 'MaxPooling2D': 'maxPool2d', | |
| 'MaxPooling1D': 'maxPool1d', | |
| 'AveragePooling2D': 'avgPool2d', | |
| 'GlobalAveragePooling2D': 'globalAvgPool', | |
| 'GlobalMaxPooling2D': 'maxPool2d', | |
| 'BatchNormalization': 'batchNorm2d', | |
| 'Activation': 'relu', | |
| 'Add': 'add', | |
| 'Concatenate': 'concat', | |
| 'Multiply': 'multiply', | |
| 'ZeroPadding2D': 'pad', | |
| 'UpSampling2D': 'upsample', | |
| 'Reshape': 'reshape', | |
| 'Permute': 'reshape', | |
| 'SeparableConv2D': 'separableConv2d', | |
| 'DepthwiseConv2D': 'depthwiseConv2d', | |
| 'Conv2DTranspose': 'convTranspose2d', | |
| 'SimpleRNN': 'rnn', | |
| 'Bidirectional': 'lstm', | |
| 'TimeDistributed': 'custom', | |
| 'Lambda': 'custom', | |
| 'SpatialDropout2D': 'dropout', | |
| 'AlphaDropout': 'dropout', | |
| }; | |
| return typeMap[layer.type] || layer.type.toLowerCase().replace(/[0-9]d$/i, (m) => m.toLowerCase()); | |
| } | |
| /** | |
| * Convert backend architecture to NN3DModel | |
| */ | |
| function architectureToNN3DModel(arch: ModelArchitecture): NN3DModel { | |
| const nodes: NN3DNode[] = arch.layers.map((layer, index) => { | |
| // Build params object from layer params with proper names | |
| const params: Record<string, unknown> = {}; | |
| // Copy all layer params | |
| if (layer.params) { | |
| Object.entries(layer.params).forEach(([key, value]) => { | |
| // Map common param names to display-friendly names | |
| const keyMap: Record<string, string> = { | |
| 'in_features': 'inFeatures', | |
| 'out_features': 'outFeatures', | |
| 'in_channels': 'inChannels', | |
| 'out_channels': 'outChannels', | |
| 'kernel_size': 'kernelSize', | |
| 'hidden_size': 'hiddenSize', | |
| 'input_size': 'inputSize', | |
| 'num_layers': 'numLayers', | |
| 'bidirectional': 'bidirectional', | |
| 'batch_first': 'batchFirst', | |
| 'dropout': 'dropout', | |
| 'bias': 'bias', | |
| }; | |
| const displayKey = keyMap[key] || key; | |
| params[displayKey] = value; | |
| }); | |
| } | |
| // Add parameter count | |
| if (layer.numParameters > 0) { | |
| params.totalParams = layer.numParameters.toLocaleString(); | |
| } | |
| // Build additional attributes - include category from backend! | |
| const attributes: Record<string, unknown> = { ...layer.params }; | |
| if (layer.numParameters > 0) { | |
| attributes.parameters = layer.numParameters; | |
| } | |
| // Store the category from the backend so it can be used in visualization | |
| attributes.category = layer.category; | |
| return { | |
| id: layer.id, | |
| name: layer.name, | |
| type: mapLayerType(layer) as NN3DNode['type'], | |
| // Set inputShape and outputShape directly on the node | |
| inputShape: layer.inputShape || undefined, | |
| outputShape: layer.outputShape || undefined, | |
| params, | |
| attributes, | |
| position: { | |
| x: index * 3, | |
| y: 0, | |
| z: 0 | |
| } | |
| }; | |
| }); | |
| const edges: NN3DEdge[] = arch.connections.map((conn, index) => ({ | |
| id: `edge_${index}`, | |
| source: conn.source, | |
| target: conn.target, | |
| attributes: conn.tensorShape ? { tensorShape: conn.tensorShape } : undefined | |
| })); | |
| // Map framework string to valid type | |
| const frameworkMap: Record<string, 'pytorch' | 'tensorflow' | 'keras' | 'onnx' | 'jax' | 'custom'> = { | |
| 'pytorch': 'pytorch', | |
| 'tensorflow': 'tensorflow', | |
| 'keras': 'keras', | |
| 'onnx': 'onnx', | |
| 'jax': 'jax', | |
| }; | |
| const framework = frameworkMap[arch.framework] || 'custom'; | |
| return { | |
| version: '1.0.0', | |
| metadata: { | |
| name: arch.name, | |
| description: `${arch.framework} model with ${arch.totalParameters.toLocaleString()} parameters (${arch.trainableParameters.toLocaleString()} trainable)`, | |
| framework, | |
| created: new Date().toISOString(), | |
| totalParams: arch.totalParameters, | |
| trainableParams: arch.trainableParameters, | |
| inputShape: arch.inputShape || undefined, | |
| outputShape: arch.outputShape || undefined, | |
| }, | |
| graph: { | |
| nodes, | |
| edges | |
| }, | |
| visualization: { | |
| layout: 'layered', | |
| layerSpacing: 2.5, | |
| } | |
| }; | |
| } | |
| /** | |
| * Registered format parsers (fallback) | |
| */ | |
| const FORMAT_PARSERS = [ | |
| OnnxParser, | |
| SafeTensorsParser, | |
| PyTorchParser, | |
| KerasParser, | |
| ]; | |
| /** | |
| * All model extensions that can be analyzed by the universal backend endpoint | |
| */ | |
| const BACKEND_SUPPORTED_EXTENSIONS = [ | |
| '.pt', '.pth', '.ckpt', '.bin', '.model', // PyTorch | |
| '.onnx', // ONNX | |
| '.h5', '.hdf5', '.keras', // Keras | |
| '.pb', // TensorFlow | |
| '.safetensors' // SafeTensors | |
| ]; | |
| /** | |
| * Load model from file - auto-detects format | |
| */ | |
| export async function loadModelFromFile(file: File): Promise<NN3DModel> { | |
| // Check if extension is supported | |
| if (!isSupportedExtension(file.name)) { | |
| const ext = '.' + file.name.split('.').pop()?.toLowerCase(); | |
| throw new Error( | |
| `Unsupported file format: ${ext}\n\n` + | |
| `Supported formats:\n${SUPPORTED_EXTENSIONS.join(', ')}` | |
| ); | |
| } | |
| // Detect format | |
| const formatInfo = detectFormatFromExtension(file.name); | |
| const category = await detectFormatFromContent(file); | |
| const ext = '.' + file.name.split('.').pop()?.toLowerCase(); | |
| // Handle native NN3D/JSON format | |
| if (formatInfo.category === 'native' || category === 'native') { | |
| const text = await file.text(); | |
| return parseModelFromString(text); | |
| } | |
| // Try universal backend endpoint for all supported model formats | |
| if (BACKEND_SUPPORTED_EXTENSIONS.includes(ext)) { | |
| const hasBackend = await checkBackend(); | |
| if (hasBackend) { | |
| try { | |
| console.log(`Analyzing ${ext} model with universal backend endpoint...`); | |
| const result = await analyzeUniversal(file); | |
| if (result.success) { | |
| console.log(`[OK] Backend analysis complete: ${result.model_type}`); | |
| console.log(` Layers: ${result.architecture.layers.length}`); | |
| console.log(` Parameters: ${result.architecture.totalParameters.toLocaleString()}`); | |
| if (result.message) { | |
| console.info(result.message); | |
| } | |
| return architectureToNN3DModel(result.architecture); | |
| } else { | |
| console.warn('Backend returned unsuccessful result'); | |
| } | |
| } catch (error) { | |
| console.warn('Backend analysis failed, falling back to JS parser:', error); | |
| } | |
| } | |
| } | |
| // Try format-specific parsers (fallback) | |
| for (const parser of FORMAT_PARSERS) { | |
| if (await parser.canParse(file)) { | |
| const result = await parser.parse(file); | |
| if (result.success && result.model) { | |
| // Log any warnings | |
| if (result.warnings.length > 0) { | |
| console.warn('Model loading warnings:', result.warnings); | |
| } | |
| if (result.inferredStructure) { | |
| console.info('Model structure was inferred from weights. Some details may be approximate.'); | |
| } | |
| return result.model; | |
| } else if (result.error) { | |
| throw new Error(result.error); | |
| } | |
| } | |
| } | |
| // Fallback error | |
| throw new Error( | |
| `Unable to parse ${getFormatDisplayName(formatInfo.category)} file.\n\n` + | |
| (formatInfo.conversionHint || 'Please convert to .nn3d or .onnx format.') | |
| ); | |
| } | |
| /** | |
| * Load NN3D model from URL | |
| */ | |
| export async function loadModelFromUrl(url: string): Promise<NN3DModel> { | |
| const response = await fetch(url); | |
| if (!response.ok) { | |
| throw new Error(`Failed to fetch model: ${response.status} ${response.statusText}`); | |
| } | |
| const text = await response.text(); | |
| return parseModelFromString(text); | |
| } | |
| /** | |
| * Parse and validate model from JSON string | |
| */ | |
| export function parseModelFromString(jsonString: string): NN3DModel { | |
| const { model, validation } = parseNN3DModel(jsonString); | |
| if (!validation.valid || !model) { | |
| const errorMessages = validation.errors.map(e => `${e.path}: ${e.message}`).join('\n'); | |
| throw new Error(`Model validation failed:\n${errorMessages}`); | |
| } | |
| // Additional semantic validation | |
| const semanticValidation = validateModelSemantics(model); | |
| if (!semanticValidation.valid) { | |
| const warnings = semanticValidation.errors.map(e => `${e.path}: ${e.message}`).join('\n'); | |
| console.warn(`Model semantic warnings:\n${warnings}`); | |
| } | |
| return model; | |
| } | |
| /** | |
| * Export model to JSON string | |
| */ | |
| export function exportModelToString(model: NN3DModel, pretty = true): string { | |
| return JSON.stringify(model, null, pretty ? 2 : undefined); | |
| } | |
| /** | |
| * Download model as file | |
| */ | |
| export function downloadModel(model: NN3DModel, filename = 'model.nn3d'): void { | |
| const json = exportModelToString(model); | |
| const blob = new Blob([json], { type: 'application/json' }); | |
| const url = URL.createObjectURL(blob); | |
| const a = document.createElement('a'); | |
| a.href = url; | |
| a.download = filename; | |
| document.body.appendChild(a); | |
| a.click(); | |
| document.body.removeChild(a); | |
| URL.revokeObjectURL(url); | |
| } | |
| /** | |
| * Create a simple file drop handler | |
| */ | |
| export function createFileDropHandler( | |
| element: HTMLElement, | |
| onFile: (file: File) => void, | |
| options: { accept?: string[]; onDragOver?: () => void; onDragLeave?: () => void } = {} | |
| ): () => void { | |
| const { accept = ['.nn3d', '.json'], onDragOver, onDragLeave } = options; | |
| const handleDragOver = (e: DragEvent) => { | |
| e.preventDefault(); | |
| e.stopPropagation(); | |
| onDragOver?.(); | |
| }; | |
| const handleDragLeave = (e: DragEvent) => { | |
| e.preventDefault(); | |
| e.stopPropagation(); | |
| onDragLeave?.(); | |
| }; | |
| const handleDrop = (e: DragEvent) => { | |
| e.preventDefault(); | |
| e.stopPropagation(); | |
| onDragLeave?.(); | |
| const files = e.dataTransfer?.files; | |
| if (files && files.length > 0) { | |
| const file = files[0]; | |
| const ext = '.' + file.name.split('.').pop()?.toLowerCase(); | |
| if (accept.includes(ext)) { | |
| onFile(file); | |
| } else { | |
| console.warn(`Unsupported file type: ${ext}`); | |
| } | |
| } | |
| }; | |
| element.addEventListener('dragover', handleDragOver); | |
| element.addEventListener('dragleave', handleDragLeave); | |
| element.addEventListener('drop', handleDrop); | |
| // Return cleanup function | |
| return () => { | |
| element.removeEventListener('dragover', handleDragOver); | |
| element.removeEventListener('dragleave', handleDragLeave); | |
| element.removeEventListener('drop', handleDrop); | |
| }; | |
| } | |