joelniklaus HF Staff commited on
Commit
17a96eb
·
1 Parent(s): 3ee1064

added sankey visualization for the start of the experiments

Browse files
app/src/content/chapters/experiments.mdx CHANGED
@@ -14,7 +14,14 @@ import FigRef from "../../components/FigRef.astro";
14
 
15
  ## Experiments
16
 
17
- With the infrastructure and setup in place, we now systematically work through our research questions. We start by benchmarking existing datasets and dissecting what makes their prompts tick. Then we test our own prompt designs, explore how the rephrasing model (size, family, generation) affects quality, and investigate the interplay between synthetic and original data. Along the way, we stumble into some surprising findings about typos and template collapse.
 
 
 
 
 
 
 
18
 
19
  ### How Do Existing Datasets Compare?
20
 
 
14
 
15
  ## Experiments
16
 
17
+ With the infrastructure and setup in place, we now systematically work through our research questions. <FigRef target="experiment-overview" /> shows the full landscape of our experiments as a flow from source datasets through prompt strategies to model families. We start by benchmarking existing datasets and dissecting what makes their prompts tick. Then we test our own prompt designs, explore how the rephrasing model (size, family, generation) affects quality, and investigate the interplay between synthetic and original data. Along the way, we stumble into some surprising findings about typos and template collapse.
18
+
19
+ <HtmlEmbed
20
+ id="experiment-overview"
21
+ src="experiment-overview.html"
22
+ data="rephrasing_metadata.json"
23
+ desc="Flow of experiments from source datasets through prompt strategies to model families. Hover over nodes and links to see experiment counts."
24
+ />
25
 
26
  ### How Do Existing Datasets Compare?
27
 
