File size: 13,802 Bytes
b9e7b9b
 
 
 
 
 
c24ea90
 
b9e7b9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c24ea90
b9e7b9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c24ea90
 
b9e7b9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c24ea90
b9e7b9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c59bbe1
 
 
 
 
b9e7b9b
 
 
 
 
 
 
 
 
 
 
c59bbe1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9e7b9b
 
 
 
 
 
 
 
 
 
c59bbe1
b9e7b9b
c59bbe1
b9e7b9b
c59bbe1
b9e7b9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
<div class="d3-bar" style="width:100%;margin:10px 0;"></div>
<style>
  .d3-bar .controls { margin-top: 12px; display: flex; gap: 16px; align-items: center; flex-wrap: wrap; }
  .d3-bar .controls label { font-size: 12px; color: var(--muted-color); display: flex; align-items: center; gap: 8px; white-space: nowrap; padding: 6px 10px; }
  .d3-bar .controls select { font-size: 12px; padding: 8px 28px 8px 10px; border: 1px solid var(--border-color); border-radius: 8px; background-color: var(--surface-bg); color: var(--text-color); background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' viewBox='0 0 24 24' fill='none' stroke='%230f1115' stroke-width='2' stroke-linecap='round' stroke-linejoin='round'%3E%3Cpolyline points='6 9 12 15 18 9'/%3E%3C/svg%3E"); background-repeat: no-repeat; background-position: right 8px center; background-size: 12px; -webkit-appearance: none; -moz-appearance: none; appearance: none; cursor: pointer; transition: border-color .15s ease, box-shadow .15s ease; }
  [data-theme="dark"] .d3-bar .controls select { background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' viewBox='0 0 24 24' fill='none' stroke='%23ffffff' stroke-width='2' stroke-linecap='round' stroke-linejoin='round'%3E%3Cpolyline points='6 9 12 15 18 9'/%3E%3C/svg%3E"); }
  .d3-bar .controls select:hover { border-color: var(--primary-color); }
  .d3-bar .controls select:focus { border-color: var(--primary-color); box-shadow: 0 0 0 3px rgba(232,137,171,.25); outline: none; }
  .d3-bar .legend { font-size: 12px; line-height: 1.35; color: var(--text-color); }
</style>
<script>
  (() => {
    const ensureD3 = (cb) => {
      if (window.d3 && typeof window.d3.select === 'function') return cb();
      let s = document.getElementById('d3-cdn-script');
      if (!s) { s = document.createElement('script'); s.id = 'd3-cdn-script'; s.src = 'https://cdn.jsdelivr.net/npm/d3@7/dist/d3.min.js'; document.head.appendChild(s); }
      const onReady = () => { if (window.d3 && typeof window.d3.select === 'function') cb(); };
      s.addEventListener('load', onReady, { once: true });
      if (window.d3) onReady();
    };

    const bootstrap = () => {
      const mount = document.currentScript ? document.currentScript.previousElementSibling : null;
      const container = (mount && mount.querySelector && mount.querySelector('.d3-bar')) || document.querySelector('.d3-bar');
      if (!container) return;
      if (container.dataset) { if (container.dataset.mounted === 'true') return; container.dataset.mounted = 'true'; }

      // Data, matching bar.py
      const seqLabels = ["1024","2048","4096","8192"];
      const seqScale = [1,2,4,8];
      const components = [
        { key: 'parameters',  color: 'rgb(78, 165, 183)' },
        { key: 'gradients',   color: 'rgb(227, 138, 66)' },
        { key: 'optimizer',   color: 'var(--primary-color)' },
        { key: 'activations', color: 'rgb(206, 192, 250)' },
      ];
      const modelSizes = ["1B","3B","8B","70B","405B"];
      const paramsMem = { "1B":4.0, "3B":13.3, "8B":26.0, "70B":244.0, "405B":1520.0 };
      const actCoeff = { "1B":3.6, "3B":9.3, "8B":46.2, "70B":145.7, "405B":1519.9 };
      const recomputeModes = ["none","selective","full"];

      const activationsCurve = (sizeKey, mode) => {
        const coeff = actCoeff[sizeKey];
        let arr = seqScale.map((v) => coeff * (v * v));
        if (mode === 'selective') arr = arr.map((v) => v * 0.25);
        else if (mode === 'full') arr = arr.map((v) => v * (1 / 16));
        return arr;
      };
      const stackFor = (sizeKey, mode) => {
        const p = seqScale.map(() => paramsMem[sizeKey]);
        const g = seqScale.map(() => paramsMem[sizeKey]);
        const o = seqScale.map(() => 2*paramsMem[sizeKey]);
        const a = activationsCurve(sizeKey, mode);
        return { parameters: p, gradients: g, optimizer: o, activations: a };
      };

      const Y = {}; // Y[mode][size][component] => array
      recomputeModes.forEach((m) => {
        Y[m] = {}; modelSizes.forEach((s) => { Y[m][s] = stackFor(s, m); });
      });

      // Controls
      const controls = document.createElement('div');
      controls.className = 'controls';
      const labelSize = document.createElement('label'); labelSize.textContent = 'Model Size';
      const selSize = document.createElement('select'); modelSizes.forEach((s) => { const o = document.createElement('option'); o.value = s; o.textContent = s; selSize.appendChild(o); });
      labelSize.appendChild(selSize);
      const labelRecomp = document.createElement('label'); labelRecomp.textContent = 'Recomputation';
      const selRecomp = document.createElement('select'); recomputeModes.forEach((m) => { const o = document.createElement('option'); o.value = m; o.textContent = m; selRecomp.appendChild(o); });
      labelRecomp.appendChild(selRecomp);

      // SVG scaffolding
      const svg = d3.select(container).append('svg').attr('width','100%').style('display','block');
      const gRoot = svg.append('g');
      const gGrid = gRoot.append('g').attr('class','grid');
      const gAxes = gRoot.append('g').attr('class','axes');
      const gBars = gRoot.append('g').attr('class','bars');
      const gLegend = gRoot.append('foreignObject').attr('class','legend');

      // Tooltip
      container.style.position = container.style.position || 'relative';
      let tip = container.querySelector('.d3-tooltip'); let tipInner;
      if (!tip) { tip = document.createElement('div'); tip.className = 'd3-tooltip'; Object.assign(tip.style,{ position:'absolute', top:'0px', left:'0px', transform:'translate(-9999px, -9999px)', pointerEvents:'none', padding:'8px 10px', borderRadius:'8px', fontSize:'12px', lineHeight:'1.35', border:'1px solid var(--border-color)', background:'var(--surface-bg)', color:'var(--text-color)', boxShadow:'0 4px 24px rgba(0,0,0,.18)', opacity:'0', transition:'opacity .12s ease' }); tipInner = document.createElement('div'); tipInner.className = 'd3-tooltip__inner'; tipInner.style.textAlign='left'; tip.appendChild(tipInner); container.appendChild(tip); } else { tipInner = tip.querySelector('.d3-tooltip__inner') || tip; }

      // State
      let currentSize = modelSizes[0];
      let currentMode = 'selective';
      selRecomp.value = currentMode;

      // Layout & scales
      let width=800, height=360; const margin = { top: 16, right: 28, bottom: 56, left: 64 };
      const x0 = d3.scaleBand().paddingInner(0.25).paddingOuter(0.1); // groups (seq)
      const y = d3.scaleLinear();
      const colorOf = (key) => components.find((c)=>c.key===key).color;

      function yMax(sizeKey, mode){
        const s = Y[mode][sizeKey];
        let max = 0; for (let i=0;i<seqLabels.length;i++){ const sum = s.parameters[i]+s.gradients[i]+s.optimizer[i]+s.activations[i]; if (sum>max) max=sum; }
        return max*1.05;
      }

      function renderLegend(innerWidth, innerHeight){
        const legendWidth = 160, legendHeight = 84;
        gLegend.attr('x', 15).attr('y', -3).attr('width', legendWidth).attr('height', legendHeight);
        const root = gLegend.selectAll('div').data([0]).join('xhtml:div');
        root.html(`
          <div style="display:flex;flex-direction:column;gap:6px;">
            ${components.map(c => `<div style="display:flex;align-items:center;gap:8px;"><span style="width:18px;height:10px;background:${c.color};border-radius:2px;display:inline-block"></span><span>${c.key}</span></div>`).join('')}
          </div>
        `);
      }

      function updateScales(){
        const isDark = document.documentElement.getAttribute('data-theme') === 'dark';
        const axisColor = isDark ? 'rgba(255,255,255,0.25)' : 'rgba(0,0,0,0.25)';
        const tickColor = isDark ? 'rgba(255,255,255,0.70)' : 'rgba(0,0,0,0.55)';
        const gridColor = isDark ? 'rgba(255,255,255,0.08)' : 'rgba(0,0,0,0.05)';

        width = container.clientWidth || 800; height = Math.max(260, Math.round(width/3)); svg.attr('width', width).attr('height', height);
        const innerWidth = width - margin.left - margin.right; const innerHeight = height - margin.top - margin.bottom; gRoot.attr('transform', `translate(${margin.left},${margin.top})`);

        x0.domain(seqLabels).range([0, innerWidth]);
        y.domain([0, yMax(currentSize, currentMode)]).range([innerHeight, 0]).nice();

        // Grid
        gGrid.selectAll('*').remove();
        gGrid.selectAll('line').data(y.ticks(6)).join('line')
          .attr('x1', 0).attr('x2', innerWidth).attr('y1', (d)=>y(d)).attr('y2', (d)=>y(d))
          .attr('stroke', gridColor).attr('stroke-width', 1).attr('shape-rendering', 'crispEdges');

        // Axes
        gAxes.selectAll('*').remove();
        gAxes.append('g').attr('transform', `translate(0,${innerHeight})`).call(d3.axisBottom(x0)).call((g)=>{ g.selectAll('path, line').attr('stroke', axisColor); g.selectAll('text').attr('fill', tickColor).style('font-size','12px'); });
        gAxes.append('g').call(d3.axisLeft(y).ticks(6).tickFormat(d3.format('~f'))).call((g)=>{ g.selectAll('path, line').attr('stroke', axisColor); g.selectAll('text').attr('fill', tickColor).style('font-size','12px'); });

        // Axis labels
        gAxes.append('text').attr('class','axis-label axis-label--x').attr('x', innerWidth/2).attr('y', innerHeight + 44).attr('text-anchor','middle').style('font-size','12px').style('fill', tickColor).text('Sequence Length');
        gAxes.append('text').attr('class','axis-label axis-label--y').attr('text-anchor','middle').attr('transform', `translate(${-52},${innerHeight/2}) rotate(-90)`).style('font-size','12px').style('fill', tickColor).text('Memory (GB)');

        renderLegend(innerWidth, innerHeight);

        return { innerWidth, innerHeight };
      }

      function drawBars(){
        const stacks = Y[currentMode][currentSize];
        const series = components.map((c)=>({ key: c.key, color: c.color, values: stacks[c.key] }));
        // Stack values
        const stacked = seqLabels.map((label, i) => {
          let acc = 0; const items = [];
          series.forEach((s, idx) => {
            const y0 = acc; const y1 = acc + s.values[i];
            items.push({ key: s.key, color: s.color, i, y0, y1, xLabel: label, value: s.values[i], isBottom: idx === 0, isTop: idx === series.length - 1 });
            acc = y1;
          });
          return { label, items };
        });

        const { innerWidth, innerHeight } = updateScales();

        const bandWidth = x0.bandwidth();
        const groups = gBars.selectAll('g.bar-group').data(stacked, d=>d.label);
        const groupsEnter = groups.enter().append('g').attr('class','bar-group');
        groupsEnter.merge(groups).attr('transform', (d)=>`translate(${x0(d.label)},0)`);
        groups.exit().remove();

        // Helper to draw per-corner rounded rectangle path
        const rCorner = 4;
        const roundedPath = (x, yTop, w, h, isTop, isBottom) => {
          const r = Math.min(rCorner, Math.max(0, Math.min(w, h) / 2));
          const rTL = isTop ? r : 0, rTR = isTop ? r : 0, rBR = isBottom ? r : 0, rBL = isBottom ? r : 0;
          const x0 = x, y0 = yTop, x1 = x + w, y1 = yTop + h;
          return `M${x0 + rTL},${y0}`
            + `H${x1 - rTR}`
            + (rTR ? `Q${x1},${y0} ${x1},${y0 + rTR}` : `V${y0}`)
            + `V${y1 - rBR}`
            + (rBR ? `Q${x1},${y1} ${x1 - rBR},${y1}` : `H${x1}`)
            + `H${x0 + rBL}`
            + (rBL ? `Q${x0},${y1} ${x0},${y1 - rBL}` : `V${y1}`)
            + `V${y0 + rTL}`
            + (rTL ? `Q${x0},${y0} ${x0 + rTL},${y0}` : `H${x0}`)
            + 'Z';
        };

        const bars = groupsEnter.merge(groups).selectAll('path.bar').data(d=>d.items, d=>d.key);
        bars.enter().append('path').attr('class','bar')
          .attr('d', (d)=> roundedPath(0, y(d.y1), bandWidth, Math.max(0.5, y(d.y0) - y(d.y1)), d.isTop, d.isBottom))
          .attr('fill', (d)=>d.color)
          .on('mouseenter', function(ev, d){
            d3.select(this).attr('stroke', 'rgba(0,0,0,0.85)').attr('stroke-width', 1);
            tipInner.innerHTML = `<div><strong>${d.key}</strong></div><div><strong>Seq</strong> ${d.xLabel}</div><div><strong>Mem</strong> ${d.value.toFixed(1)} GB</div>`;
            tip.style.opacity = '1';
          })
          .on('mousemove', function(ev, d){
            const [mx, my] = d3.pointer(ev, container); const offsetX = 12, offsetY = 12; tip.style.transform = `translate(${Math.round(mx+offsetX)}px, ${Math.round(my+offsetY)}px)`;
          })
          .on('mouseleave', function(){ tip.style.opacity='0'; tip.style.transform='translate(-9999px, -9999px)'; d3.select(this).attr('stroke','none'); })
          .merge(bars)
          .transition().duration(200)
          .attr('d', (d)=> roundedPath(0, y(d.y1), bandWidth, Math.max(0.5, y(d.y0) - y(d.y1)), d.isTop, d.isBottom))
          .attr('fill', (d)=>d.color);
        bars.exit().remove();
      }

      function update(){ drawBars(); }

      // Boot
      update();
      container.appendChild(controls);
      controls.appendChild(labelSize); controls.appendChild(labelRecomp);
      selSize.addEventListener('change', (e)=>{ currentSize = e.target.value; update(); });
      selRecomp.addEventListener('change', (e)=>{ currentMode = e.target.value; update(); });

      const rerender = () => { update(); };
      if (window.ResizeObserver) { const ro = new ResizeObserver(()=>rerender()); ro.observe(container); } else { window.addEventListener('resize', rerender); }
    };

    if (document.readyState === 'loading') { document.addEventListener('DOMContentLoaded', () => ensureD3(bootstrap), { once: true }); } else { ensureD3(bootstrap); }
  })();
</script>