Spaces:
Runtime error
Runtime error
| import type { CanvasNode, CanvasEdge } from './ScanCanvas'; | |
| import type { LayerStructure, ConnectionInfo, LayerWeightStats } from '../../types/scan'; | |
| /** | |
| * Build a brain-shaped (sagittal view) layout for model layers. | |
| * | |
| * Mapping (SVG coords: right = anterior/front, left = posterior/back): | |
| * - Embedding β brain stem (bottom center) | |
| * - Transformer blocks β distributed along cortex arc | |
| * - Early blocks β occipital (back/left) | |
| * - Middle blocks β parietal (top) | |
| * - Late blocks β frontal (front/right) | |
| * - Attention layers β outer ring (cortex) | |
| * - MLP layers β inner ring (white matter) | |
| * - Output β frontal pole (top-right) | |
| */ | |
| export function buildBrainLayout( | |
| layers: LayerStructure[], | |
| connections: ConnectionInfo[], | |
| canvasWidth: number, | |
| canvasHeight: number, | |
| weightLayers?: LayerWeightStats[], | |
| ): { nodes: CanvasNode[]; edges: CanvasEdge[]; brainPath: string } { | |
| const maxParam = Math.max(...layers.map((l) => l.param_count), 1); | |
| // Weight lookup | |
| const weightMap = new Map<string, number>(); | |
| if (weightLayers) { | |
| const grouped = new Map<string, number[]>(); | |
| for (const w of weightLayers) { | |
| const arr = grouped.get(w.layer_id) ?? []; | |
| arr.push(w.l2_norm); | |
| grouped.set(w.layer_id, arr); | |
| } | |
| const allNorms = weightLayers.map((w) => w.l2_norm); | |
| const maxNorm = Math.max(...allNorms, 1); | |
| for (const [id, norms] of grouped) { | |
| const avg = norms.reduce((a, b) => a + b, 0) / norms.length; | |
| weightMap.set(id, avg / maxNorm); | |
| } | |
| } | |
| // Brain geometry β center in upper portion, leave room for brain stem below | |
| const cx = canvasWidth * 0.50; | |
| const cy = canvasHeight * 0.38; | |
| const rx = canvasWidth * 0.28; | |
| const ry = canvasHeight * 0.25; | |
| const innerScale = 0.62; // MLP ring as fraction of outer ring | |
| // Separate layers by type | |
| const embedding = layers.filter((l) => l.layer_type === 'embedding'); | |
| const output = layers.filter((l) => l.layer_type === 'output'); | |
| const attention = layers.filter((l) => l.layer_type === 'attention'); | |
| const mlp = layers.filter((l) => l.layer_type === 'mlp'); | |
| const nodes: CanvasNode[] = []; | |
| const makeNode = (layer: LayerStructure, x: number, y: number): CanvasNode => { | |
| const paramRatio = layer.param_count / maxParam; | |
| const baseRadius = 4 + paramRatio * 8; | |
| const wRatio = weightMap.get(layer.layer_id) ?? paramRatio; | |
| return { | |
| id: layer.layer_id, | |
| x, y, | |
| radius: baseRadius, | |
| layerType: layer.layer_type, | |
| layerIndex: layer.layer_index, | |
| paramCount: layer.param_count, | |
| ratio: wRatio, | |
| }; | |
| }; | |
| // Arc from occipital (back/left, 160Β°) through parietal (top, 90Β°) to frontal (front/right, 20Β°) | |
| // In standard math angles with SVG y-inversion: positive sin β upward | |
| // arcSpan is negative β clockwise traversal through the top | |
| const arcStart = (160 / 180) * Math.PI; | |
| const arcEnd = (20 / 180) * Math.PI; | |
| const arcSpan = arcEnd - arcStart; // negative | |
| // Embedding β brain stem (bottom center) | |
| for (const layer of embedding) { | |
| nodes.push(makeNode(layer, cx, cy + ry + canvasHeight * 0.12)); | |
| } | |
| // Attention β outer cortex ring | |
| for (let i = 0; i < attention.length; i++) { | |
| const t = attention.length > 1 ? i / (attention.length - 1) : 0.5; | |
| const angle = arcStart + t * arcSpan; | |
| const x = cx + rx * Math.cos(angle); | |
| const y = cy - ry * Math.sin(angle); // SVG y inverted | |
| nodes.push(makeNode(attention[i], x, y)); | |
| } | |
| // MLP β inner ring (white matter) | |
| for (let i = 0; i < mlp.length; i++) { | |
| const t = mlp.length > 1 ? i / (mlp.length - 1) : 0.5; | |
| const angle = arcStart + t * arcSpan; | |
| const x = cx + rx * innerScale * Math.cos(angle); | |
| const y = cy - ry * innerScale * Math.sin(angle); | |
| nodes.push(makeNode(mlp[i], x, y)); | |
| } | |
| // Output β frontal pole (top-right) | |
| for (const layer of output) { | |
| nodes.push(makeNode(layer, cx + rx * 0.55, cy - ry * 0.70)); | |
| } | |
| // Build edges | |
| const nodeMap = new Map(nodes.map((n) => [n.id, n])); | |
| const edges: CanvasEdge[] = connections | |
| .filter((c) => nodeMap.has(c.from_id) && nodeMap.has(c.to_id)) | |
| .map((c) => { | |
| const from = nodeMap.get(c.from_id)!; | |
| const to = nodeMap.get(c.to_id)!; | |
| return { | |
| fromId: c.from_id, | |
| toId: c.to_id, | |
| strength: (from.ratio + to.ratio) / 2, | |
| }; | |
| }); | |
| const brainPath = generateBrainPath(cx, cy, rx, ry, canvasHeight); | |
| return { nodes, edges, brainPath }; | |
| } | |
| /** | |
| * Generate a stylized brain silhouette path (sagittal view). | |
| * The path wraps around the node positions with ~20% margin, | |
| * forming a recognizable brain shape with a brain stem. | |
| * | |
| * Shape: occipital (back/left) β parietal (top) β frontal (front/right) | |
| * β temporal (front-bottom) β brain stem β back to occipital | |
| */ | |
| function generateBrainPath( | |
| cx: number, cy: number, | |
| rx: number, ry: number, | |
| canvasHeight: number, | |
| ): string { | |
| // Silhouette extends ~22% beyond the node ellipse | |
| const mx = rx * 1.22; | |
| const my = ry * 1.22; | |
| // Brain stem bottom (slightly below embedding position) | |
| const stemY = cy + ry + canvasHeight * 0.15; | |
| // Trace clockwise starting from brain stem bottom | |
| return [ | |
| // Brain stem bottom | |
| `M ${cx} ${stemY}`, | |
| // Left/back: stem up through cerebellum area to lower occipital | |
| `C ${cx - mx * 0.12} ${stemY - my * 0.15}, | |
| ${cx - mx * 0.42} ${cy + my * 0.65}, | |
| ${cx - mx * 0.55} ${cy + my * 0.35}`, | |
| // Occipital (back of brain): continue upward | |
| `C ${cx - mx * 0.72} ${cy + my * 0.05}, | |
| ${cx - mx * 0.88} ${cy - my * 0.15}, | |
| ${cx - mx * 0.88} ${cy - my * 0.38}`, | |
| // Occipital to parietal (back-top to top) | |
| `C ${cx - mx * 0.88} ${cy - my * 0.62}, | |
| ${cx - mx * 0.60} ${cy - my * 0.92}, | |
| ${cx - mx * 0.18} ${cy - my * 0.98}`, | |
| // Parietal to frontal (top to front-top) | |
| `C ${cx + mx * 0.18} ${cy - my * 1.02}, | |
| ${cx + mx * 0.58} ${cy - my * 0.95}, | |
| ${cx + mx * 0.82} ${cy - my * 0.68}`, | |
| // Frontal (front of brain): curve downward | |
| `C ${cx + mx * 0.96} ${cy - my * 0.42}, | |
| ${cx + mx * 1.00} ${cy - my * 0.08}, | |
| ${cx + mx * 0.92} ${cy + my * 0.18}`, | |
| // Temporal (front-bottom): continue down toward stem | |
| `C ${cx + mx * 0.82} ${cy + my * 0.42}, | |
| ${cx + mx * 0.55} ${cy + my * 0.60}, | |
| ${cx + mx * 0.35} ${cy + my * 0.55}`, | |
| // Temporal to brain stem | |
| `C ${cx + mx * 0.18} ${cy + my * 0.50}, | |
| ${cx + mx * 0.10} ${stemY - my * 0.10}, | |
| ${cx} ${stemY}`, | |
| 'Z', | |
| ].join(' '); | |
| } | |