app/src/content/embeds/experiment-overview.html ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div class="d3-experiment-overview" style="width:100%;margin:10px 0;aspect-ratio:1/1;min-height:520px;"></div>
2
+ <style>
3
+ .d3-experiment-overview { position: relative; font-family: system-ui, -apple-system, sans-serif; }
4
+ </style>
5
+ <script>
6
+ (() => {
7
+ const ensureD3 = (cb) => {
8
+ if (window.d3 && typeof window.d3.select === 'function' && typeof window.d3.sankey === 'function') return cb();
9
+ const loadSankey = () => {
10
+ if (typeof window.d3.sankey === 'function') return cb();
11
+ let s2 = document.getElementById('d3-sankey-cdn');
12
+ if (!s2) {
13
+ s2 = document.createElement('script');
14
+ s2.id = 'd3-sankey-cdn';
15
+ s2.src = 'https://cdn.jsdelivr.net/npm/d3-sankey@0.12.3/dist/d3-sankey.min.js';
16
+ document.head.appendChild(s2);
17
+ }
18
+ s2.addEventListener('load', cb, { once: true });
19
+ };
20
+ let s = document.getElementById('d3-cdn-script');
21
+ if (!s) {
22
+ s = document.createElement('script');
23
+ s.id = 'd3-cdn-script';
24
+ s.src = 'https://cdn.jsdelivr.net/npm/d3@7/dist/d3.min.js';
25
+ document.head.appendChild(s);
26
+ }
27
+ if (window.d3 && typeof window.d3.select === 'function') { loadSankey(); return; }
28
+ s.addEventListener('load', loadSankey, { once: true });
29
+ };
30
+
31
+ const bootstrap = () => {
32
+ const scriptEl = document.currentScript;
33
+ let container = scriptEl ? scriptEl.previousElementSibling : null;
34
+ if (!(container && container.classList && container.classList.contains('d3-experiment-overview'))) {
35
+ const cs = Array.from(document.querySelectorAll('.d3-experiment-overview'))
36
+ .filter(el => !(el.dataset && el.dataset.mounted === 'true'));
37
+ container = cs[cs.length - 1] || null;
38
+ }
39
+ if (!container) return;
40
+ if (container.dataset) {
41
+ if (container.dataset.mounted === 'true') return;
42
+ container.dataset.mounted = 'true';
43
+ }
44
+
45
+ // Read data path from HtmlEmbed attribute
46
+ let mountEl = container;
47
+ while (mountEl && !mountEl.getAttribute?.('data-datafiles')) mountEl = mountEl.parentElement;
48
+ const dataAttr = mountEl?.getAttribute?.('data-datafiles');
49
+ const dataPaths = dataAttr
50
+ ? [dataAttr.includes('/') ? dataAttr : `/data/${dataAttr}`]
51
+ : ['/data/rephrasing_metadata.json', './assets/data/rephrasing_metadata.json', '../assets/data/rephrasing_metadata.json', '../../assets/data/rephrasing_metadata.json'];
52
+
53
+ const fetchFirst = async (paths) => {
54
+ for (const p of paths) {
55
+ try { const r = await fetch(p, { cache: 'no-cache' }); if (r.ok) return r.json(); } catch(_) {}
56
+ }
57
+ throw new Error('Data not found');
58
+ };
59
+
60
+ fetchFirst(dataPaths).then(data => buildChart(data)).catch(err => {
61
+ container.innerHTML = `<pre style="color:red;padding:12px;">Error loading data: ${err.message}</pre>`;
62
+ });
63
+
64
+ function buildChart(rawData) {
65
+ // Map source dataset strings to display names
66
+ const sourceMap = {
67
+ 'fineweb-edu-hq-20BT': 'FW-Edu HQ',
68
+ 'fineweb-edu-lq-20BT': 'FW-Edu LQ',
69
+ 'dclm-37BT': 'DCLM',
70
+ 'cosmopedia-25BT': 'Cosmopedia',
71
+ };
72
+
73
+ // Map prompt paths to display names and categories
74
+ const promptMap = {
75
+ 'format/tutorial.md': { name: 'Tutorial', cat: 'Format' },
76
+ 'format/faq.md': { name: 'FAQ', cat: 'Format' },
77
+ 'format/math.md': { name: 'Math', cat: 'Format' },
78
+ 'format/table.md': { name: 'Table', cat: 'Format' },
79
+ 'format/commentary.md': { name: 'Commentary', cat: 'Format' },
80
+ 'format/discussion.md': { name: 'Discussion', cat: 'Format' },
81
+ 'format/article.md': { name: 'Article', cat: 'Format' },
82
+ 'nemotron/diverse_qa_pairs.md': { name: 'Diverse QA', cat: 'Nemotron' },
83
+ 'nemotron/knowledge_list.md': { name: 'Knowledge List', cat: 'Nemotron' },
84
+ 'nemotron/wikipedia_style_rephrasing.md': { name: 'Wikipedia Style', cat: 'Nemotron' },
85
+ 'nemotron/extract_knowledge.md': { name: 'Extract Knowledge', cat: 'Nemotron' },
86
+ 'nemotron/distill.md': { name: 'Distill', cat: 'Nemotron' },
87
+ 'rewire/guided_rewrite_original.md': { name: 'Guided Rewrite', cat: 'REWIRE' },
88
+ 'rewire/guided_rewrite_improved.md': { name: 'Guided Rewrite+', cat: 'REWIRE' },
89
+ };
90
+
91
+ // Map model IDs to family names
92
+ const modelFamilyMap = (modelId) => {
93
+ if (modelId.includes('gemma')) return 'Gemma';
94
+ if (modelId.includes('Qwen') || modelId.includes('qwen')) return 'Qwen';
95
+ if (modelId.includes('Falcon') || modelId.includes('falcon')) return 'Falcon';
96
+ if (modelId.includes('granite') || modelId.includes('Granite')) return 'Granite';
97
+ if (modelId.includes('Llama') || modelId.includes('llama')) return 'Llama';
98
+ if (modelId.includes('SmolLM') || modelId.includes('smollm')) return 'SmolLM2';
99
+ return modelId;
100
+ };
101
+
102
+ // Build link counts from data
103
+ const linkCounts = {};
104
+ const key = (a, b) => `${a}|||${b}`;
105
+
106
+ rawData.forEach(exp => {
107
+ const src = sourceMap[exp.source_dataset];
108
+ const promptInfo = promptMap[exp.prompt];
109
+ const family = modelFamilyMap(exp.model);
110
+ if (!src || !promptInfo) return;
111
+
112
+ const spKey = key(src, promptInfo.name);
113
+ linkCounts[spKey] = (linkCounts[spKey] || 0) + 1;
114
+
115
+ const pmKey = key(promptInfo.name, family);
116
+ linkCounts[pmKey] = (linkCounts[pmKey] || 0) + 1;
117
+ });
118
+
119
+ // Collect unique names in order
120
+ const sources = [...new Set(rawData.map(e => sourceMap[e.source_dataset]).filter(Boolean))];
121
+ const prompts = [...new Set(rawData.map(e => promptMap[e.prompt]?.name).filter(Boolean))];
122
+ const models = [...new Set(rawData.map(e => modelFamilyMap(e.model)).filter(Boolean))];
123
+
124
+ // Build node list
125
+ const nodes = [];
126
+ sources.forEach(name => nodes.push({ name, col: 'source' }));
127
+ prompts.forEach(name => {
128
+ const info = Object.values(promptMap).find(p => p.name === name);
129
+ nodes.push({ name, col: 'prompt', cat: info?.cat || 'Other' });
130
+ });
131
+ models.forEach(name => nodes.push({ name, col: 'model' }));
132
+
133
+ const ni = (name) => nodes.findIndex(n => n.name === name);
134
+
135
+ // Build links
136
+ const links = [];
137
+ Object.entries(linkCounts).forEach(([k, value]) => {
138
+ const [from, to] = k.split('|||');
139
+ const s = ni(from), t = ni(to);
140
+ if (s >= 0 && t >= 0) links.push({ source: s, target: t, value });
141
+ });
142
+
143
+ // Colors
144
+ const sourceColors = { 'FW-Edu HQ': '#6B8DB5', 'FW-Edu LQ': '#B58B9B', 'DCLM': '#7B82C8', 'Cosmopedia': '#8BA878' };
145
+ const catColors = { 'Format': '#4EA5B7', 'Nemotron': '#D4A850', 'REWIRE': '#C87878' };
146
+ const familyColors = { 'Gemma': '#4EA5B7', 'Qwen': '#8B7BE8', 'SmolLM2': '#E8C44A', 'Falcon': '#E889AB', 'Granite': '#5BC0A4', 'Llama': '#D09090' };
147
+
148
+ const nodeColor = (d) => {
149
+ if (d.col === 'source') return sourceColors[d.name] || '#888';
150
+ if (d.col === 'prompt') return catColors[d.cat] || '#888';
151
+ if (d.col === 'model') return familyColors[d.name] || '#888';
152
+ return '#888';
153
+ };
154
+
155
+ // SVG
156
+ const svg = d3.select(container).append('svg').attr('width', '100%').style('display', 'block');
157
+
158
+ const render = () => {
159
+ const width = container.clientWidth || 800;
160
+ const height = Math.max(520, width);
161
+ svg.attr('width', width).attr('height', height);
162
+ svg.selectAll('*').remove();
163
+
164
+ const isDark = document.documentElement.getAttribute('data-theme') === 'dark';
165
+ const textColor = isDark ? 'rgba(255,255,255,0.78)' : 'rgba(0,0,0,0.68)';
166
+ const mutedText = isDark ? 'rgba(255,255,255,0.35)' : 'rgba(0,0,0,0.30)';
167
+ const linkOpacity = isDark ? 0.20 : 0.18;
168
+ const linkHoverOpacity = isDark ? 0.50 : 0.45;
169
+ const fontSize = Math.max(10, Math.min(14, width / 65));
170
+
171
+ const ml = width * 0.005, mr = width * 0.01;
172
+ const mt = height * 0.04, mb = height * 0.01;
173
+
174
+ const sankeyGen = d3.sankey()
175
+ .nodeId(d => d.index)
176
+ .nodeWidth(Math.max(8, width * 0.012))
177
+ .nodePadding(Math.max(3, height * 0.012))
178
+ .nodeSort(null)
179
+ .extent([[ml, mt], [width - mr, height - mb]]);
180
+
181
+ const graph = sankeyGen({
182
+ nodes: nodes.map((d, i) => ({ ...d, index: i })),
183
+ links: links.map(d => ({ ...d }))
184
+ });
185
+
186
+ // Column headers
187
+ const modelNodes = graph.nodes.filter(n => n.col === 'model');
188
+ const colLabels = [
189
+ { text: 'Source Dataset', x: graph.nodes.filter(n => n.col === 'source')[0]?.x0 || ml, anchor: 'start' },
190
+ { text: 'Prompt Strategy', x: graph.nodes.filter(n => n.col === 'prompt')[0]?.x1 || width * 0.35, anchor: 'end' },
191
+ { text: 'Model Family', x: (modelNodes[0]?.x1 || width * 0.75), anchor: 'end' },
192
+ ];
193
+ svg.selectAll('text.col-header')
194
+ .data(colLabels).join('text')
195
+ .attr('class', 'col-header')
196
+ .attr('x', d => d.x).attr('y', mt - 8)
197
+ .attr('text-anchor', d => d.anchor)
198
+ .attr('fill', mutedText)
199
+ .attr('font-size', (fontSize * 1.4) + 'px')
200
+ .attr('font-weight', '700')
201
+ .attr('font-family', 'system-ui, -apple-system, sans-serif')
202
+ .attr('letter-spacing', '0.5px')
203
+ .attr('text-transform', 'uppercase')
204
+ .text(d => d.text);
205
+
206
+ // Category brackets for prompts
207
+ const catGroups = {};
208
+ graph.nodes.filter(n => n.col === 'prompt').forEach(n => {
209
+ if (!catGroups[n.cat]) catGroups[n.cat] = { min: Infinity, max: -Infinity };
210
+ catGroups[n.cat].min = Math.min(catGroups[n.cat].min, n.y0);
211
+ catGroups[n.cat].max = Math.max(catGroups[n.cat].max, n.y1);
212
+ });
213
+ const bracketX = (graph.nodes.find(n => n.col === 'prompt')?.x1 || 0) + 5;
214
+ Object.entries(catGroups).forEach(([cat, { min: y0, max: y1 }]) => {
215
+ const midY = (y0 + y1) / 2;
216
+ svg.append('line')
217
+ .attr('x1', bracketX).attr('x2', bracketX)
218
+ .attr('y1', y0 + 2).attr('y2', y1 - 2)
219
+ .attr('stroke', catColors[cat]).attr('stroke-width', 1.5)
220
+ .attr('stroke-opacity', 0.35).attr('stroke-linecap', 'round');
221
+ svg.append('text')
222
+ .attr('x', bracketX + 4).attr('y', midY)
223
+ .attr('dominant-baseline', 'central')
224
+ .attr('fill', catColors[cat]).attr('fill-opacity', 0.45)
225
+ .attr('font-size', (fontSize * 1.3) + 'px')
226
+ .attr('font-weight', '600')
227
+ .attr('font-family', 'system-ui, -apple-system, sans-serif')
228
+ .attr('letter-spacing', '0.3px')
229
+ .text(cat);
230
+ });
231
+
232
+ // Links
233
+ const gLinks = svg.append('g').attr('class', 'links');
234
+ const linkPath = d3.sankeyLinkHorizontal();
235
+ const linkEls = gLinks.selectAll('path')
236
+ .data(graph.links).join('path')
237
+ .attr('d', linkPath)
238
+ .attr('fill', 'none')
239
+ .attr('stroke', d => nodeColor(d.source))
240
+ .attr('stroke-width', d => Math.max(1, d.width))
241
+ .attr('stroke-opacity', linkOpacity)
242
+ .style('mix-blend-mode', isDark ? 'screen' : 'multiply');
243
+
244
+ // Nodes
245
+ const gNodes = svg.append('g').attr('class', 'nodes');
246
+ const nodeEls = gNodes.selectAll('rect')
247
+ .data(graph.nodes).join('rect')
248
+ .attr('x', d => d.x0).attr('y', d => d.y0)
249
+ .attr('width', d => d.x1 - d.x0)
250
+ .attr('height', d => Math.max(1, d.y1 - d.y0))
251
+ .attr('fill', d => nodeColor(d))
252
+ .attr('fill-opacity', 0.85).attr('rx', 2)
253
+ .attr('stroke', d => nodeColor(d))
254
+ .attr('stroke-width', 0.5).attr('stroke-opacity', 0.3);
255
+
256
+ // Node labels (interactive, same hover as node rects)
257
+ const gLabels = svg.append('g').attr('class', 'labels');
258
+ graph.nodes.forEach(d => {
259
+ const midY = (d.y0 + d.y1) / 2;
260
+ const isSource = d.col === 'source';
261
+ let labelX, anchor;
262
+ if (isSource) { labelX = d.x1 + 5; anchor = 'start'; }
263
+ else { labelX = d.x0 - 5; anchor = 'end'; }
264
+
265
+ const totalIn = (d.targetLinks || []).reduce((s, l) => s + l.value, 0);
266
+ const totalOut = (d.sourceLinks || []).reduce((s, l) => s + l.value, 0);
267
+ const total = Math.max(totalIn, totalOut);
268
+
269
+ gLabels.append('text')
270
+ .datum(d)
271
+ .attr('class', 'node-label')
272
+ .attr('x', labelX).attr('y', midY - (total > 1 ? fontSize * 0.3 : 0))
273
+ .attr('text-anchor', anchor).attr('dominant-baseline', 'central')
274
+ .attr('fill', textColor)
275
+ .attr('font-size', fontSize + 'px').attr('font-weight', '600')
276
+ .attr('font-family', 'system-ui, -apple-system, sans-serif')
277
+ .style('cursor', 'pointer')
278
+ .text(d.name);
279
+
280
+ if (total > 1) {
281
+ gLabels.append('text')
282
+ .datum(d)
283
+ .attr('class', 'node-label')
284
+ .attr('x', labelX).attr('y', midY + fontSize * 0.55)
285
+ .attr('text-anchor', anchor).attr('dominant-baseline', 'central')
286
+ .attr('fill', mutedText)
287
+ .attr('font-size', (fontSize * 0.8) + 'px')
288
+ .attr('font-family', 'system-ui, -apple-system, sans-serif')
289
+ .style('cursor', 'pointer')
290
+ .text(total + ' exp.');
291
+ }
292
+ });
293
+
294
+ // Tooltip
295
+ container.style.position = container.style.position || 'relative';
296
+ let tip = container.querySelector('.d3-tooltip');
297
+ let tipInner;
298
+ if (!tip) {
299
+ tip = document.createElement('div');
300
+ tip.className = 'd3-tooltip';
301
+ Object.assign(tip.style, {
302
+ position: 'absolute', top: '0px', left: '0px',
303
+ transform: 'translate(-9999px, -9999px)',
304
+ pointerEvents: 'none', padding: '8px 12px', borderRadius: '10px',
305
+ fontSize: '12px', lineHeight: '1.4',
306
+ border: '1px solid var(--border-color)',
307
+ background: 'var(--surface-bg)', color: 'var(--text-color)',
308
+ boxShadow: '0 6px 24px rgba(0,0,0,.25)',
309
+ opacity: '0', transition: 'opacity .12s ease',
310
+ backdropFilter: 'saturate(1.12) blur(8px)',
311
+ zIndex: '20', maxWidth: '280px'
312
+ });
313
+ tipInner = document.createElement('div');
314
+ tipInner.className = 'd3-tooltip__inner';
315
+ tip.appendChild(tipInner);
316
+ container.appendChild(tip);
317
+ } else {
318
+ tipInner = tip.querySelector('.d3-tooltip__inner') || tip;
319
+ }
320
+
321
+ const positionTip = (ev) => {
322
+ const [mx, my] = d3.pointer(ev, container);
323
+ const bw = tip.offsetWidth || 220, bh = tip.offsetHeight || 60;
324
+ const ox = (mx + bw + 20 > width) ? -(bw + 12) : 12;
325
+ const oy = (my + bh + 20 > height) ? -(bh + 12) : 14;
326
+ tip.style.transform = `translate(${Math.round(mx + ox)}px, ${Math.round(my + oy)}px)`;
327
+ };
328
+ const showTip = (ev, html) => { tipInner.innerHTML = html; tip.style.opacity = '1'; positionTip(ev); };
329
+ const hideTip = () => { tip.style.opacity = '0'; tip.style.transform = 'translate(-9999px, -9999px)'; };
330
+
331
+ // Interaction
332
+ linkEls
333
+ .on('mouseenter', function (ev, d) {
334
+ linkEls.attr('stroke-opacity', l => l === d ? linkHoverOpacity * 1.5 : linkOpacity * 0.3);
335
+ showTip(ev, `<b>${d.source.name}</b> \u2192 <b>${d.target.name}</b><br/><span style="color:var(--muted-color);">${d.value} experiment${d.value > 1 ? 's' : ''}</span>`);
336
+ })
337
+ .on('mousemove', positionTip)
338
+ .on('mouseleave', function () { linkEls.attr('stroke-opacity', linkOpacity); hideTip(); });
339
+
340
+ // Shared node hover handlers (used by both rects and labels)
341
+ const onNodeEnter = function (ev, d) {
342
+ const connected = new Set();
343
+ (d.sourceLinks || []).forEach(l => connected.add(l.index));
344
+ (d.targetLinks || []).forEach(l => connected.add(l.index));
345
+ linkEls.attr('stroke-opacity', l => connected.has(l.index) ? linkHoverOpacity : linkOpacity * 0.15);
346
+ const totalIn = (d.targetLinks || []).reduce((s, l) => s + l.value, 0);
347
+ const totalOut = (d.sourceLinks || []).reduce((s, l) => s + l.value, 0);
348
+ const total = Math.max(totalIn, totalOut);
349
+ let info = `<b style="font-size:14px;">${d.name}</b>`;
350
+ if (d.cat) info += ` <span style="color:${catColors[d.cat]};font-size:12px;">(${d.cat})</span>`;
351
+ info += `<br/><span style="color:var(--muted-color);">${total} experiment${total > 1 ? 's' : ''}</span>`;
352
+ showTip(ev, info);
353
+ };
354
+ const onNodeLeave = function () { linkEls.attr('stroke-opacity', linkOpacity); hideTip(); };
355
+
356
+ nodeEls.style('cursor', 'pointer')
357
+ .on('mouseenter', onNodeEnter).on('mousemove', positionTip).on('mouseleave', onNodeLeave);
358
+
359
+ gLabels.selectAll('.node-label')
360
+ .on('mouseenter', onNodeEnter).on('mousemove', positionTip).on('mouseleave', onNodeLeave);
361
+ };
362
+
363
+ if (window.ResizeObserver) new ResizeObserver(() => render()).observe(container);
364
+ else window.addEventListener('resize', render);
365
+ render();
366
+ }
367
+ };
368
+
369
+ if (document.readyState === 'loading') {
370
+ document.addEventListener('DOMContentLoaded', () => ensureD3(bootstrap), { once: true });
371
+ } else { ensureD3(bootstrap); }
372
+ })();
373
+ </script>