Clémentine
Init
ffdff5d
<div class="d3-mmlu-heatmap">
<div class="heatmap-container"></div>
<div class="legend-container"></div>
</div>
<style>
.d3-mmlu-heatmap {
position: relative;
margin: 24px 0;
}
.d3-mmlu-heatmap .heatmap-container {
width: 100%;
}
.d3-mmlu-heatmap .legend-container {
margin-top: 8px;
padding: 0 8px;
}
.d3-mmlu-heatmap .legend-title {
font-size: 12px;
font-weight: 600;
color: var(--text-color);
margin-bottom: 12px;
text-align: center;
}
.d3-mmlu-heatmap .legend-grid {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 8px 24px;
font-size: 11px;
color: var(--text-color);
}
.d3-mmlu-heatmap .legend-column {
display: flex;
flex-direction: column;
gap: 8px;
}
.d3-mmlu-heatmap .legend-item {
display: flex;
align-items: flex-start;
gap: 8px;
}
.d3-mmlu-heatmap .legend-label {
font-weight: 700;
min-width: 20px;
}
.d3-mmlu-heatmap .legend-text {
flex: 1;
line-height: 1.4;
}
.d3-mmlu-heatmap .axis-label {
fill: var(--text-color);
font-size: 11px;
font-weight: 600;
}
.d3-mmlu-heatmap .cell-text {
fill: var(--text-color);
font-size: 10px;
font-weight: 600;
pointer-events: none;
}
@media (max-width: 768px) {
.d3-mmlu-heatmap .legend-grid {
grid-template-columns: 1fr;
}
}
</style>
<script>
(() => {
// Load D3 from CDN once
const ensureD3 = (cb) => {
if (window.d3 && typeof window.d3.select === 'function') return cb();
let s = document.getElementById('d3-cdn-script');
if (!s) {
s = document.createElement('script');
s.id = 'd3-cdn-script';
s.src = 'https://cdn.jsdelivr.net/npm/d3@7/dist/d3.min.js';
document.head.appendChild(s);
}
const onReady = () => { if (window.d3 && typeof window.d3.select === 'function') cb(); };
s.addEventListener('load', onReady, { once: true });
if (window.d3) onReady();
};
const bootstrap = () => {
const scriptEl = document.currentScript;
let container = scriptEl ? scriptEl.previousElementSibling : null;
if (!(container && container.classList && container.classList.contains('d3-mmlu-heatmap'))) {
const cs = Array.from(document.querySelectorAll('.d3-mmlu-heatmap')).filter(el => !(el.dataset && el.dataset.mounted === 'true'));
container = cs[cs.length - 1] || null;
}
if (!container) return;
if (container.dataset) {
if (container.dataset.mounted === 'true') return;
container.dataset.mounted = 'true';
}
// Tooltip
container.style.position = container.style.position || 'relative';
let tip = container.querySelector('.d3-tooltip');
let tipInner;
if (!tip) {
tip = document.createElement('div');
tip.className = 'd3-tooltip';
Object.assign(tip.style, {
position: 'absolute',
top: '0px',
left: '0px',
transform: 'translate(-9999px, -9999px)',
pointerEvents: 'none',
padding: '8px 10px',
borderRadius: '8px',
fontSize: '12px',
lineHeight: '1.35',
border: '1px solid var(--border-color)',
background: 'var(--surface-bg)',
color: 'var(--text-color)',
boxShadow: '0 4px 24px rgba(0,0,0,.18)',
opacity: '0',
transition: 'opacity .12s ease'
});
tipInner = document.createElement('div');
tipInner.className = 'd3-tooltip__inner';
tipInner.style.textAlign = 'left';
tip.appendChild(tipInner);
container.appendChild(tip);
} else {
tipInner = tip.querySelector('.d3-tooltip__inner') || tip;
}
// Heatmap container (no card)
const heatmapContainer = container.querySelector('.heatmap-container');
const svg = d3.select(heatmapContainer).append('svg').attr('width', '100%').style('display', 'block');
const defs = svg.append('defs');
const gRoot = svg.append('g');
const gCells = gRoot.append('g');
const gAxes = gRoot.append('g');
// Data from the image (5 models)
const models = [
'Mistral-7B-v0.1',
'Qwen1.5-7B',
'gemma-7b',
'phi-2',
'DeciLM-7B'
];
const promptFormats = [
'...? -> choice1/choice2/...',
'Q:...? A: -> choice1/choice2/...',
'Question: ...? Answer: -> choice1/choice2/...',
'Question: ...? Choices: ... Answer: -> choice1/choice2/...',
'Question: ...? Choices: A. ... Answer: -> choice1/choice2/...',
'Question: ...? Choices: (A) ... Answer: -> choice1/choice2/...',
'Question: ...? Choices: A. ... Answer: -> A/B/C/D',
'Question: ...? Choices: (A) Answer: -> (A)/(B)/(C)/(D)'
];
const matrix = [
[49.0, 50.5, 52.1, 54.5, 56.4, 55.4, 55.5, 57.0], // Mistral-7B-v0.1
[37.6, 41.8, 43.5, 47.9, 50.8, 51.2, 22.9, 47.7], // Qwen1.5-7B
[44.6, 48.0, 47.6, 53.5, 54.2, 54.9, 56.4, 50.7], // gemma-7b
[39.1, 44.3, 46.5, 46.1, 47.1, 48.4, 51.7, 45.8], // phi-2
[43.6, 48.9, 49.5, 51.0, 51.3, 52.0, 52.8, 52.3] // DeciLM-7B
];
// Colors: diverging palette (purple for low, yellow for high)
const getDivergingColors = (count) => {
try {
if (window.ColorPalettes && typeof window.ColorPalettes.getColors === 'function') {
return window.ColorPalettes.getColors('diverging', count);
}
} catch (_) { }
// Fallback: diverging scale from purple (low) to yellow (high)
const colors = [];
for (let i = 0; i < count; i++) {
const t = i / (count - 1);
// Purple (dark) -> lighter purple -> green -> yellow
if (t < 0.25) {
// Dark purple to medium purple
const r = Math.round(75 + (t / 0.25) * 50);
const g = Math.round(0 + (t / 0.25) * 30);
const b = Math.round(130 + (t / 0.25) * 50);
colors.push(`rgb(${r}, ${g}, ${b})`);
} else if (t < 0.5) {
// Purple to blue-green
const t2 = (t - 0.25) / 0.25;
const r = Math.round(125 - t2 * 75);
const g = Math.round(30 + t2 * 100);
const b = Math.round(180 - t2 * 80);
colors.push(`rgb(${r}, ${g}, ${b})`);
} else if (t < 0.75) {
// Blue-green to green
const t2 = (t - 0.5) / 0.25;
const r = Math.round(50 + t2 * 50);
const g = Math.round(130 + t2 * 70);
const b = Math.round(100 - t2 * 50);
colors.push(`rgb(${r}, ${g}, ${b})`);
} else {
// Green to yellow
const t2 = (t - 0.75) / 0.25;
const r = Math.round(100 + t2 * 155);
const g = Math.round(200 - t2 * 50);
const b = Math.round(50 - t2 * 50);
colors.push(`rgb(${r}, ${g}, ${b})`);
}
}
return colors;
};
const palette = getDivergingColors(10);
let width = 900;
const margin = { top: 10, right: 20, bottom: 20, left: 100 }; // Only left margin for model names
function updateSize() {
width = container.clientWidth || 900;
// Calculate actual content dimensions
const nRows = models.length;
const nCols = promptFormats.length;
const innerWidth = width - margin.left - margin.right;
const maxDim = Math.max(nRows, nCols);
const availableSize = Math.min(innerWidth, 600);
const cellSize = availableSize / maxDim;
const gridWidth = cellSize * nCols;
const gridHeight = cellSize * nRows;
const labelsHeight = 15; // space for X-axis labels
// Calculate exact SVG dimensions needed
const actualWidth = margin.left + gridWidth + margin.right;
const actualHeight = margin.top + gridHeight + labelsHeight + margin.bottom;
svg
.attr('viewBox', `0 0 ${actualWidth} ${actualHeight}`)
.attr('preserveAspectRatio', 'xMidYMin meet')
.style('width', '100%')
.style('height', 'auto');
gRoot.attr('transform', `translate(${margin.left},${margin.top})`);
return { innerWidth: gridWidth, innerHeight: gridHeight + labelsHeight };
}
function getColorScale(values, minV, maxV) {
const hasPalette = palette.length > 0;
if (hasPalette && window.ColorPalettes && typeof window.ColorPalettes.getColors === 'function') {
// Use quantile scale but with emphasis on extremes
const sorted = [...values].sort((a, b) => a - b);
const n = sorted.length;
// Create custom quantiles that emphasize extremes
const quantiles = [];
for (let i = 0; i <= 10; i++) {
const q = i / 10;
// Apply a power transformation to emphasize extremes
const transformedQ = q < 0.5
? Math.pow(q * 2, 1.5) / 2
: 0.5 + Math.pow((q - 0.5) * 2, 1.5) / 2;
const idx = Math.floor(transformedQ * (n - 1));
quantiles.push(sorted[Math.min(idx, n - 1)]);
}
const scale = d3.scaleQuantile().domain(quantiles).range(palette);
return (v) => scale(v);
}
// Fallback: non-linear scale that emphasizes extremes
const linearScale = d3.scaleLinear()
.domain([minV, maxV])
.range([0, 1])
.clamp(true);
return (v) => {
const t = linearScale(v);
// Apply power transformation to emphasize extremes
let transformedT;
if (t < 0.5) {
transformedT = Math.pow(t * 2, 1.8) / 2;
} else {
transformedT = 0.5 + Math.pow((t - 0.5) * 2, 1.8) / 2;
}
// Purple (low) -> Green (mid) -> Yellow (high)
if (transformedT < 0.25) {
const r = Math.round(75 + (transformedT / 0.25) * 50);
const g = Math.round(0 + (transformedT / 0.25) * 30);
const b = Math.round(130 + (transformedT / 0.25) * 50);
return `rgb(${r}, ${g}, ${b})`;
} else if (transformedT < 0.5) {
const t2 = (transformedT - 0.25) / 0.25;
const r = Math.round(125 - t2 * 75);
const g = Math.round(30 + t2 * 100);
const b = Math.round(180 - t2 * 80);
return `rgb(${r}, ${g}, ${b})`;
} else if (transformedT < 0.75) {
const t2 = (transformedT - 0.5) / 0.25;
const r = Math.round(50 + t2 * 50);
const g = Math.round(130 + t2 * 70);
const b = Math.round(100 - t2 * 50);
return `rgb(${r}, ${g}, ${b})`;
} else {
const t2 = (transformedT - 0.75) / 0.25;
const r = Math.round(100 + t2 * 155);
const g = Math.round(200 - t2 * 50);
const b = Math.round(50 - t2 * 50);
return `rgb(${r}, ${g}, ${b})`;
}
};
}
function chooseReadableTextColor(bgColor) {
try {
const m = String(bgColor || '').match(/rgb\(([^)]+)\)/);
if (!m) return '#0e1116';
const [r, g, b] = m[1].split(',').map(s => parseFloat(s.trim()));
const luminance = (0.299 * r + 0.587 * g + 0.114 * b) / 255;
return luminance < 0.5 ? '#ffffff' : '#0e1116';
} catch (_) {
return '#0e1116';
}
}
function render() {
const { innerWidth, innerHeight } = updateSize();
const nRows = models.length;
const nCols = promptFormats.length;
// Calculate cell size to make each cell square
const maxDim = Math.max(nRows, nCols);
const cellSize = innerWidth / maxDim;
const gridWidth = cellSize * nCols;
const gridHeight = cellSize * nRows;
const gridOffsetX = 0;
const gridOffsetY = 0;
const x = d3.scaleBand()
.domain(d3.range(nCols))
.range([0, gridWidth])
.paddingInner(0.08);
const y = d3.scaleBand()
.domain(d3.range(nRows))
.range([0, gridHeight])
.paddingInner(0.08);
// Flatten matrix data
const flatData = [];
let minVal = Infinity, maxVal = -Infinity;
for (let r = 0; r < nRows; r++) {
for (let c = 0; c < nCols; c++) {
const value = matrix[r][c];
if (value < minVal) minVal = value;
if (value > maxVal) maxVal = value;
flatData.push({ r, c, value, model: models[r], format: promptFormats[c] });
}
}
const colorScale = getColorScale(flatData.map(d => d.value), minVal, maxVal);
gCells.attr('transform', `translate(${gridOffsetX}, ${gridOffsetY})`);
const cells = gCells.selectAll('g.cell')
.data(flatData, d => `${d.r}-${d.c}`);
const cellsEnter = cells.enter()
.append('g')
.attr('class', 'cell');
cellsEnter.append('rect')
.attr('rx', 3)
.attr('ry', 3)
.on('mousemove', (event, d) => {
const [px, py] = d3.pointer(event, container);
tipInner.innerHTML = `<strong>${d.model}</strong><br/>${d.format}<br/>Score: ${d.value.toFixed(1)}`;
tip.style.transform = `translate(${px + 10}px, ${py + 10}px)`;
tip.style.opacity = '1';
})
.on('mouseleave', () => {
tip.style.opacity = '0';
});
cellsEnter.append('text')
.attr('class', 'cell-text')
.attr('text-anchor', 'middle')
.attr('dominant-baseline', 'middle');
const cellsMerged = cellsEnter.merge(cells);
cellsMerged.select('rect')
.attr('x', d => x(d.c))
.attr('y', d => y(d.r))
.attr('width', Math.max(1, x.bandwidth()))
.attr('height', Math.max(1, y.bandwidth()))
.attr('fill', d => colorScale(d.value))
.attr('stroke', 'var(--border-color)')
.attr('stroke-width', 0.5);
cellsMerged.select('text')
.attr('x', d => x(d.c) + x.bandwidth() / 2)
.attr('y', d => y(d.r) + y.bandwidth() / 2)
.text(d => d.value.toFixed(1))
.style('fill', function(d) {
try {
const rect = this.parentNode.querySelector('rect');
const bg = rect ? getComputedStyle(rect).fill : colorScale(d.value);
return chooseReadableTextColor(bg);
} catch (_) {
return '#0e1116';
}
});
cells.exit().remove();
// Axes
gAxes.selectAll('*').remove();
gAxes.attr('transform', `translate(${gridOffsetX}, ${gridOffsetY})`);
// X-axis labels (prompt formats)
gAxes.append('g')
.selectAll('text')
.data(promptFormats)
.join('text')
.attr('class', 'axis-label')
.attr('text-anchor', 'middle')
.attr('x', (_, i) => x(i) + x.bandwidth() / 2)
.attr('y', gridHeight + 12)
.text((d, i) => String.fromCharCode(65 + i)); // A, B, C, D, E, F, G
// Y-axis labels (models)
gAxes.append('g')
.selectAll('text')
.data(models)
.join('text')
.attr('class', 'axis-label')
.attr('text-anchor', 'end')
.attr('x', -10)
.attr('y', (_, i) => y(i) + y.bandwidth() / 2)
.attr('dominant-baseline', 'middle')
.text(d => d);
// Update HTML legend
const legendContainer = container.querySelector('.legend-container');
legendContainer.innerHTML = '';
const legendTitle = document.createElement('div');
legendTitle.className = 'legend-title';
legendTitle.textContent = 'Prompt Formats:';
legendContainer.appendChild(legendTitle);
const legendGrid = document.createElement('div');
legendGrid.className = 'legend-grid';
// Column 1: A, B, C, D (first 4)
const column1 = document.createElement('div');
column1.className = 'legend-column';
// Column 2: E, F, G, H (last 4)
const column2 = document.createElement('div');
column2.className = 'legend-column';
promptFormats.forEach((format, i) => {
const item = document.createElement('div');
item.className = 'legend-item';
const label = document.createElement('span');
label.className = 'legend-label';
label.textContent = `${String.fromCharCode(65 + i)}.`;
const text = document.createElement('span');
text.className = 'legend-text';
text.textContent = format;
item.appendChild(label);
item.appendChild(text);
// First 4 go to column 1, rest go to column 2
if (i < 4) {
column1.appendChild(item);
} else {
column2.appendChild(item);
}
});
legendGrid.appendChild(column1);
legendGrid.appendChild(column2);
legendContainer.appendChild(legendGrid);
}
// Initial render + resize handling
render();
const rerender = () => render();
if (window.ResizeObserver) {
const ro = new ResizeObserver(() => rerender());
ro.observe(container);
} else {
window.addEventListener('resize', rerender);
}
};
if (document.readyState === 'loading') {
document.addEventListener('DOMContentLoaded', () => ensureD3(bootstrap), { once: true });
} else {
ensureD3(bootstrap);
}
})();
</script>