joelniklaus HF Staff commited on
Commit
f7eff68
·
1 Parent(s): 1c77cab

improved pipeline display

Browse files
app/src/content/embeds/d3-pipeline.html CHANGED
@@ -129,14 +129,14 @@
129
  const W = container.clientWidth || 820;
130
  const s = Math.min(1, W / 820);
131
 
132
- const nw = Math.round(166 * s), nh = Math.round(48 * s);
133
  const nr = Math.round(10 * s);
134
  const gp = Math.round(10 * s); // group padding
135
  const gr = Math.round(10 * s); // group corner radius
136
  const glh = Math.round(22 * s); // group label height
137
- const ng = Math.round(8 * s); // node gap within group
138
- const cg = Math.round(28 * s); // column gap
139
- const rg = Math.round(16 * s); // row gap between groups
140
 
141
  // Three columns: left (exec + inference), center (input + pipeline), right (output)
142
  const leftW = nw + gp * 2;
@@ -150,7 +150,7 @@
150
  const rightX = offsetX + leftW + cg + centerW + cg;
151
 
152
  // -- Center column: Input (1 node) + Pipeline (3 nodes)
153
- let y = Math.round(6 * s);
154
  const inputNode = nodes.find(n => n.id === 'hf_in');
155
  inputNode._x = centerX + gp; inputNode._y = y + glh + gp;
156
  inputNode._w = nw; inputNode._h = nh; inputNode._r = nr;
@@ -172,16 +172,15 @@
172
  pipeGroup._w = centerW; pipeGroup._h = pipeH; pipeGroup._r = gr;
173
 
174
  // -- Left column: Execution + Inference
175
- // Vertically center the left column with the pipeline
176
  const execNodes = ['local', 'slurm'].map(id => nodes.find(n => n.id === id));
177
  const execH = glh + gp * 2 + execNodes.length * nh + (execNodes.length - 1) * ng;
178
  const inferNodes = ['rollout', 'vllm'].map(id => nodes.find(n => n.id === id));
179
  const inferH = glh + gp * 2 + inferNodes.length * nh + (inferNodes.length - 1) * ng;
