Spaces:
Build error
Build error
| /** | |
| * Architecture Layout Engine | |
| * | |
| * Creates a VGG-style 3D visualization where: | |
| * - Block HEIGHT represents spatial dimension (H×W shrinks through network) | |
| * - Block DEPTH represents channel count (grows through network) | |
| * - Block WIDTH is thin (represents a single layer or group) | |
| * - Position flows left-to-right (X axis) | |
| * - Non-linear architectures spread vertically (Y axis) for branches | |
| * | |
| * This creates the classic "funnel" effect seen in CNN architecture diagrams, | |
| * while also supporting ResNets, U-Nets, and other branching architectures. | |
| */ | |
| // ============================================================================ | |
| // Types | |
| // ============================================================================ | |
| export interface TensorShape { | |
| height: number; // Spatial height (e.g., 224, 112, 56...) | |
| width: number; // Spatial width (usually same as height for images) | |
| channels: number; // Feature channels (e.g., 64, 128, 256, 512...) | |
| } | |
| export interface LayerBlock { | |
| id: string; | |
| name: string; | |
| displayName: string; | |
| type: string; | |
| category: string; | |
| // Graph info | |
| depth: number; // Graph depth (for X positioning) | |
| branchIndex: number; // Branch index at this depth (for Y offset) | |
| // Tensor dimensions | |
| inputShape: TensorShape | null; | |
| outputShape: TensorShape | null; | |
| // Computed 3D dimensions for visualization | |
| position: { x: number; y: number; z: number }; | |
| dimensions: { width: number; height: number; depth: number }; | |
| // Visual properties | |
| color: string; | |
| opacity: number; | |
| // Layer params for display | |
| params: Record<string, unknown>; | |
| numParameters: number; | |
| // Label info | |
| label: string; | |
| dimensionLabel: string; | |
| } | |
| export interface ArchitectureLayout { | |
| blocks: LayerBlock[]; | |
| connections: Array<{ | |
| from: string; | |
| to: string; | |
| fromPos: { x: number; y: number; z: number }; | |
| toPos: { x: number; y: number; z: number }; | |
| isSkipConnection: boolean; | |
| }>; | |
| bounds: { | |
| minX: number; maxX: number; | |
| minY: number; maxY: number; | |
| minZ: number; maxZ: number; | |
| }; | |
| center: { x: number; y: number; z: number }; | |
| totalLayers: number; | |
| modelName: string; | |
| isLinear: boolean; | |
| } | |
| // ============================================================================ | |
| // Color Scheme (professional, distinct but muted colors) | |
| // ============================================================================ | |
| const ARCH_COLORS: Record<string, string> = { | |
| // Core computation layers - blues | |
| convolution: '#5B8BD9', // Soft Blue (Conv layers) | |
| conv2d: '#5B8BD9', // Soft Blue | |
| conv1d: '#7BA3E0', // Light Blue | |
| // Fully connected - greens | |
| linear: '#6BAF6B', // Soft Green (Fully Connected) | |
| dense: '#6BAF6B', // Soft Green | |
| fc: '#6BAF6B', // Soft Green | |
| // Spatial reduction - coral/salmon | |
| pooling: '#E07070', // Coral Red (Max/Avg Pooling) | |
| maxpool: '#E07070', // Coral Red | |
| avgpool: '#E08850', // Soft Orange | |
| // Activations - warm amber | |
| activation: '#D9A740', // Soft Amber | |
| relu: '#D9A740', // Soft Amber | |
| sigmoid: '#C99030', // Deeper Amber | |
| softmax: '#E0A050', // Light Amber | |
| // Normalization - teals | |
| normalization: '#50A8A0', // Soft Teal | |
| batchnorm: '#50A8A0', // Soft Teal | |
| layernorm: '#60B8B0', // Light Teal | |
| // Regularization - slate | |
| regularization: '#708090', // Slate Gray | |
| dropout: '#708090', // Slate Gray | |
| // Special layers - purple | |
| attention: '#9070C0', // Soft Purple | |
| multiheadattention: '#9070C0', | |
| // Embeddings - cyan | |
| embedding: '#50A8C0', // Soft Cyan | |
| // Recurrent - rose | |
| recurrent: '#C070A0', // Soft Rose | |
| lstm: '#C070A0', // Soft Rose | |
| gru: '#D080B0', // Light Rose | |
| // Reshaping - cool gray | |
| reshape: '#808890', // Cool Gray | |
| flatten: '#808890', // Cool Gray | |
| // Input/Output | |
| input: '#50B080', // Soft Emerald | |
| output: '#D06070', // Soft Rose | |
| // Padding/Other | |
| padding: '#8070B0', // Soft Violet | |
| concat: '#A060C0', // Soft Fuchsia | |
| add: '#A060C0', // Soft Fuchsia (for residual adds) | |
| // Default | |
| other: '#909090', // Neutral Gray | |
| }; | |
| export function getArchColor(category: string, type?: string): string { | |
| const lowerCategory = category.toLowerCase(); | |
| const lowerType = (type || '').toLowerCase(); | |
| // First try to match specific type | |
| if (lowerType && ARCH_COLORS[lowerType]) { | |
| return ARCH_COLORS[lowerType]; | |
| } | |
| // Check type for partial matches | |
| if (lowerType) { | |
| if (lowerType.includes('conv')) return ARCH_COLORS.convolution; | |
| if (lowerType.includes('pool')) return ARCH_COLORS.pooling; | |
| if (lowerType.includes('relu')) return ARCH_COLORS.relu; | |
| if (lowerType.includes('norm')) return ARCH_COLORS.normalization; | |
| if (lowerType.includes('drop')) return ARCH_COLORS.dropout; | |
| if (lowerType.includes('attention')) return ARCH_COLORS.attention; | |
| if (lowerType.includes('embed')) return ARCH_COLORS.embedding; | |
| if (lowerType.includes('lstm') || lowerType.includes('gru')) return ARCH_COLORS.recurrent; | |
| if (lowerType.includes('flatten')) return ARCH_COLORS.flatten; | |
| if (lowerType.includes('dense') || lowerType.includes('linear') || lowerType.includes('fc')) return ARCH_COLORS.linear; | |
| } | |
| // Fall back to category | |
| return ARCH_COLORS[lowerCategory] || ARCH_COLORS.other; | |
| } | |
| // ============================================================================ | |
| // Dimension Parsing | |
| // ============================================================================ | |
| /** | |
| * Parse tensor shape from various formats | |
| */ | |
| export function parseTensorShape(shape: number[] | null | undefined): TensorShape | null { | |
| if (!shape || shape.length === 0) return null; | |
| // Handle different shape formats: | |
| // [B, C, H, W] - PyTorch conv (batch, channels, height, width) | |
| // [B, H, W, C] - TensorFlow/Keras (batch, height, width, channels) | |
| // [B, features] - Linear layers | |
| // [B, seq, features] - Sequence models | |
| if (shape.length === 4) { | |
| // Assume PyTorch format [B, C, H, W] | |
| const [_b, c, h, w] = shape; | |
| return { height: h, width: w, channels: c }; | |
| } else if (shape.length === 3) { | |
| // Could be [B, H, W] or [B, seq, features] | |
| const [_b, dim1, dim2] = shape; | |
| return { height: dim1, width: dim2, channels: 1 }; | |
| } else if (shape.length === 2) { | |
| // [B, features] - treat as 1×1×features | |
| return { height: 1, width: 1, channels: shape[1] }; | |
| } else if (shape.length === 1) { | |
| return { height: 1, width: 1, channels: shape[0] }; | |
| } | |
| return null; | |
| } | |
| /** | |
| * Infer shape from layer parameters | |
| */ | |
| export function inferShapeFromParams( | |
| layer: { type: string; category: string; params: Record<string, unknown> }, | |
| prevShape: TensorShape | null | |
| ): TensorShape { | |
| const params = layer.params || {}; | |
| // Extract common parameters | |
| const filters = params.filters as number || params.out_channels as number || params.outChannels as number; | |
| const units = params.units as number || params.out_features as number || params.outFeatures as number; | |
| // Note: kernel_size and padding affect output dimensions but we use simplified calculation | |
| const strides = params.strides || params.stride; | |
| // Start with previous shape or default | |
| let height = prevShape?.height || 224; | |
| let width = prevShape?.width || 224; | |
| let channels = prevShape?.channels || 3; | |
| const category = layer.category.toLowerCase(); | |
| const type = layer.type.toLowerCase(); | |
| // Handle pooling - reduces spatial dimensions | |
| if (category === 'pooling' || type.includes('pool')) { | |
| const poolStride = Array.isArray(strides) ? strides[0] : (strides as number) || 2; | |
| height = Math.floor(height / poolStride); | |
| width = Math.floor(width / poolStride); | |
| // Channels stay the same for pooling | |
| } | |
| // Handle convolution | |
| else if (category === 'convolution' || type.includes('conv')) { | |
| if (filters) channels = filters; | |
| // Check if stride reduces size | |
| const convStride = Array.isArray(strides) ? strides[0] : (strides as number) || 1; | |
| if (convStride > 1) { | |
| height = Math.floor(height / convStride); | |
| width = Math.floor(width / convStride); | |
| } | |
| } | |
| // Handle linear/dense - flattens to 1×1×features | |
| else if (category === 'linear' || type.includes('dense') || type.includes('linear')) { | |
| height = 1; | |
| width = 1; | |
| if (units) channels = units; | |
| } | |
| // Handle flatten | |
| else if (type.includes('flatten')) { | |
| const totalFeatures = height * width * channels; | |
| height = 1; | |
| width = 1; | |
| channels = totalFeatures; | |
| } | |
| // Handle reshape | |
| else if (category === 'reshape') { | |
| // Keep previous or use output shape if available | |
| } | |
| return { height, width, channels }; | |
| } | |
| // ============================================================================ | |
| // Layout Calculation | |
| // ============================================================================ | |
| const LAYOUT_CONFIG = { | |
| // Scaling factors for 3D dimensions | |
| spatialScale: 0.05, // How much to scale spatial dimensions (large) | |
| channelScale: 0.008, // How much to scale channel dimension | |
| layerThickness: 0.5, // Base thickness of each layer block | |
| // Spacing | |
| layerSpacing: 1.8, // Gap between layers (X) - wide for clear labels | |
| branchSpacing: 3.0, // Gap between parallel branches (Y) | |
| groupSpacing: 2.0, // Extra gap between groups (conv blocks) | |
| // Size limits | |
| minBlockSize: 0.3, // Minimum visible size | |
| maxSpatialSize: 8.0, // Max spatial size for large feature maps | |
| maxChannelSize: 6.0, // Max channel size for deep networks (increased) | |
| // Pooling layer thickness (thinner) | |
| poolingThickness: 0.3, | |
| activationThickness: 0.15, | |
| }; | |
| // ============================================================================ | |
| // Graph Analysis for Non-Linear Architectures | |
| // ============================================================================ | |
| interface GraphNode { | |
| id: string; | |
| depth: number; | |
| branchIndex: number; | |
| parents: string[]; | |
| children: string[]; | |
| } | |
| /** | |
| * Build a graph from layers and connections, compute depth for each node | |
| */ | |
| function buildGraph( | |
| layers: Array<{ id: string }>, | |
| connections: Array<{ source: string; target: string }> | |
| ): Map<string, GraphNode> { | |
| const graph = new Map<string, GraphNode>(); | |
| // Initialize nodes | |
| layers.forEach(layer => { | |
| graph.set(layer.id, { | |
| id: layer.id, | |
| depth: 0, | |
| branchIndex: 0, | |
| parents: [], | |
| children: [], | |
| }); | |
| }); | |
| // Build adjacency | |
| connections.forEach(conn => { | |
| const parent = graph.get(conn.source); | |
| const child = graph.get(conn.target); | |
| if (parent && child) { | |
| parent.children.push(conn.target); | |
| child.parents.push(conn.source); | |
| } | |
| }); | |
| // Find root nodes (no parents) | |
| const roots = Array.from(graph.values()).filter(n => n.parents.length === 0); | |
| // BFS to compute depths | |
| const queue: string[] = roots.map(r => r.id); | |
| const visited = new Set<string>(); | |
| while (queue.length > 0) { | |
| const nodeId = queue.shift()!; | |
| if (visited.has(nodeId)) continue; | |
| visited.add(nodeId); | |
| const node = graph.get(nodeId)!; | |
| // Depth is max parent depth + 1 | |
| if (node.parents.length > 0) { | |
| const maxParentDepth = Math.max( | |
| ...node.parents.map(p => graph.get(p)?.depth || 0) | |
| ); | |
| node.depth = maxParentDepth + 1; | |
| } | |
| // Add children to queue | |
| node.children.forEach(childId => { | |
| if (!visited.has(childId)) { | |
| queue.push(childId); | |
| } | |
| }); | |
| } | |
| // Handle disconnected nodes (not in any connection) | |
| let currentDepth = 0; | |
| layers.forEach(layer => { | |
| const node = graph.get(layer.id)!; | |
| if (!visited.has(layer.id)) { | |
| node.depth = currentDepth++; | |
| visited.add(layer.id); | |
| } | |
| }); | |
| // Assign branch indices for nodes at same depth | |
| const depthGroups = new Map<number, string[]>(); | |
| graph.forEach(node => { | |
| const group = depthGroups.get(node.depth) || []; | |
| group.push(node.id); | |
| depthGroups.set(node.depth, group); | |
| }); | |
| depthGroups.forEach(nodeIds => { | |
| nodeIds.forEach((id, index) => { | |
| const node = graph.get(id)!; | |
| // Center branches around 0 | |
| node.branchIndex = index - (nodeIds.length - 1) / 2; | |
| }); | |
| }); | |
| return graph; | |
| } | |
| /** | |
| * Check if architecture is linear (sequential) | |
| */ | |
| function isLinearArchitecture(graph: Map<string, GraphNode>): boolean { | |
| for (const node of graph.values()) { | |
| if (node.parents.length > 1 || node.children.length > 1) { | |
| return false; | |
| } | |
| } | |
| return true; | |
| } | |
| /** | |
| * Detect skip connections (edges that skip depths) | |
| */ | |
| function isSkipConnection( | |
| fromId: string, | |
| toId: string, | |
| graph: Map<string, GraphNode> | |
| ): boolean { | |
| const from = graph.get(fromId); | |
| const to = graph.get(toId); | |
| if (!from || !to) return false; | |
| return Math.abs(to.depth - from.depth) > 1; | |
| } | |
| /** | |
| * Compute 3D layout for architecture visualization | |
| * Handles both linear (sequential) and non-linear (branching) architectures | |
| */ | |
| export function computeArchitectureLayout( | |
| architecture: { | |
| name: string; | |
| framework: string; | |
| totalParameters: number; | |
| inputShape?: number[] | null; | |
| outputShape?: number[] | null; | |
| layers: Array<{ | |
| id: string; | |
| name: string; | |
| type: string; | |
| category: string; | |
| inputShape: number[] | null; | |
| outputShape: number[] | null; | |
| params: Record<string, unknown>; | |
| numParameters: number; | |
| }>; | |
| connections: Array<{ | |
| source: string; | |
| target: string; | |
| }>; | |
| } | |
| ): ArchitectureLayout { | |
| const blocks: LayerBlock[] = []; | |
| const connections: ArchitectureLayout['connections'] = []; | |
| // Build graph to analyze topology | |
| const graph = buildGraph(architecture.layers, architecture.connections); | |
| const isLinear = isLinearArchitecture(graph); | |
| // Track shapes for each node | |
| const shapeMap = new Map<string, TensorShape>(); | |
| const defaultShape = parseTensorShape(architecture.inputShape) || { height: 224, width: 224, channels: 3 }; | |
| // Track block counter per category for naming | |
| const blockCounter: Record<string, number> = {}; | |
| // Track cumulative X position (since blocks have different widths) | |
| let currentX = 0; | |
| const positionXMap = new Map<number, number>(); // depth -> X position | |
| // Track bounds | |
| let minX = Infinity, maxX = -Infinity; | |
| let minY = Infinity, maxY = -Infinity; | |
| let minZ = Infinity, maxZ = -Infinity; | |
| // Process each layer | |
| architecture.layers.forEach((layer) => { | |
| const category = layer.category.toLowerCase(); | |
| const type = layer.type.toLowerCase(); | |
| const graphNode = graph.get(layer.id)!; | |
| // Parse shapes | |
| let inputShape = parseTensorShape(layer.inputShape); | |
| let outputShape = parseTensorShape(layer.outputShape); | |
| // Get parent shape if available | |
| let parentShape: TensorShape | null = null; | |
| if (graphNode.parents.length > 0) { | |
| parentShape = shapeMap.get(graphNode.parents[0]) || null; | |
| } | |
| // If shapes not provided, infer from params | |
| if (!inputShape) { | |
| inputShape = parentShape || { ...defaultShape }; | |
| } | |
| if (!outputShape) { | |
| outputShape = inferShapeFromParams(layer, inputShape); | |
| } | |
| // Store output shape for children | |
| shapeMap.set(layer.id, outputShape); | |
| // Calculate 3D dimensions based on tensor shape | |
| // The block represents a 3D tensor: H × W × C | |
| // HEIGHT (Y-axis) = spatial frame height | |
| // WIDTH (X-axis) = channels (stacked frames, going left-to-right) | |
| // DEPTH (Z-axis) = layer thickness (thin slice) | |
| const spatialH = outputShape.height; | |
| const spatialW = outputShape.width; | |
| const channels = outputShape.channels; | |
| // Spatial size → Block Height (frame size, forms square with itself visually) | |
| const spatialScale = Math.sqrt(Math.max(spatialH, spatialW)) * LAYOUT_CONFIG.spatialScale * 5; | |
| let blockHeight = Math.min( | |
| LAYOUT_CONFIG.maxSpatialSize, | |
| Math.max(LAYOUT_CONFIG.minBlockSize, spatialScale) | |
| ); | |
| // Channels → Block Width (stacked frames going along X-axis) | |
| // Cap channels at 512 for visual scaling to avoid huge blocks | |
| const cappedChannels = Math.min(channels, 512); | |
| let blockWidth = Math.min( | |
| LAYOUT_CONFIG.maxChannelSize, | |
| Math.max(LAYOUT_CONFIG.minBlockSize * 0.3, Math.sqrt(cappedChannels) * LAYOUT_CONFIG.channelScale * 5) | |
| ); | |
| // Layer thickness → Block Depth (thin slice in Z) | |
| let blockDepth = blockHeight; // Same as height for square face when viewed from side | |
| // Special handling for Flatten layers - compact transition block | |
| if (type.includes('flatten')) { | |
| // Flatten is a transition - show as a thin vertical bar | |
| blockHeight = LAYOUT_CONFIG.minBlockSize * 2; | |
| blockWidth = LAYOUT_CONFIG.minBlockSize * 0.4; | |
| blockDepth = blockHeight; | |
| } | |
| // Make linear/dense/FC layers appear as vertical bars proportional to units | |
| else if (category === 'linear' || type.includes('dense') || type.includes('linear') || type.includes('fc')) { | |
| // Height scales with number of units (log scale to keep manageable) | |
| const units = Math.min(channels, 4096); // Cap for visualization | |
| blockHeight = Math.min( | |
| LAYOUT_CONFIG.maxSpatialSize, | |
| Math.max(LAYOUT_CONFIG.minBlockSize * 1.5, Math.log2(units + 1) * 0.4) | |
| ); | |
| blockWidth = LAYOUT_CONFIG.minBlockSize * 0.5; // Thin width | |
| blockDepth = LAYOUT_CONFIG.minBlockSize * 0.5; // Thin depth - appears as vertical bar | |
| } | |
| // Pooling layers - inherit parent's width (channels don't change) | |
| else if (category === 'pooling') { | |
| const parentBlock = blocks.find(b => graphNode.parents.includes(b.id)); | |
| if (parentBlock) { | |
| blockWidth = parentBlock.dimensions.width; | |
| } | |
| } | |
| // Activation/norm layers - thinner version of parent | |
| else if (category === 'activation' || category === 'normalization') { | |
| const parentBlock = blocks.find(b => graphNode.parents.includes(b.id)); | |
| if (parentBlock) { | |
| blockWidth = parentBlock.dimensions.width * 0.5; | |
| } | |
| } | |
| // Generate display name | |
| if (!blockCounter[category]) blockCounter[category] = 0; | |
| blockCounter[category]++; | |
| let displayName = layer.name; | |
| if (type.includes('conv')) { | |
| displayName = `Conv-${blockCounter[category]}`; | |
| } else if (type.includes('pool')) { | |
| displayName = `Pool`; | |
| } else if (type.includes('dense') || type.includes('linear')) { | |
| displayName = `FC-${blockCounter[category]}`; | |
| } else if (type.includes('flatten')) { | |
| displayName = 'Flatten'; | |
| } else if (type.includes('add') || type.includes('concat')) { | |
| displayName = type.includes('add') ? '⊕ Add' : '⊕ Concat'; | |
| } else if (type.includes('attention')) { | |
| displayName = 'Attention'; | |
| } | |
| // Format dimension label | |
| let dimensionLabel = ''; | |
| if (outputShape.height > 1) { | |
| dimensionLabel = `${outputShape.height}×${outputShape.width}×${outputShape.channels}`; | |
| } else { | |
| dimensionLabel = `${outputShape.channels}`; | |
| } | |
| // Calculate position based on graph topology | |
| // Use cumulative X position to account for varying block widths | |
| if (!positionXMap.has(graphNode.depth)) { | |
| positionXMap.set(graphNode.depth, currentX); | |
| currentX += blockWidth + LAYOUT_CONFIG.layerSpacing; | |
| } | |
| const posX = positionXMap.get(graphNode.depth)!; | |
| const posY = graphNode.branchIndex * LAYOUT_CONFIG.branchSpacing; | |
| const posZ = 0; | |
| // Create block | |
| const block: LayerBlock = { | |
| id: layer.id, | |
| name: layer.name, | |
| displayName, | |
| type: layer.type, | |
| category, | |
| depth: graphNode.depth, | |
| branchIndex: graphNode.branchIndex, | |
| inputShape, | |
| outputShape, | |
| position: { | |
| x: posX, | |
| y: posY, | |
| z: posZ, | |
| }, | |
| dimensions: { | |
| width: blockWidth, | |
| height: blockHeight, | |
| depth: blockDepth, | |
| }, | |
| color: getArchColor(category, layer.type), | |
| opacity: 1.0, | |
| params: layer.params, | |
| numParameters: layer.numParameters, | |
| label: displayName, | |
| dimensionLabel, | |
| }; | |
| blocks.push(block); | |
| // Update bounds | |
| minX = Math.min(minX, posX - blockWidth / 2); | |
| maxX = Math.max(maxX, posX + blockWidth / 2); | |
| minY = Math.min(minY, posY - blockHeight / 2); | |
| maxY = Math.max(maxY, posY + blockHeight / 2); | |
| minZ = Math.min(minZ, posZ - blockDepth / 2); | |
| maxZ = Math.max(maxZ, posZ + blockDepth / 2); | |
| }); | |
| // Build connections with skip connection detection | |
| architecture.connections.forEach(conn => { | |
| const fromBlock = blocks.find(b => b.id === conn.source); | |
| const toBlock = blocks.find(b => b.id === conn.target); | |
| if (fromBlock && toBlock) { | |
| const isSkip = isSkipConnection(conn.source, conn.target, graph); | |
| connections.push({ | |
| from: conn.source, | |
| to: conn.target, | |
| fromPos: { | |
| x: fromBlock.position.x + fromBlock.dimensions.width / 2, | |
| y: fromBlock.position.y, | |
| z: fromBlock.position.z, | |
| }, | |
| toPos: { | |
| x: toBlock.position.x - toBlock.dimensions.width / 2, | |
| y: toBlock.position.y, | |
| z: toBlock.position.z, | |
| }, | |
| isSkipConnection: isSkip, | |
| }); | |
| } | |
| }); | |
| // Handle empty model | |
| if (blocks.length === 0) { | |
| minX = 0; maxX = 1; | |
| minY = -1; maxY = 1; | |
| minZ = -1; maxZ = 1; | |
| } | |
| // Calculate center | |
| const center = { | |
| x: (minX + maxX) / 2, | |
| y: (minY + maxY) / 2, | |
| z: (minZ + maxZ) / 2, | |
| }; | |
| return { | |
| blocks, | |
| connections, | |
| bounds: { minX, maxX, minY, maxY, minZ, maxZ }, | |
| center, | |
| totalLayers: architecture.layers.length, | |
| modelName: architecture.name, | |
| isLinear, | |
| }; | |
| } | |
| /** | |
| * Group consecutive layers of same category into "stages" | |
| * (e.g., Conv+ReLU+Conv+ReLU → "Conv Block 1") | |
| */ | |
| export function groupLayersIntoStages( | |
| layout: ArchitectureLayout | |
| ): ArchitectureLayout { | |
| // For now, return as-is. Can implement grouping later. | |
| return layout; | |
| } | |