smol-training-playbook / app /src /content /embeds /attention-mechanisms.html
loubnabnl's picture
loubnabnl HF Staff
Edit text inside embeds (#12)
3eb5825
<div class="kv-cache-diagrams"></div>
<style>
.kv-cache-diagrams {
font-family: 'Arial', sans-serif;
padding: 0;
}
.kv-cache-diagrams .diagrams-grid {
display: grid;
grid-template-columns: repeat(2, 1fr);
gap: 20px;
}
.kv-cache-diagrams .diagram-container {
padding: 15px;
background: var(--page-bg, #fff);
border-radius: 12px;
border: 1px solid var(--border-color, #e0e0e0);
display: flex;
flex-direction: column;
}
.kv-cache-diagrams .diagram-title {
font-size: 12px;
font-weight: 700;
text-align: center;
margin-bottom: 10px;
color: var(--text-color, #333);
text-transform: uppercase;
letter-spacing: 0.5px;
}
.kv-cache-diagrams .diagram-svg-container {
width: 100%;
height: auto;
min-height: 250px;
}
.kv-cache-diagrams .diagram-svg-container svg {
width: 100%;
height: 100%;
}
.kv-cache-diagrams .placeholder-content {
min-height: 250px;
display: flex;
align-items: center;
justify-content: center;
color: #999;
font-size: 14px;
}
/* Mobile breakpoint - single column */
@media (max-width: 900px) {
.kv-cache-diagrams .diagrams-grid {
grid-template-columns: 1fr;
}
}
/* Legend styles */
.kv-cache-diagrams .legend {
display: flex;
flex-direction: column;
align-items: center;
gap: 8px;
margin-bottom: 20px;
}
.kv-cache-diagrams .legend-title {
font-size: 12px;
font-weight: 700;
color: var(--text-color);
}
.kv-cache-diagrams .legend .items {
display: flex;
flex-wrap: wrap;
gap: 8px 14px;
justify-content: center;
}
.kv-cache-diagrams .legend .item {
display: inline-flex;
align-items: center;
gap: 6px;
white-space: nowrap;
font-size: 12px;
color: var(--text-color);
}
.kv-cache-diagrams .legend .swatch {
width: 18px;
height: 18px;
border-radius: 4px;
border-width: 1px;
border-style: solid;
position: relative;
}
.kv-cache-diagrams .legend .swatch.hatched {
position: relative;
}
.kv-cache-diagrams .legend .swatch.hatched::after {
content: '';
position: absolute;
inset: 0;
border-radius: 4px;
background-image: repeating-linear-gradient(135deg,
transparent,
transparent 3px,
currentColor 3px,
currentColor 4px);
opacity: 0.5;
}
</style>
<!-- Import SVG.js from CDN -->
<script src="https://cdnjs.cloudflare.com/ajax/libs/svg.js/3.2.5/svg.min.js"></script>
<script>
(() => {
const bootstrap = () => {
const scriptEl = document.currentScript;
let container = scriptEl ? scriptEl.previousElementSibling : null;
if (!(container && container.classList && container.classList.contains('kv-cache-diagrams'))) {
const candidates = Array.from(document.querySelectorAll('.kv-cache-diagrams'))
.filter((el) => !(el.dataset && el.dataset.mounted === 'true'));
container = candidates[candidates.length - 1] || null;
}
if (!container) return;
if (container.dataset) {
if (container.dataset.mounted === 'true') return;
container.dataset.mounted = 'true';
}
// Apply categorical color palette
const applyColorPalette = () => {
try {
if (window.ColorPalettes && typeof window.ColorPalettes.getColors === 'function') {
const colors = window.ColorPalettes.getColors('categorical', 3);
return {
query: colors[0],
key: colors[1],
value: colors[2]
};
} else {
// Fallback colors
return {
query: '#E889AB',
key: '#4EA5B7',
value: '#E38A42'
};
}
} catch (e) {
console.warn('ColorPalettes not available, using fallback colors');
return {
query: '#E889AB',
key: '#4EA5B7',
value: '#E38A42'
};
}
};
const paletteColors = applyColorPalette();
if (window.ColorPalettes && typeof window.ColorPalettes.refresh === 'function') {
window.ColorPalettes.refresh();
}
// Configuration
const CONFIG = {
colors: {
query: paletteColors.query,
key: paletteColors.key,
value: paletteColors.value,
arrow: 'var(--text-color, #666)',
clkv: '#adb5bd', // Gray color for Compressed Latent KV (same as groups)
clkvFill: '#e9ecef'
},
node: {
width: 18, // Narrow vertical rectangle
height: 42, // Shorter vertical rectangle
rx: 2,
fontSize: 13,
strokeWidth: 1
},
arrow: {
length: 15,
width: 1,
headSize: 4,
gap: 3 // Gap between node edge and arrow
}
};
// Helper: Create a rounded rect node with text
function createNode(svg, x, y, text, type, noCachedPattern = false) {
const config = CONFIG.node;
const colors = CONFIG.colors;
let fill, stroke, useCachedPattern;
switch (type) {
case 'query':
fill = { color: colors.query, opacity: 0.6 };
stroke = colors.query;
useCachedPattern = false;
break;
case 'key':
fill = { color: colors.key, opacity: 0.4 }; // Lighter base for pattern visibility
stroke = colors.key;
useCachedPattern = !noCachedPattern;
break;
case 'value':
fill = { color: colors.value, opacity: 0.4 }; // Lighter base for pattern visibility
stroke = colors.value;
useCachedPattern = !noCachedPattern;
break;
}
const group = svg.group();
// Create rectangle
const rect = group.rect(config.width, config.height)
.move(x - config.width / 2, y - config.height / 2)
.radius(config.rx)
.fill(fill)
.stroke({ color: stroke, width: config.strokeWidth });
// Add hatch pattern overlay for cached nodes (K and V)
if (useCachedPattern) {
const patternId = `hatch-${type}-${Math.random().toString(36).substr(2, 9)}`;
createHatchPattern(svg, stroke, patternId);
group.rect(config.width, config.height)
.move(x - config.width / 2, y - config.height / 2)
.radius(config.rx)
.fill(`url(#${patternId})`)
.stroke('none');
}
// Add small letter indicator with border color
let letter = '';
if (type === 'query') letter = 'Q';
else if (type === 'key') letter = 'K';
else if (type === 'value') letter = 'V';
if (letter) {
const textEl = group.text(letter)
.font({
family: 'Arial, sans-serif',
size: 10,
anchor: 'middle',
weight: '600',
fill: stroke
})
.cx(x)
.cy(y)
.opacity(0.6); // Not too contrasted
}
return { group, rect, x, y };
}
// Helper: Create diagonal hatch pattern for cached nodes
function createHatchPattern(svg, color, id) {
const pattern = svg.pattern(10, 10, function (add) {
// Create diagonal lines from bottom-left to top-right
add.line(0, 10, 10, 0).stroke({ color: color, width: 2, opacity: 0.5 });
});
pattern.attr('id', id);
pattern.attr('patternUnits', 'userSpaceOnUse');
return pattern;
}
// Helper: Create arrow without arrowhead (just a line)
function createArrow(svg, x1, y1, x2, y2) {
const config = CONFIG.arrow;
const arrowGroup = svg.group();
// Create line without marker (no arrowhead)
arrowGroup.line(x1, y1, x2, y2)
.stroke({ color: CONFIG.colors.arrow, width: config.width })
.opacity(0.5);
return arrowGroup;
}
// ========================================================================
// DIAGRAM 1: Individual Q-K-V Pairs (8 columns)
// ========================================================================
function drawDiagram1(containerEl) {
const svg = SVG().addTo(containerEl).size('100%', '100%');
const nodeConfig = CONFIG.node;
const columns = 8;
const columnSpacing = 40;
const rowSpacing = 68;
const padding = 5;
const startX = 30 + padding;
const startY = 20 + padding;
const viewboxWidth = startX * 2 + columnSpacing * (columns - 1) + padding;
const viewboxHeight = startY * 2 + rowSpacing * 2 + padding;
svg.viewbox(0, 0, viewboxWidth, viewboxHeight);
for (let i = 0; i < columns; i++) {
const x = startX + i * columnSpacing;
const yValue = startY;
const yKey = startY + rowSpacing;
const yQuery = startY + rowSpacing * 2;
// Create nodes (no text labels)
const v = createNode(svg, x, yValue, '', 'value');
const k = createNode(svg, x, yKey, '', 'key');
const q = createNode(svg, x, yQuery, '', 'query');
// Create arrows (V -> K -> Q, top to bottom)
// Arrow from V to K
const arrowGap = CONFIG.arrow.gap;
createArrow(svg,
x,
yValue + nodeConfig.height / 2 + arrowGap,
x,
yKey - nodeConfig.height / 2 - arrowGap
);
// Arrow from K to Q
createArrow(svg,
x,
yKey + nodeConfig.height / 2 + arrowGap,
x,
yQuery - nodeConfig.height / 2 - arrowGap
);
}
}
// ========================================================================
// DIAGRAM 2: Multi-Query Attention (8 Q -> 1 K -> 1 V)
// ========================================================================
function drawDiagram2(containerEl) {
const svg = SVG().addTo(containerEl).size('100%', '100%');
const nodeConfig = CONFIG.node;
const queries = 8;
const querySpacing = 40;
const padding = 5;
const centerX = 170 + padding;
const viewboxWidth = 340 + padding * 2;
const rowSpacing = 68;
const startY = 20 + padding;
const viewboxHeight = startY * 2 + rowSpacing * 2 + padding;
svg.viewbox(0, 0, viewboxWidth, viewboxHeight);
// Calculate starting X for queries to center them
const queriesWidth = (queries - 1) * querySpacing;
const startX = centerX - queriesWidth / 2;
const yValue = startY;
const yKey = startY + rowSpacing;
const yQuery = startY + rowSpacing * 2;
// Create K and V nodes (centered)
const k = createNode(svg, centerX, yKey, '', 'key');
const v = createNode(svg, centerX, yValue, '', 'value');
// Arrow from K to V
const arrowGap = CONFIG.arrow.gap;
createArrow(svg,
centerX,
yKey - nodeConfig.height / 2 - arrowGap,
centerX,
yValue + nodeConfig.height / 2 + arrowGap
);
// Create query nodes and arrows converging to K
const arrowSpacing = 2; // Gap between arrow arrival points
const totalArrowWidth = (queries - 1) * arrowSpacing;
const arrowStartX = centerX - totalArrowWidth / 2;
for (let i = 0; i < queries; i++) {
const x = startX + i * querySpacing;
const q = createNode(svg, x, yQuery, '', 'query');
// Calculate arrow start and end points with gap
const startY = yQuery - nodeConfig.height / 2 - arrowGap;
const endY = yKey + nodeConfig.height / 2 + arrowGap;
const endX = arrowStartX + i * arrowSpacing; // Spread arrival points
// Arrow from Q to K (converging) - no arrowhead
svg.line(x, startY, endX, endY)
.stroke({ color: CONFIG.colors.arrow, width: CONFIG.arrow.width })
.opacity(0.4);
}
}
// ========================================================================
// DIAGRAM 3: Grouped-Query Attention (8 Q -> 4 K -> 4 V)
// ========================================================================
function drawDiagram3(containerEl) {
const svg = SVG().addTo(containerEl).size('100%', '100%');
const nodeConfig = CONFIG.node;
const queries = 8;
const groups = 4; // 4 groups of K-V pairs
const querySpacing = 40;
const groupSpacing = 80;
const padding = 5;
const centerX = 170 + padding;
const viewboxWidth = 340 + padding * 2;
const rowSpacing = 68;
const startY = 20 + padding;
const viewboxHeight = startY * 2 + rowSpacing * 2 + padding;
svg.viewbox(0, 0, viewboxWidth, viewboxHeight);
// Calculate starting X for queries to center them
const queriesWidth = (queries - 1) * querySpacing;
const startXQuery = centerX - queriesWidth / 2;
const yValue = startY;
const yKey = startY + rowSpacing;
const yQuery = startY + rowSpacing * 2;
// Calculate starting X for K-V groups
const groupsWidth = (groups - 1) * groupSpacing;
const startXGroup = centerX - groupsWidth / 2;
const arrowGap = CONFIG.arrow.gap;
// Create K and V nodes for each group
const kvPairs = [];
for (let g = 0; g < groups; g++) {
const xGroup = startXGroup + g * groupSpacing;
const k = createNode(svg, xGroup, yKey, '', 'key');
const v = createNode(svg, xGroup, yValue, '', 'value');
kvPairs.push({ k, v, x: xGroup });
// Arrow from K to V
createArrow(svg,
xGroup,
yKey - nodeConfig.height / 2 - arrowGap,
xGroup,
yValue + nodeConfig.height / 2 + arrowGap
);
}
// Create query nodes and arrows (2 queries per K)
const queriesPerK = 2;
const arrowSpacing = 3; // Gap between arrow arrival points
for (let i = 0; i < queries; i++) {
const x = startXQuery + i * querySpacing;
const q = createNode(svg, x, yQuery, '', 'query');
// Determine which K group this Q belongs to (2 queries per K)
const groupIndex = Math.floor(i / 2);
const indexInGroup = i % queriesPerK;
const targetXBase = kvPairs[groupIndex].x;
// Spread arrival points horizontally, centered
const totalArrowWidth = (queriesPerK - 1) * arrowSpacing;
const targetX = targetXBase - totalArrowWidth / 2 + indexInGroup * arrowSpacing;
// Calculate arrow start and end points with gap
const startY = yQuery - nodeConfig.height / 2 - arrowGap;
const endY = yKey + nodeConfig.height / 2 + arrowGap;
// Arrow from Q to K (converging, 2 per K) - no arrowhead
svg.line(x, startY, targetX, endY)
.stroke({ color: CONFIG.colors.arrow, width: CONFIG.arrow.width })
.opacity(0.4);
}
}
// ========================================================================
// DIAGRAM 4: Latent Compressed KV with Projection
// ========================================================================
function drawDiagram4(containerEl) {
const svg = SVG().addTo(containerEl).size('100%', '100%');
const nodeConfig = CONFIG.node;
const columns = 8;
const columnSpacing = 28; // Slightly increased spacing
const rowSpacing = 68; // Same as diagrams 1, 2, 3
const padding = 5;
const startX = 30 + padding; // Starting X for first column
const startY = 20 + padding; // Same as diagrams 1, 2, 3
// Match viewbox dimensions with other diagrams for consistent sizing
const viewboxWidth = 340 + padding * 2; // Same as diagrams 1, 2, 3
const viewboxHeight = startY * 2 + rowSpacing * 2 + padding;
svg.viewbox(0, 0, viewboxWidth, viewboxHeight);
const arrowGap = CONFIG.arrow.gap;
// Calculate positions for K and V rows (inverted: V top, K middle, Q bottom)
const yValue = startY;
const yKey = startY + rowSpacing;
const yQuery = startY + rowSpacing * 2;
// Draw background rectangles for K and V groups
const groupPadding = 6;
const groupX = startX - nodeConfig.width / 2 - groupPadding;
const groupWidth = columnSpacing * (columns - 1) + nodeConfig.width + groupPadding * 2;
const groupHeight = nodeConfig.height + groupPadding * 2;
// V group background (now at top)
svg.rect(groupWidth, groupHeight)
.move(groupX, yValue - nodeConfig.height / 2 - groupPadding)
.radius(4)
.fill('none')
.stroke({ color: '#adb5bd', width: 1 })
.opacity(0.8);
// K group background (in middle)
svg.rect(groupWidth, groupHeight)
.move(groupX, yKey - nodeConfig.height / 2 - groupPadding)
.radius(4)
.fill('none')
.stroke({ color: '#adb5bd', width: 1 })
.opacity(0.8);
// Draw Q, K, V nodes
for (let i = 0; i < columns; i++) {
const x = startX + i * columnSpacing;
// Create nodes (inverted order: V top, K middle, Q bottom)
// No hatch patterns for V and K in diagram 4
const v = createNode(svg, x, yValue, '', 'value', true);
const k = createNode(svg, x, yKey, '', 'key', true);
const q = createNode(svg, x, yQuery, '', 'query');
// Create arrows (V -> K -> Q, top to bottom)
createArrow(svg,
x,
yValue + nodeConfig.height / 2 + arrowGap,
x,
yKey - nodeConfig.height / 2 - arrowGap
);
createArrow(svg,
x,
yKey + nodeConfig.height / 2 + arrowGap,
x,
yQuery - nodeConfig.height / 2 - arrowGap
);
}
// Create Compressed Latent KV node on the right side
const lastColumnX = startX + (columns - 1) * columnSpacing;
const clkvX = lastColumnX + nodeConfig.width / 2 + 75; // Right side with gap
const clkvY = (yKey + yValue) / 2; // Between K and V groups
const clkvWidth = nodeConfig.width; // Same width as other nodes
const clkvHeight = nodeConfig.height; // Same height as other nodes
// Draw CLKV node with cached styling (hatch pattern like K/V nodes)
const clkvGroup = svg.group();
clkvGroup.rect(clkvWidth, clkvHeight)
.move(clkvX - clkvWidth / 2, clkvY - clkvHeight / 2)
.radius(nodeConfig.rx)
.fill({ color: CONFIG.colors.clkvFill, opacity: 0.15 })
.stroke({ color: CONFIG.colors.clkv, width: CONFIG.node.strokeWidth });
// Add hatch pattern overlay for CLKV (like cached K/V nodes)
const clkvPatternId = `hatch-clkv-${Math.random().toString(36).substr(2, 9)}`;
createHatchPattern(svg, CONFIG.colors.clkv, clkvPatternId);
clkvGroup.rect(clkvWidth, clkvHeight)
.move(clkvX - clkvWidth / 2, clkvY - clkvHeight / 2)
.radius(nodeConfig.rx)
.fill(`url(#${clkvPatternId})`)
.stroke('none');
// CLKV text below the node (multi-line)
svg.text('Compressed\nlatent KV')
.font({
family: 'Arial, sans-serif',
size: 9,
anchor: 'middle',
weight: '700',
fill: 'var(--text-color, #666)',
leading: '1.2em'
})
.cx(clkvX)
.cy(clkvY + clkvHeight / 2 + 28);
// Draw projection arrows from CLKV to K group and V group (solid lines, no arrowheads)
// Projection to V group (from top-right of V group to top-left corner of CLKV)
const vGroupTopY = yValue - nodeConfig.height / 2 - groupPadding;
const vGroupRightX = groupX + groupWidth;
const clkvTopLeftX = clkvX - clkvWidth / 2;
const clkvTopLeftY = clkvY - clkvHeight / 2;
svg.line(vGroupRightX, vGroupTopY, clkvTopLeftX, clkvTopLeftY)
.stroke({ color: CONFIG.colors.clkv, width: 1 })
.opacity(0.7);
// Projection to K group (from bottom-right of K group to bottom-left corner of CLKV)
const kGroupBottomY = yKey + nodeConfig.height / 2 + groupPadding;
const kGroupRightX = groupX + groupWidth;
const clkvBottomLeftX = clkvX - clkvWidth / 2;
const clkvBottomLeftY = clkvY + clkvHeight / 2;
svg.line(kGroupRightX, kGroupBottomY, clkvBottomLeftX, clkvBottomLeftY)
.stroke({ color: CONFIG.colors.clkv, width: 1 })
.opacity(0.7);
// Add single "projection" label centered between both lines
const centerTextX = (clkvX + vGroupRightX) / 2 - 5; // Shifted slightly to the right
const centerTextY = (vGroupTopY + kGroupBottomY) / 2;
svg.text('Projection')
.font({
family: 'Arial, sans-serif',
size: 9,
anchor: 'middle',
fill: CONFIG.colors.clkv,
weight: '600'
})
.cx(centerTextX)
.cy(centerTextY - 4);
// Add small arrow pointing left below the text
const arrowY = centerTextY + 6;
const arrowStartX = centerTextX + 15;
const arrowEndX = centerTextX - 15;
const arrowSize = 3;
// Arrow line
svg.line(arrowStartX, arrowY, arrowEndX, arrowY)
.stroke({ color: CONFIG.colors.clkv, width: 1 });
// Arrow head (triangle pointing left)
svg.polygon(`${arrowEndX},${arrowY} ${arrowEndX + arrowSize},${arrowY - arrowSize} ${arrowEndX + arrowSize},${arrowY + arrowSize}`)
.fill(CONFIG.colors.clkv);
}
// Build HTML structure
container.innerHTML = `
<!-- Legend -->
<div class="legend">
<div class="legend-title">Legend</div>
<div class="items">
<span class="item">
<span class="swatch" style="background-color: color-mix(in srgb, ${paletteColors.value} 40%, transparent); border-color: ${paletteColors.value};"></span>
<span>Values</span>
</span>
<span class="item">
<span class="swatch" style="background-color: color-mix(in srgb, ${paletteColors.key} 40%, transparent); border-color: ${paletteColors.key};"></span>
<span>Keys</span>
</span>
<span class="item">
<span class="swatch" style="background-color: color-mix(in srgb, ${paletteColors.query} 60%, transparent); border-color: ${paletteColors.query};"></span>
<span>Queries</span>
</span>
<span class="item">
<span class="swatch hatched" style="background-color: transparent; border-color: #adb5bd; color: #adb5bd;"></span>
<span>Cached during inference</span>
</span>
</div>
</div>
<div class="diagrams-grid">
<!-- DIAGRAM 1: Multi-head Attention -->
<div class="diagram-container">
<div class="diagram-title">Multi-head attention</div>
<div class="diagram-svg-container" id="diagram-1"></div>
</div>
<!-- DIAGRAM 2: Multi-query Attention -->
<div class="diagram-container">
<div class="diagram-title">Multi-query attention</div>
<div class="diagram-svg-container" id="diagram-2"></div>
</div>
<!-- DIAGRAM 3: Grouped Query Attention -->
<div class="diagram-container">
<div class="diagram-title">Grouped query attention</div>
<div class="diagram-svg-container" id="diagram-3"></div>
</div>
<!-- DIAGRAM 4: Multi-head Latent Attention -->
<div class="diagram-container">
<div class="diagram-title">Multi-head latent attention</div>
<div class="diagram-svg-container" id="diagram-4"></div>
</div>
</div>
`;
// Wait for SVG.js to be ready, then draw diagrams
setTimeout(() => {
const diagram1Container = document.getElementById('diagram-1');
const diagram2Container = document.getElementById('diagram-2');
const diagram3Container = document.getElementById('diagram-3');
const diagram4Container = document.getElementById('diagram-4');
if (diagram1Container && diagram2Container && diagram3Container && diagram4Container && typeof SVG !== 'undefined') {
drawDiagram1(diagram1Container);
drawDiagram2(diagram2Container);
drawDiagram3(diagram3Container);
drawDiagram4(diagram4Container);
}
}, 50);
};
// Wait for both DOM and SVG.js to be ready
if (document.readyState === 'loading') {
document.addEventListener('DOMContentLoaded', bootstrap, { once: true });
} else {
bootstrap();
}
})();
</script>