smol-training-playbook / app /src /content /embeds /d3-rl-full-length.html
lewtun's picture
lewtun HF Staff
Fix fig
0cdc3b1
<div class="d3-grpo-full-length"></div>
<style>
.d3-grpo-full-length {
width: 100%;
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
position: relative;
}
.d3-grpo-full-length svg {
display: block;
width: 100%;
}
.d3-grpo-full-length .axis path {
stroke: none;
}
.d3-grpo-full-length .axis line {
stroke: var(--axis-color);
shape-rendering: crispEdges;
}
.d3-grpo-full-length .axis text {
fill: var(--tick-color);
font-size: 11px;
}
.d3-grpo-full-length .grid line {
stroke: var(--grid-color);
stroke-dasharray: 2,2;
}
.d3-grpo-full-length .confidence-band {
opacity: 0.15;
}
.d3-grpo-full-length .line {
fill: none;
stroke-width: 2;
stroke-linejoin: round;
stroke-linecap: round;
}
.d3-grpo-full-length .axis-label {
fill: var(--text-color);
font-size: 12px;
font-weight: 600;
}
.d3-grpo-full-length .axis-label.reward {
fill: var(--reward-color);
}
.d3-grpo-full-length .axis-label.length {
fill: var(--length-color);
}
.d3-grpo-full-length .header {
display: flex;
align-items: center;
justify-content: space-between;
flex-wrap: wrap;
gap: 16px;
margin-top: 12px;
padding-top: 12px;
border-top: 1px solid var(--border-color);
}
.d3-grpo-full-length .legend {
display: flex;
flex-direction: column;
align-items: flex-start;
gap: 6px;
}
.d3-grpo-full-length .legend-title {
font-size: 12px;
font-weight: 700;
color: var(--text-color);
}
.d3-grpo-full-length .legend .items {
display: flex;
flex-wrap: wrap;
gap: 8px 14px;
}
.d3-grpo-full-length .legend .item {
display: inline-flex;
align-items: center;
gap: 6px;
white-space: nowrap;
font-size: 12px;
color: var(--text-color);
cursor: pointer;
user-select: none;
opacity: 1;
transition: opacity 0.2s ease;
}
.d3-grpo-full-length .legend .item.dimmed {
opacity: 0.3;
}
.d3-grpo-full-length .legend .swatch {
width: 14px;
height: 14px;
border-radius: 3px;
border: 1px solid var(--border-color);
}
.d3-grpo-full-length .controls {
display: flex;
gap: 16px;
align-items: center;
justify-content: flex-end;
flex-wrap: wrap;
}
.d3-grpo-full-length .controls .control-group {
display: flex;
flex-direction: column;
align-items: flex-start;
gap: 6px;
}
.d3-grpo-full-length .controls label {
font-size: 12px;
font-weight: 700;
color: var(--text-color);
}
.d3-grpo-full-length .controls .toggle-group {
display: flex;
gap: 8px;
align-items: center;
}
.d3-grpo-full-length .controls .toggle-btn {
padding: 6px 12px;
font-size: 12px;
border: 1px solid var(--border-color);
border-radius: 8px;
background: var(--surface-bg);
color: var(--text-color);
cursor: pointer;
transition: all 0.2s ease;
}
.d3-grpo-full-length .controls .toggle-btn:hover {
background: var(--primary-color);
color: white;
border-color: var(--primary-color);
}
.d3-grpo-full-length .controls .toggle-btn.active {
background: var(--primary-color);
color: white;
border-color: var(--primary-color);
}
</style>
<script>
(() => {
const ensureD3 = (cb) => {
if (window.d3 && typeof window.d3.select === 'function') return cb();
let s = document.getElementById('d3-cdn-script');
if (!s) {
s = document.createElement('script');
s.id = 'd3-cdn-script';
s.src = 'https://cdn.jsdelivr.net/npm/d3@7/dist/d3.min.js';
document.head.appendChild(s);
}
const onReady = () => {
if (window.d3 && typeof window.d3.select === 'function') cb();
};
s.addEventListener('load', onReady, { once: true });
if (window.d3) onReady();
};
const bootstrap = () => {
const scriptEl = document.currentScript;
let container = scriptEl ? scriptEl.previousElementSibling : null;
if (!(container && container.classList && container.classList.contains('d3-grpo-full-length'))) {
const candidates = Array.from(document.querySelectorAll('.d3-grpo-full-length'))
.filter((el) => !(el.dataset && el.dataset.mounted === 'true'));
container = candidates[candidates.length - 1] || null;
}
if (!container) return;
if (container.dataset) {
if (container.dataset.mounted === 'true') return;
container.dataset.mounted = 'true';
}
// Data loading configuration
let mountEl = container;
while (mountEl && !mountEl.getAttribute?.('data-datafiles')) {
mountEl = mountEl.parentElement;
}
let providedData = null;
try {
const attr = mountEl && mountEl.getAttribute ? mountEl.getAttribute('data-datafiles') : null;
if (attr && attr.trim()) {
providedData = attr.trim().startsWith('[') ? JSON.parse(attr) : attr.trim();
}
} catch (_) {}
const DEFAULT_CSV = '/data/rl_reward_curves.csv';
const ensureDataPrefix = (p) => {
if (typeof p !== 'string' || !p) return p;
// If it starts with /, it's already absolute
if (p.startsWith('/')) return p;
// Otherwise, prefix with /data/
return `/data/${p}`;
};
const normalizeInput = (inp) => Array.isArray(inp)
? inp.map(ensureDataPrefix)
: (typeof inp === 'string' ? [ensureDataPrefix(inp)] : null);
const CSV_PATHS = Array.isArray(providedData)
? normalizeInput(providedData)
: (typeof providedData === 'string' ? normalizeInput(providedData) || [DEFAULT_CSV] : [
DEFAULT_CSV,
'./assets/data/rl_reward_curves.csv',
'../assets/data/rl_reward_curves.csv',
'../../assets/data/rl_reward_curves.csv'
]);
const fetchFirstAvailable = async (paths) => {
const errors = [];
for (const p of paths) {
try {
const r = await fetch(p, { cache: 'no-cache' });
if (r.ok) return await r.text();
errors.push(`${p}: ${r.status}`);
} catch (e) {
errors.push(`${p}: ${e.message}`);
}
}
throw new Error(`CSV not found. Tried:\n${errors.join('\n')}`);
};
// Tooltip setup
container.style.position = container.style.position || 'relative';
let tip = container.querySelector('.d3-tooltip');
let tipInner;
if (!tip) {
tip = document.createElement('div');
tip.className = 'd3-tooltip';
Object.assign(tip.style, {
position: 'absolute',
top: '0px',
left: '0px',
transform: 'translate(-9999px, -9999px)',
pointerEvents: 'none',
padding: '8px 10px',
borderRadius: '8px',
fontSize: '12px',
lineHeight: '1.35',
border: '1px solid var(--border-color)',
background: 'var(--surface-bg)',
color: 'var(--text-color)',
boxShadow: '0 4px 24px rgba(0,0,0,.18)',
opacity: '0',
transition: 'opacity .12s ease',
zIndex: '1000'
});
tipInner = document.createElement('div');
tipInner.className = 'd3-tooltip__inner';
tipInner.style.textAlign = 'left';
tip.appendChild(tipInner);
container.appendChild(tip);
} else {
tipInner = tip.querySelector('.d3-tooltip__inner') || tip;
}
// SVG setup
const svg = d3.select(container).append('svg').attr('width', '100%').style('display', 'block');
const gRoot = svg.append('g');
const gGrid = gRoot.append('g').attr('class', 'grid');
const gBands = gRoot.append('g').attr('class', 'bands');
const gLines = gRoot.append('g').attr('class', 'lines');
const gAxes = gRoot.append('g').attr('class', 'axes');
// State
let width = 800, height = 400;
const margin = { top: 16, right: 64, bottom: 56, left: 64 };
let rawData = {}; // Store both datasets
let rewardSeries = [];
let lengthSeries = [];
let hiddenSeries = new Set();
let showRunningAverage = true;
const RUNNING_AVG_WINDOW = 50; // steps
// Define colors for reward and length
const REWARD_COLOR = '#4E79A7';
const LENGTH_COLOR = '#F28E2B';
// Color setup
const getColors = (count) => {
if (window.ColorPalettes && window.ColorPalettes.getColors) {
return window.ColorPalettes.getColors('categorical', count);
}
// Fallback colors
return ['#4E79A7', '#F28E2B', '#E15759', '#76B7B2', '#59A14F', '#EDC948'];
};
// Calculate running average based on step window
function calculateRunningAverage(points, windowSize) {
if (points.length === 0) return [];
const avgPoints = [];
for (let i = 0; i < points.length; i++) {
const currentStep = points[i].step;
const minStep = currentStep - windowSize;
// Find all points within the window
const windowPoints = points.filter(p => p.step >= minStep && p.step <= currentStep);
if (windowPoints.length > 0) {
const avgMean = d3.mean(windowPoints, p => p.mean);
const avgMin = d3.mean(windowPoints, p => p.min);
const avgMax = d3.mean(windowPoints, p => p.max);
avgPoints.push({
step: currentStep,
mean: avgMean,
min: avgMin,
max: avgMax
});
}
}
return avgPoints;
}
function parseData(csvText, metricType) {
const rows = d3.csvParse(csvText);
// Determine metric column suffix based on type
const metricSuffix = metricType === 'reward'
? 'train/reward'
: 'train/completions/mean_terminated_length';
// Extract run names (each run has _step, mean, MIN, MAX columns)
const runNames = [];
const headers = Object.keys(rows[0]);
headers.forEach(h => {
if (h.includes(` - ${metricSuffix}`) && !h.includes('MIN') && !h.includes('MAX')) {
const runName = h.split(' - ')[0];
runNames.push(runName);
}
});
// For v18.00, just use a simple label
const displayNameMap = {
'grpo-SmollM3-3B-GRPO-no-think-v18.00': 'No Penalty'
};
// Build series data using train/global_step for x-axis
const series = runNames.map(runName => {
const meanCol = `${runName} - ${metricSuffix}`;
const minCol = `${meanCol}__MIN`;
const maxCol = `${meanCol}__MAX`;
const points = rows
.filter(row => row['train/global_step'] && row[meanCol])
.map(row => ({
step: +row['train/global_step'],
mean: +row[meanCol],
min: +row[minCol],
max: +row[maxCol]
}))
.filter(p => !isNaN(p.step) && !isNaN(p.mean));
// Calculate running average
const runningAvgPoints = calculateRunningAverage(points, RUNNING_AVG_WINDOW);
// Map to display name
const displayName = displayNameMap[runName] || runName;
return {
name: displayName,
fullName: runName,
points,
runningAvgPoints
};
});
// Store in appropriate series array
if (metricType === 'reward') {
rewardSeries = series;
} else {
lengthSeries = series;
}
}
function updateSize() {
width = container.clientWidth || 800;
height = Math.max(320, Math.round(width / 2.5));
svg.attr('width', width).attr('height', height);
gRoot.attr('transform', `translate(${margin.left},${margin.top})`);
return {
innerWidth: width - margin.left - margin.right,
innerHeight: height - margin.top - margin.bottom
};
}
function render() {
const { innerWidth, innerHeight } = updateSize();
if (rewardSeries.length === 0 || lengthSeries.length === 0) return;
// Filter visible series for both metrics
const visibleReward = rewardSeries.filter(s => !hiddenSeries.has(`reward-${s.name}`));
const visibleLength = lengthSeries.filter(s => !hiddenSeries.has(`length-${s.name}`));
// Select which points to use based on running average toggle
const getPoints = (s) => showRunningAverage ? s.runningAvgPoints : s.points;
// Get all points for x-domain calculation
const allRewardPoints = visibleReward.flatMap(s => getPoints(s));
const allLengthPoints = visibleLength.flatMap(s => getPoints(s));
// X scale (shared)
const maxStep = Math.max(
d3.max(allRewardPoints, d => d.step) || 0,
d3.max(allLengthPoints, d => d.step) || 0
);
const xScale = d3.scaleLinear()
.domain([0, maxStep || 1])
.range([0, innerWidth])
.nice();
// Y scales (dual axes)
const minReward = d3.min(allRewardPoints, d => d.mean);
const maxReward = d3.max(allRewardPoints, d => d.mean);
const yScaleReward = d3.scaleLinear()
.domain([minReward * 0.95, maxReward * 1.05])
.range([innerHeight, 0]);
const minLength = d3.min(allLengthPoints, d => d.mean);
const maxLength = d3.max(allLengthPoints, d => d.mean);
const yScaleLength = d3.scaleLinear()
.domain([minLength * 0.95, maxLength * 1.05])
.range([innerHeight, 0]);
// Grid (based on left axis - reward)
gGrid.selectAll('.grid-y').data([0])
.join('g')
.attr('class', 'grid grid-y')
.call(d3.axisLeft(yScaleReward)
.tickSize(-innerWidth)
.tickFormat('')
)
.call(g => g.select('.domain').remove());
// Line generators
const rewardLine = d3.line()
.x(d => xScale(d.step))
.y(d => yScaleReward(d.mean))
.curve(d3.curveMonotoneX);
const lengthLine = d3.line()
.x(d => xScale(d.step))
.y(d => yScaleLength(d.mean))
.curve(d3.curveMonotoneX);
// Render reward lines
gLines.selectAll('.line.reward')
.data(visibleReward, d => `reward-${d.name}`)
.join('path')
.attr('class', 'line reward')
.attr('d', d => rewardLine(getPoints(d)))
.attr('stroke', REWARD_COLOR);
// Render length lines
gLines.selectAll('.line.length')
.data(visibleLength, d => `length-${d.name}`)
.join('path')
.attr('class', 'line length')
.attr('d', d => lengthLine(getPoints(d)))
.attr('stroke', LENGTH_COLOR);
// Axes
const xAxis = gAxes.selectAll('.x-axis').data([0])
.join('g')
.attr('class', 'x-axis axis')
.attr('transform', `translate(0,${innerHeight})`)
.call(d3.axisBottom(xScale).ticks(Math.min(10, Math.floor(innerWidth / 80))));
const yAxisLeft = gAxes.selectAll('.y-axis-left').data([0])
.join('g')
.attr('class', 'y-axis-left axis')
.call(d3.axisLeft(yScaleReward).ticks(8))
.call(g => g.selectAll('.tick line').attr('stroke', REWARD_COLOR).attr('opacity', 0.3))
.call(g => g.selectAll('.tick text').attr('fill', REWARD_COLOR));
const yAxisRight = gAxes.selectAll('.y-axis-right').data([0])
.join('g')
.attr('class', 'y-axis-right axis')
.attr('transform', `translate(${innerWidth},0)`)
.call(d3.axisRight(yScaleLength).ticks(8))
.call(g => g.selectAll('.tick line').attr('stroke', LENGTH_COLOR).attr('opacity', 0.3))
.call(g => g.selectAll('.tick text').attr('fill', LENGTH_COLOR));
// Axis labels
gAxes.selectAll('.x-label').data([0])
.join('text')
.attr('class', 'x-label axis-label')
.attr('text-anchor', 'middle')
.attr('x', innerWidth / 2)
.attr('y', innerHeight + 45)
.text('Training step');
gAxes.selectAll('.y-label-left').data([0])
.join('text')
.attr('class', 'y-label-left axis-label reward')
.attr('text-anchor', 'middle')
.attr('transform', `translate(-48,${innerHeight / 2}) rotate(-90)`)
.text('Reward');
gAxes.selectAll('.y-label-right').data([0])
.join('text')
.attr('class', 'y-label-right axis-label length')
.attr('text-anchor', 'middle')
.attr('transform', `translate(${innerWidth + 48},${innerHeight / 2}) rotate(90)`)
.text('Mean Terminated Length');
// Set CSS variables for colors
document.documentElement.style.setProperty('--reward-color', REWARD_COLOR);
document.documentElement.style.setProperty('--length-color', LENGTH_COLOR);
// Tooltip interactions
const bisect = d3.bisector(d => d.step).left;
svg.on('mousemove', function(event) {
const [mx] = d3.pointer(event, gRoot.node());
const step = xScale.invert(mx);
let tooltipHtml = `<strong>Step: ${Math.round(step)}</strong>`;
if (showRunningAverage) {
tooltipHtml += ` <span style="font-weight:normal;font-size:11px">(${RUNNING_AVG_WINDOW}-step avg)</span>`;
}
tooltipHtml += `<br/>`;
// Add reward values
visibleReward.forEach(s => {
const points = getPoints(s);
const idx = bisect(points, step);
if (idx > 0 && idx < points.length) {
const p = points[idx];
const valueStr = `${(p.mean * 100).toFixed(1)}%`;
tooltipHtml += `<div style="margin-top:4px"><span style="color:${REWARD_COLOR}">●</span> Reward: ${valueStr}</div>`;
}
});
// Add length values
visibleLength.forEach(s => {
const points = getPoints(s);
const idx = bisect(points, step);
if (idx > 0 && idx < points.length) {
const p = points[idx];
const valueStr = `${p.mean.toFixed(1)} tokens`;
tooltipHtml += `<div style="margin-top:4px"><span style="color:${LENGTH_COLOR}">●</span> Length: ${valueStr}</div>`;
}
});
tipInner.innerHTML = tooltipHtml;
const tipBounds = tip.getBoundingClientRect();
const [px, py] = d3.pointer(event, container);
let tipX = px + 12;
let tipY = py - 12;
if (tipX + tipBounds.width > width - 10) {
tipX = px - tipBounds.width - 12;
}
if (tipY - tipBounds.height < 10) {
tipY = py + 20;
}
tip.style.transform = `translate(${tipX}px, ${tipY}px)`;
tip.style.opacity = '1';
});
svg.on('mouseleave', () => {
tip.style.opacity = '0';
tip.style.transform = 'translate(-9999px, -9999px)';
});
}
function makeLegend() {
let header = container.querySelector('.header');
if (!header) {
header = document.createElement('div');
header.className = 'header';
container.appendChild(header);
}
let legend = header.querySelector('.legend');
if (!legend) {
legend = document.createElement('div');
legend.className = 'legend';
header.appendChild(legend);
}
let title = legend.querySelector('.legend-title');
if (!title) {
title = document.createElement('div');
title.className = 'legend-title';
title.textContent = 'Legend';
legend.appendChild(title);
} else {
title.textContent = 'Legend';
}
let items = legend.querySelector('.items');
if (!items) {
items = document.createElement('div');
items.className = 'items';
legend.appendChild(items);
}
items.innerHTML = '';
// Add reward legend item
const rewardItem = document.createElement('span');
rewardItem.className = 'item';
const rewardSwatch = document.createElement('span');
rewardSwatch.className = 'swatch';
rewardSwatch.style.background = REWARD_COLOR;
const rewardText = document.createElement('span');
rewardText.textContent = 'Reward';
rewardItem.appendChild(rewardSwatch);
rewardItem.appendChild(rewardText);
items.appendChild(rewardItem);
// Add length legend item
const lengthItem = document.createElement('span');
lengthItem.className = 'item';
const lengthSwatch = document.createElement('span');
lengthSwatch.className = 'swatch';
lengthSwatch.style.background = LENGTH_COLOR;
const lengthText = document.createElement('span');
lengthText.textContent = 'Mean Terminated Length';
lengthItem.appendChild(lengthSwatch);
lengthItem.appendChild(lengthText);
items.appendChild(lengthItem);
}
function makeControls() {
// No controls needed for dual-axis view
// This function is kept for consistency but does nothing
}
// Load both datasets
const REWARD_PATHS = [
'/data/rl_reward_full_length.csv',
'./assets/data/rl_reward_full_length.csv',
'../assets/data/rl_reward_full_length.csv',
'../../assets/data/rl_reward_full_length.csv'
];
const LENGTH_PATHS = [
'/data/rl_mean_terminated_length_full_length.csv',
'./assets/data/rl_mean_terminated_length_full_length.csv',
'../assets/data/rl_mean_terminated_length_full_length.csv',
'../../assets/data/rl_mean_terminated_length_full_length.csv'
];
Promise.all([
fetchFirstAvailable(REWARD_PATHS),
fetchFirstAvailable(LENGTH_PATHS)
])
.then(([rewardCsvText, lengthCsvText]) => {
// Store both datasets
rawData.reward = rewardCsvText;
rawData.length = lengthCsvText;
// Parse both datasets
parseData(rewardCsvText, 'reward');
parseData(lengthCsvText, 'length');
makeLegend();
makeControls();
render();
// Responsiveness
if (window.ResizeObserver) {
const ro = new ResizeObserver(() => render());
ro.observe(container);
} else {
window.addEventListener('resize', render);
}
})
.catch(err => {
const pre = document.createElement('pre');
pre.style.color = '#f44336';
pre.style.fontSize = '12px';
pre.style.padding = '12px';
pre.textContent = `Error loading data: ${err.message}`;
container.appendChild(pre);
});
};
if (document.readyState === 'loading') {
document.addEventListener('DOMContentLoaded', () => ensureD3(bootstrap), { once: true });
} else {
ensureD3(bootstrap);
}
})();
</script>