Spaces:
Runtime error
Runtime error
File size: 6,590 Bytes
0ce9643 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | import type { CanvasNode, CanvasEdge } from './ScanCanvas';
import type { LayerStructure, ConnectionInfo, LayerWeightStats } from '../../types/scan';
/**
* Build a brain-shaped (sagittal view) layout for model layers.
*
* Mapping (SVG coords: right = anterior/front, left = posterior/back):
* - Embedding β brain stem (bottom center)
* - Transformer blocks β distributed along cortex arc
* - Early blocks β occipital (back/left)
* - Middle blocks β parietal (top)
* - Late blocks β frontal (front/right)
* - Attention layers β outer ring (cortex)
* - MLP layers β inner ring (white matter)
* - Output β frontal pole (top-right)
*/
export function buildBrainLayout(
layers: LayerStructure[],
connections: ConnectionInfo[],
canvasWidth: number,
canvasHeight: number,
weightLayers?: LayerWeightStats[],
): { nodes: CanvasNode[]; edges: CanvasEdge[]; brainPath: string } {
const maxParam = Math.max(...layers.map((l) => l.param_count), 1);
// Weight lookup
const weightMap = new Map<string, number>();
if (weightLayers) {
const grouped = new Map<string, number[]>();
for (const w of weightLayers) {
const arr = grouped.get(w.layer_id) ?? [];
arr.push(w.l2_norm);
grouped.set(w.layer_id, arr);
}
const allNorms = weightLayers.map((w) => w.l2_norm);
const maxNorm = Math.max(...allNorms, 1);
for (const [id, norms] of grouped) {
const avg = norms.reduce((a, b) => a + b, 0) / norms.length;
weightMap.set(id, avg / maxNorm);
}
}
// Brain geometry β center in upper portion, leave room for brain stem below
const cx = canvasWidth * 0.50;
const cy = canvasHeight * 0.38;
const rx = canvasWidth * 0.28;
const ry = canvasHeight * 0.25;
const innerScale = 0.62; // MLP ring as fraction of outer ring
// Separate layers by type
const embedding = layers.filter((l) => l.layer_type === 'embedding');
const output = layers.filter((l) => l.layer_type === 'output');
const attention = layers.filter((l) => l.layer_type === 'attention');
const mlp = layers.filter((l) => l.layer_type === 'mlp');
const nodes: CanvasNode[] = [];
const makeNode = (layer: LayerStructure, x: number, y: number): CanvasNode => {
const paramRatio = layer.param_count / maxParam;
const baseRadius = 4 + paramRatio * 8;
const wRatio = weightMap.get(layer.layer_id) ?? paramRatio;
return {
id: layer.layer_id,
x, y,
radius: baseRadius,
layerType: layer.layer_type,
layerIndex: layer.layer_index,
paramCount: layer.param_count,
ratio: wRatio,
};
};
// Arc from occipital (back/left, 160Β°) through parietal (top, 90Β°) to frontal (front/right, 20Β°)
// In standard math angles with SVG y-inversion: positive sin β upward
// arcSpan is negative β clockwise traversal through the top
const arcStart = (160 / 180) * Math.PI;
const arcEnd = (20 / 180) * Math.PI;
const arcSpan = arcEnd - arcStart; // negative
// Embedding β brain stem (bottom center)
for (const layer of embedding) {
nodes.push(makeNode(layer, cx, cy + ry + canvasHeight * 0.12));
}
// Attention β outer cortex ring
for (let i = 0; i < attention.length; i++) {
const t = attention.length > 1 ? i / (attention.length - 1) : 0.5;
const angle = arcStart + t * arcSpan;
const x = cx + rx * Math.cos(angle);
const y = cy - ry * Math.sin(angle); // SVG y inverted
nodes.push(makeNode(attention[i], x, y));
}
// MLP β inner ring (white matter)
for (let i = 0; i < mlp.length; i++) {
const t = mlp.length > 1 ? i / (mlp.length - 1) : 0.5;
const angle = arcStart + t * arcSpan;
const x = cx + rx * innerScale * Math.cos(angle);
const y = cy - ry * innerScale * Math.sin(angle);
nodes.push(makeNode(mlp[i], x, y));
}
// Output β frontal pole (top-right)
for (const layer of output) {
nodes.push(makeNode(layer, cx + rx * 0.55, cy - ry * 0.70));
}
// Build edges
const nodeMap = new Map(nodes.map((n) => [n.id, n]));
const edges: CanvasEdge[] = connections
.filter((c) => nodeMap.has(c.from_id) && nodeMap.has(c.to_id))
.map((c) => {
const from = nodeMap.get(c.from_id)!;
const to = nodeMap.get(c.to_id)!;
return {
fromId: c.from_id,
toId: c.to_id,
strength: (from.ratio + to.ratio) / 2,
};
});
const brainPath = generateBrainPath(cx, cy, rx, ry, canvasHeight);
return { nodes, edges, brainPath };
}
/**
* Generate a stylized brain silhouette path (sagittal view).
* The path wraps around the node positions with ~20% margin,
* forming a recognizable brain shape with a brain stem.
*
* Shape: occipital (back/left) β parietal (top) β frontal (front/right)
* β temporal (front-bottom) β brain stem β back to occipital
*/
function generateBrainPath(
cx: number, cy: number,
rx: number, ry: number,
canvasHeight: number,
): string {
// Silhouette extends ~22% beyond the node ellipse
const mx = rx * 1.22;
const my = ry * 1.22;
// Brain stem bottom (slightly below embedding position)
const stemY = cy + ry + canvasHeight * 0.15;
// Trace clockwise starting from brain stem bottom
return [
// Brain stem bottom
`M ${cx} ${stemY}`,
// Left/back: stem up through cerebellum area to lower occipital
`C ${cx - mx * 0.12} ${stemY - my * 0.15},
${cx - mx * 0.42} ${cy + my * 0.65},
${cx - mx * 0.55} ${cy + my * 0.35}`,
// Occipital (back of brain): continue upward
`C ${cx - mx * 0.72} ${cy + my * 0.05},
${cx - mx * 0.88} ${cy - my * 0.15},
${cx - mx * 0.88} ${cy - my * 0.38}`,
// Occipital to parietal (back-top to top)
`C ${cx - mx * 0.88} ${cy - my * 0.62},
${cx - mx * 0.60} ${cy - my * 0.92},
${cx - mx * 0.18} ${cy - my * 0.98}`,
// Parietal to frontal (top to front-top)
`C ${cx + mx * 0.18} ${cy - my * 1.02},
${cx + mx * 0.58} ${cy - my * 0.95},
${cx + mx * 0.82} ${cy - my * 0.68}`,
// Frontal (front of brain): curve downward
`C ${cx + mx * 0.96} ${cy - my * 0.42},
${cx + mx * 1.00} ${cy - my * 0.08},
${cx + mx * 0.92} ${cy + my * 0.18}`,
// Temporal (front-bottom): continue down toward stem
`C ${cx + mx * 0.82} ${cy + my * 0.42},
${cx + mx * 0.55} ${cy + my * 0.60},
${cx + mx * 0.35} ${cy + my * 0.55}`,
// Temporal to brain stem
`C ${cx + mx * 0.18} ${cy + my * 0.50},
${cx + mx * 0.10} ${stemY - my * 0.10},
${cx} ${stemY}`,
'Z',
].join(' ');
}
|