File size: 4,631 Bytes
0ce9643
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import type { CanvasNode, CanvasEdge } from './ScanCanvas';
import type { LayerStructure, ConnectionInfo, LayerWeightStats } from '../../types/scan';

/**
 * Build a force-directed network layout using a simple iterative simulation.
 *
 * Nodes repel each other while edges act as springs, producing an organic
 * graph layout. We run a synchronous simulation (no D3 force dependency)
 * to keep it deterministic and fast.
 */
export function buildNetworkLayout(
  layers: LayerStructure[],
  connections: ConnectionInfo[],
  canvasWidth: number,
  canvasHeight: number,
  weightLayers?: LayerWeightStats[],
): { nodes: CanvasNode[]; edges: CanvasEdge[] } {
  const maxParam = Math.max(...layers.map((l) => l.param_count), 1);

  // Weight lookup
  const weightMap = new Map<string, number>();
  if (weightLayers) {
    const grouped = new Map<string, number[]>();
    for (const w of weightLayers) {
      const arr = grouped.get(w.layer_id) ?? [];
      arr.push(w.l2_norm);
      grouped.set(w.layer_id, arr);
    }
    const maxNorm = Math.max(...weightLayers.map((w) => w.l2_norm), 1);
    for (const [id, norms] of grouped) {
      const avg = norms.reduce((a, b) => a + b, 0) / norms.length;
      weightMap.set(id, avg / maxNorm);
    }
  }

  const cx = canvasWidth / 2;
  const cy = canvasHeight / 2;
  const n = layers.length;

  // Initialize positions in a circle
  const positions: { x: number; y: number }[] = layers.map((_, i) => {
    const angle = (2 * Math.PI * i) / n - Math.PI / 2;
    const r = Math.min(canvasWidth, canvasHeight) * 0.3;
    return { x: cx + r * Math.cos(angle), y: cy + r * Math.sin(angle) };
  });

  // Build adjacency set for spring forces
  const connSet = new Set<string>();
  for (const c of connections) {
    connSet.add(`${c.from_id}|${c.to_id}`);
  }
  const idToIdx = new Map(layers.map((l, i) => [l.layer_id, i]));

  // Run simple force simulation
  const iterations = 80;
  const repulsion = 3000;
  const springK = 0.02;
  const springLen = Math.min(canvasWidth, canvasHeight) * 0.12;
  const damping = 0.85;
  const padding = 50;

  const vx = new Float64Array(n);
  const vy = new Float64Array(n);

  for (let iter = 0; iter < iterations; iter++) {
    const temp = 1 - iter / iterations; // cooling

    // Repulsion between all pairs
    for (let i = 0; i < n; i++) {
      for (let j = i + 1; j < n; j++) {
        let dx = positions[j].x - positions[i].x;
        let dy = positions[j].y - positions[i].y;
        const dist = Math.sqrt(dx * dx + dy * dy) || 1;
        const force = (repulsion * temp) / (dist * dist);
        dx = (dx / dist) * force;
        dy = (dy / dist) * force;
        vx[i] -= dx;
        vy[i] -= dy;
        vx[j] += dx;
        vy[j] += dy;
      }
    }

    // Spring forces along edges
    for (const c of connections) {
      const i = idToIdx.get(c.from_id);
      const j = idToIdx.get(c.to_id);
      if (i === undefined || j === undefined) continue;
      let dx = positions[j].x - positions[i].x;
      let dy = positions[j].y - positions[i].y;
      const dist = Math.sqrt(dx * dx + dy * dy) || 1;
      const force = springK * (dist - springLen);
      dx = (dx / dist) * force;
      dy = (dy / dist) * force;
      vx[i] += dx;
      vy[i] += dy;
      vx[j] -= dx;
      vy[j] -= dy;
    }

    // Apply velocities with damping and clamp to canvas
    for (let i = 0; i < n; i++) {
      vx[i] *= damping;
      vy[i] *= damping;
      positions[i].x += vx[i];
      positions[i].y += vy[i];
      positions[i].x = Math.max(padding, Math.min(canvasWidth - padding, positions[i].x));
      positions[i].y = Math.max(padding, Math.min(canvasHeight - padding, positions[i].y));
    }
  }

  // Build nodes
  const nodes: CanvasNode[] = layers.map((layer, i) => {
    const paramRatio = layer.param_count / maxParam;
    const baseRadius = 4 + paramRatio * 8;
    const wRatio = weightMap.get(layer.layer_id) ?? paramRatio;
    return {
      id: layer.layer_id,
      x: positions[i].x,
      y: positions[i].y,
      radius: baseRadius,
      layerType: layer.layer_type,
      layerIndex: layer.layer_index,
      paramCount: layer.param_count,
      ratio: wRatio,
    };
  });

  // Build edges
  const nodeMap = new Map(nodes.map((nd) => [nd.id, nd]));
  const edges: CanvasEdge[] = connections
    .filter((c) => nodeMap.has(c.from_id) && nodeMap.has(c.to_id))
    .map((c) => {
      const from = nodeMap.get(c.from_id)!;
      const to = nodeMap.get(c.to_id)!;
      return {
        fromId: c.from_id,
        toId: c.to_id,
        strength: (from.ratio + to.ratio) / 2,
      };
    });

  return { nodes, edges };
}