180
- const leftTotalH = execH + rg + inferH;
181
- const leftCenterY = pipeTop + pipeH / 2;
182
- const leftTop = Math.max(pipeTop, leftCenterY - leftTotalH / 2);
183
-
184
- const execTop = leftTop;
185
  execNodes.forEach((n, i) => {
186
  n._x = leftX + gp; n._y = execTop + glh + gp + i * (nh + ng);
187
  n._w = nw; n._h = nh; n._r = nr;
@@ -190,7 +189,6 @@
190
  execGroup._x = leftX; execGroup._y = execTop;
191
  execGroup._w = leftW; execGroup._h = execH; execGroup._r = gr;
192
 
193
- const inferTop = execTop + execH + rg;
194
  inferNodes.forEach((n, i) => {
195
  n._x = leftX + gp; n._y = inferTop + glh + gp + i * (nh + ng);
196
  n._w = nw; n._h = nh; n._r = nr;
@@ -199,10 +197,11 @@
199
  inferGroup._x = leftX; inferGroup._y = inferTop;
200
  inferGroup._w = leftW; inferGroup._h = inferH; inferGroup._r = gr;
201
 
202
- // -- Right column: Output (vertically centered with pipeline)
203
  const outNodes = ['hf_out', 'card', 'monitor'].map(id => nodes.find(n => n.id === id));
204
  const outH = glh + gp * 2 + outNodes.length * nh + (outNodes.length - 1) * ng;
205
- const outTop = pipeTop + (pipeH - outH) / 2;
 
206
  outNodes.forEach((n, i) => {
207
  n._x = rightX + gp; n._y = outTop + glh + gp + i * (nh + ng);
208
  n._w = nw; n._h = nh; n._r = nr;
@@ -211,29 +210,40 @@
211
  outGroup._x = rightX; outGroup._y = outTop;
212
  outGroup._w = rightW; outGroup._h = outH; outGroup._r = gr;
213
 
 
 
 
 
 
 
 
 
 
 
214
  const maxY = Math.max(
215
  ...nodes.map(n => n._y + n._h + gp),
216
  ...groups.map(g => g._y + g._h)
217
  );
218
- svg.attr('height', maxY + Math.round(8 * s));
219
 
220
  return s;
221
  }
222
 
223
- function anchor(n, side) {
224
- if (side === 'top') return { x: n._x + n._w / 2, y: n._y };
225
- if (side === 'bottom') return { x: n._x + n._w / 2, y: n._y + n._h };
226
- if (side === 'left') return { x: n._x, y: n._y + n._h / 2 };
227
- if (side === 'right') return { x: n._x + n._w, y: n._y + n._h / 2 };
 
228
  }
229
 
230
- function bezier(a, b, orient) {
231
- if (orient === 'v') {
232
- const d = (b.y - a.y) * 0.45;
233
- return `M${a.x},${a.y} C${a.x},${a.y + d} ${b.x},${b.y - d} ${b.x},${b.y}`;
234
- }
235
- const d = (b.x - a.x) * 0.4;
236
- return `M${a.x},${a.y} C${a.x + d},${a.y} ${b.x - d},${b.y} ${b.x},${b.y}`;
237
  }
238
 
239
  function edgePath(e) {
@@ -241,18 +251,19 @@
241
  const t = nodes.find(n => n.id === e.to);
242
  if (!f || !t) return '';
243
 
244
- // Explicit routing for each edge
245
- if (e.from === 'hf_in' && e.to === 'read') return bezier(anchor(f, 'bottom'), anchor(t, 'top'), 'v');
246
- if (e.from === 'read' && e.to === 'transform') return bezier(anchor(f, 'bottom'), anchor(t, 'top'), 'v');
247
- if (e.from === 'transform' && e.to === 'write') return bezier(anchor(f, 'bottom'), anchor(t, 'top'), 'v');
248
- if (e.from === 'transform' && e.to === 'rollout') return bezier(anchor(f, 'left'), anchor(t, 'right'), 'h');
249
- if (e.from === 'rollout' && e.to === 'vllm') return bezier(anchor(f, 'bottom'), anchor(t, 'top'), 'v');
250
- if (e.from === 'write' && e.to === 'hf_out') return bezier(anchor(f, 'right'), anchor(t, 'left'), 'h');
251
- if (e.from === 'write' && e.to === 'card') return bezier(anchor(f, 'right'), anchor(t, 'left'), 'h');
252
- if (e.from === 'write' && e.to === 'monitor') return bezier(anchor(f, 'right'), anchor(t, 'left'), 'h');
253
-
254
- // Fallback
255
- return bezier(anchor(f, 'right'), anchor(t, 'left'), 'h');
 
256
  }
257
 
258
  function render() {
@@ -281,13 +292,13 @@
281
  .attr('stroke', d => d.id === 'pipeline' ? c.pipeBd : c.groupBd)
282
  .attr('stroke-width', 1);
283
  gM.select('.grp-icon')
284
- .attr('x', d => d._x + Math.round(10 * s))
285
- .attr('y', d => d._y + Math.round(19 * s))
286
  .style('font-size', fsIcon + 'px')
287
  .text(d => d.icon);
288
  gM.select('.group-label')
289
- .attr('x', d => d._x + Math.round(10 * s) + fsIcon + Math.round(3 * s))
290
- .attr('y', d => d._y + Math.round(19 * s))
291
  .style('font-size', fsGrp + 'px')
292
  .text(d => d.label);
293
  gSel.exit().remove();
 
129
  const W = container.clientWidth || 820;
130
  const s = Math.min(1, W / 820);
131
 
132
+ const nw = Math.round(200 * s), nh = Math.round(60 * s);
133
  const nr = Math.round(10 * s);
134
  const gp = Math.round(10 * s); // group padding
135
  const gr = Math.round(10 * s); // group corner radius
136
  const glh = Math.round(22 * s); // group label height
137
+ const ng = Math.round(7 * s); // node gap within group
138
+ const cg = Math.round(70 * s); // column gap
139
+ const rg = Math.round(14 * s); // row gap between groups
140
 
141
  // Three columns: left (exec + inference), center (input + pipeline), right (output)
142
  const leftW = nw + gp * 2;
 
150
  const rightX = offsetX + leftW + cg + centerW + cg;
151
 
152
  // -- Center column: Input (1 node) + Pipeline (3 nodes)
153
+ let y = Math.round(4 * s);
154
  const inputNode = nodes.find(n => n.id === 'hf_in');
155
  inputNode._x = centerX + gp; inputNode._y = y + glh + gp;
156
  inputNode._w = nw; inputNode._h = nh; inputNode._r = nr;
 
172
  pipeGroup._w = centerW; pipeGroup._h = pipeH; pipeGroup._r = gr;
173
 
174
  // -- Left column: Execution + Inference
175
+ // Position so inference engine bottom aligns with write node
176
  const execNodes = ['local', 'slurm'].map(id => nodes.find(n => n.id === id));
177
  const execH = glh + gp * 2 + execNodes.length * nh + (execNodes.length - 1) * ng;
178
  const inferNodes = ['rollout', 'vllm'].map(id => nodes.find(n => n.id === id));
179
  const inferH = glh + gp * 2 + inferNodes.length * nh + (inferNodes.length - 1) * ng;
180
+ const writeNode = nodes.find(n => n.id === 'write');
181
+ const inferBottom = writeNode._y + writeNode._h + gp;
182
+ const inferTop = inferBottom - inferH;
183
+ const execTop = inferTop - rg - execH;
 
184
  execNodes.forEach((n, i) => {
185
  n._x = leftX + gp; n._y = execTop + glh + gp + i * (nh + ng);
186
  n._w = nw; n._h = nh; n._r = nr;
 
189
  execGroup._x = leftX; execGroup._y = execTop;
190
  execGroup._w = leftW; execGroup._h = execH; execGroup._r = gr;
191
 
 
192
  inferNodes.forEach((n, i) => {
193
  n._x = leftX + gp; n._y = inferTop + glh + gp + i * (nh + ng);
194
  n._w = nw; n._h = nh; n._r = nr;
 
197
  inferGroup._x = leftX; inferGroup._y = inferTop;
198
  inferGroup._w = leftW; inferGroup._h = inferH; inferGroup._r = gr;
199
 
200
+ // -- Right column: Output (align bottom with write node)
201
  const outNodes = ['hf_out', 'card', 'monitor'].map(id => nodes.find(n => n.id === id));
202
  const outH = glh + gp * 2 + outNodes.length * nh + (outNodes.length - 1) * ng;
203
+ const outBottom = writeNode._y + writeNode._h + gp;
204
+ const outTop = outBottom - outH;
205
  outNodes.forEach((n, i) => {
206
  n._x = rightX + gp; n._y = outTop + glh + gp + i * (nh + ng);
207
  n._w = nw; n._h = nh; n._r = nr;
 
210
  outGroup._x = rightX; outGroup._y = outTop;
211
  outGroup._w = rightW; outGroup._h = outH; outGroup._r = gr;
212
 
213
+ const minY = Math.min(
214
+ ...nodes.map(n => n._y),
215
+ ...groups.map(g => g._y)
216
+ );
217
+ if (minY < 0) {
218
+ // Shift everything down so nothing is clipped
219
+ const shift = -minY + Math.round(4 * s);
220
+ nodes.forEach(n => { n._y += shift; });
221
+ groups.forEach(g => { g._y += shift; });
222
+ }
223
  const maxY = Math.max(
224
  ...nodes.map(n => n._y + n._h + gp),
225
  ...groups.map(g => g._y + g._h)
226
  );
227
+ svg.attr('height', maxY + Math.round(4 * s));
228
 
229
  return s;
230
  }
231
 
232
+ function pt(n, side, offset) {
233
+ const o = offset || 0;
234
+ if (side === 'top') return { x: n._x + n._w / 2 + o, y: n._y };
235
+ if (side === 'bottom') return { x: n._x + n._w / 2 + o, y: n._y + n._h };
236
+ if (side === 'left') return { x: n._x, y: n._y + n._h / 2 + o };
237
+ if (side === 'right') return { x: n._x + n._w, y: n._y + n._h / 2 + o };
238
  }
239
 
240
+ function hBez(a, b) {
241
+ const mx = (a.x + b.x) / 2;
242
+ return `M${a.x},${a.y} C${mx},${a.y} ${mx},${b.y} ${b.x},${b.y}`;
243
+ }
244
+ function vBez(a, b) {
245
+ const my = (a.y + b.y) / 2;
246
+ return `M${a.x},${a.y} C${a.x},${my} ${b.x},${my} ${b.x},${b.y}`;
247
  }
248
 
249
  function edgePath(e) {
 
251
  const t = nodes.find(n => n.id === e.to);
252
  if (!f || !t) return '';
253
 
254
+ if (e.from === 'hf_in' && e.to === 'read') return vBez(pt(f,'bottom'), pt(t,'top'));
255
+ if (e.from === 'read' && e.to === 'transform') return vBez(pt(f,'bottom'), pt(t,'top'));
256
+ if (e.from === 'transform' && e.to === 'write') return vBez(pt(f,'bottom'), pt(t,'top'));
257
+ if (e.from === 'transform' && e.to === 'rollout') return hBez(pt(f,'left'), pt(t,'right'));
258
+ if (e.from === 'rollout' && e.to === 'vllm') return vBez(pt(f,'bottom'), pt(t,'top'));
259
+
260
+ // Fan out from Write: top/center/bottom of right edge
261
+ const sp = Math.round(f._h * 0.28);
262
+ if (e.from === 'write' && e.to === 'hf_out') return hBez(pt(f,'right', -sp), pt(t,'left'));
263
+ if (e.from === 'write' && e.to === 'card') return hBez(pt(f,'right'), pt(t,'left'));
264
+ if (e.from === 'write' && e.to === 'monitor') return hBez(pt(f,'right', sp), pt(t,'left'));
265
+
266
+ return hBez(pt(f,'right'), pt(t,'left'));
267
  }
268
 
269
  function render() {
 
292
  .attr('stroke', d => d.id === 'pipeline' ? c.pipeBd : c.groupBd)
293
  .attr('stroke-width', 1);
294
  gM.select('.grp-icon')
295
+ .attr('x', d => d._x + Math.round(6 * s))
296
+ .attr('y', d => d._y + Math.round(15 * s))
297
  .style('font-size', fsIcon + 'px')
298
  .text(d => d.icon);
299
  gM.select('.group-label')
300
+ .attr('x', d => d._x + Math.round(6 * s) + fsIcon + Math.round(3 * s))
301
+ .attr('y', d => d._y + Math.round(15 * s))
302
  .style('font-size', fsGrp + 'px')
303
  .text(d => d.label);
304
  gSel.exit().remove();