| | <!DOCTYPE html> |
| | <html lang="en"> |
| | <head> |
| | <meta charset="UTF-8"> |
| | <meta name="viewport" content="width=device-width, initial-scale=1.0"> |
| | <title>RoPE Visualization</title> |
| | <script src="https://cdn.tailwindcss.com"></script> |
| | <script src="https://cdn.jsdelivr.net/npm/chart.js"></script> |
| | <script src="https://cdn.jsdelivr.net/npm/mathjs@11.6.0/lib/browser/math.js"></script> |
| | <style> |
| | .rope-vector { |
| | transition: all 0.3s ease; |
| | } |
| | .vector-container { |
| | perspective: 1000px; |
| | } |
| | .gradient-bg { |
| | background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); |
| | } |
| | .control-panel { |
| | box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); |
| | } |
| | </style> |
| | </head> |
| | <body class="gradient-bg min-h-screen p-6"> |
| | <div class="max-w-6xl mx-auto"> |
| | <div class="text-center mb-8"> |
| | <h1 class="text-4xl font-bold text-gray-800 mb-2">Rotary Positional Embedding (RoPE) Visualization</h1> |
| | <p class="text-lg text-gray-600">Interactive exploration of how RoPE encodes position information in transformer models</p> |
| | </div> |
| |
|
| | <div class="grid grid-cols-1 lg:grid-cols-3 gap-6"> |
| | |
| | <div class="bg-white rounded-xl p-6 control-panel"> |
| | <h2 class="text-xl font-semibold mb-4 text-gray-800">Configuration</h2> |
| | |
| | <div class="space-y-4"> |
| | <div> |
| | <label class="block text-sm font-medium text-gray-700 mb-1">Model Dimension (d)</label> |
| | <select id="dimension" class="w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500"> |
| | <option value="4">4 (Simplified)</option> |
| | <option value="8">8</option> |
| | <option value="16">16</option> |
| | <option value="32">32</option> |
| | <option value="64" selected>64</option> |
| | <option value="128">128</option> |
| | </select> |
| | </div> |
| | |
| | <div> |
| | <label class="block text-sm font-medium text-gray-700 mb-1">Base Frequency (θ)</label> |
| | <input type="range" id="baseFreq" min="1000" max="50000" step="1000" value="10000" class="w-full"> |
| | <div class="flex justify-between text-xs text-gray-500"> |
| | <span>1,000</span> |
| | <span id="baseFreqValue">10,000</span> |
| | <span>50,000</span> |
| | </div> |
| | </div> |
| | |
| | <div> |
| | <label class="block text-sm font-medium text-gray-700 mb-1">Position (p)</label> |
| | <input type="range" id="position" min="0" max="1024" step="1" value="0" class="w-full"> |
| | <div class="flex justify-between text-xs text-gray-500"> |
| | <span>0</span> |
| | <span id="positionValue">0</span> |
| | <span>1024</span> |
| | </div> |
| | </div> |
| | |
| | <div> |
| | <label class="block text-sm font-medium text-gray-700 mb-1">Relative Position (k)</label> |
| | <input type="range" id="relativePos" min="-32" max="32" step="1" value="1" class="w-full"> |
| | <div class="flex justify-between text-xs text-gray-500"> |
| | <span>-32</span> |
| | <span id="relativePosValue">1</span> |
| | <span>32</span> |
| | </div> |
| | </div> |
| | |
| | <div class="pt-2"> |
| | <button id="animateBtn" class="w-full bg-indigo-600 text-white py-2 px-4 rounded-md hover:bg-indigo-700 transition-colors"> |
| | Animate Rotation |
| | </button> |
| | </div> |
| | </div> |
| | |
| | <div class="mt-6 pt-4 border-t border-gray-200"> |
| | <h3 class="text-lg font-medium text-gray-800 mb-2">Dot Product Analysis</h3> |
| | <div class="space-y-2"> |
| | <div class="flex justify-between"> |
| | <span class="text-sm text-gray-600">q · k (original):</span> |
| | <span id="originalDot" class="font-mono">0.00</span> |
| | </div> |
| | <div class="flex justify-between"> |
| | <span class="text-sm text-gray-600">R(q,p) · R(k,p+k):</span> |
| | <span id="rotatedDot" class="font-mono">0.00</span> |
| | </div> |
| | <div class="flex justify-between"> |
| | <span class="text-sm text-gray-600">Difference:</span> |
| | <span id="dotDifference" class="font-mono">0.00</span> |
| | </div> |
| | </div> |
| | </div> |
| | </div> |
| | |
| | |
| | <div class="bg-white rounded-xl p-6 vector-container"> |
| | <h2 class="text-xl font-semibold mb-4 text-gray-800">3D Vector Rotation</h2> |
| | <div class="relative h-64 w-full mb-4"> |
| | <canvas id="vectorCanvas" class="absolute inset-0"></canvas> |
| | </div> |
| | <div class="text-sm text-gray-600"> |
| | <p>Visualization of how RoPE rotates vector components in 3D space. Each pair of vector dimensions is rotated by an angle determined by the position and frequency.</p> |
| | </div> |
| | </div> |
| | |
| | |
| | <div class="bg-white rounded-xl p-6"> |
| | <h2 class="text-xl font-semibold mb-4 text-gray-800">Frequency Spectrum</h2> |
| | <div class="relative h-64 w-full mb-4"> |
| | <canvas id="freqChart"></canvas> |
| | </div> |
| | <div class="text-sm text-gray-600"> |
| | <p>Shows the geometric progression of frequencies across dimensions. Lower dimensions have higher frequencies (shorter wavelengths) while higher dimensions have lower frequencies.</p> |
| | </div> |
| | </div> |
| | </div> |
| | |
| | |
| | <div class="mt-6 bg-white rounded-xl p-6"> |
| | <h2 class="text-xl font-semibold mb-4 text-gray-800">Vector Component Rotation</h2> |
| | <div class="overflow-x-auto"> |
| | <table class="min-w-full divide-y divide-gray-200"> |
| | <thead class="bg-gray-50"> |
| | <tr> |
| | <th scope="col" class="px-6 py-3 text-left text-xs font-medium text-gray-500 uppercase tracking-wider">Dimension Pair</th> |
| | <th scope="col" class="px-6 py-3 text-left text-xs font-medium text-gray-500 uppercase tracking-wider">Original Values</th> |
| | <th scope="col" class="px-6 py-3 text-left text-xs font-medium text-gray-500 uppercase tracking-wider">Rotated Values</th> |
| | <th scope="col" class="px-6 py-3 text-left text-xs font-medium text-gray-500 uppercase tracking-wider">Rotation Angle</th> |
| | <th scope="col" class="px-6 py-3 text-left text-xs font-medium text-gray-500 uppercase tracking-wider">Frequency</th> |
| | </tr> |
| | </thead> |
| | <tbody id="vectorDetails" class="bg-white divide-y divide-gray-200"> |
| | |
| | </tbody> |
| | </table> |
| | </div> |
| | </div> |
| | |
| | |
| | <div class="mt-6 bg-white rounded-xl p-6"> |
| | <h2 class="text-xl font-semibold mb-4 text-gray-800">How RoPE Works</h2> |
| | <div class="prose max-w-none text-gray-700"> |
| | <p>Rotary Positional Embedding (RoPE) encodes position information by rotating pairs of vector components:</p> |
| | <ul class="list-disc pl-5 space-y-1"> |
| | <li>Each pair of dimensions (2i, 2i+1) is treated as a 2D vector</li> |
| | <li>The vector is rotated by an angle θ = p * ωᵢ where p is position and ωᵢ is frequency</li> |
| | <li>Frequencies decrease geometrically with dimension: ωᵢ = 1/(θ^(2i/d))</li> |
| | <li>This creates relative position encoding in the attention dot product: q·k depends only on m-n</li> |
| | <li>The norm (length) of vectors remains unchanged, preserving semantic information</li> |
| | </ul> |
| | <p class="mt-4">The rotation matrix for each pair is:</p> |
| | <div class="bg-gray-100 p-4 rounded-md overflow-x-auto"> |
| | <code class="text-sm"> |
| | Mᵢ = [ cos(pωᵢ) -sin(pωᵢ) ]<br> |
| | [ sin(pωᵢ) cos(pωᵢ) ] |
| | </code> |
| | </div> |
| | </div> |
| | </div> |
| | </div> |
| |
|
| | <script> |
| | |
| | document.addEventListener('DOMContentLoaded', function() { |
| | |
| | const dimensionSelect = document.getElementById('dimension'); |
| | const baseFreqSlider = document.getElementById('baseFreq'); |
| | const baseFreqValue = document.getElementById('baseFreqValue'); |
| | const positionSlider = document.getElementById('position'); |
| | const positionValue = document.getElementById('positionValue'); |
| | const relativePosSlider = document.getElementById('relativePos'); |
| | const relativePosValue = document.getElementById('relativePosValue'); |
| | const animateBtn = document.getElementById('animateBtn'); |
| | const originalDot = document.getElementById('originalDot'); |
| | const rotatedDot = document.getElementById('rotatedDot'); |
| | const dotDifference = document.getElementById('dotDifference'); |
| | const vectorDetails = document.getElementById('vectorDetails'); |
| | |
| | |
| | const vectorCanvas = document.getElementById('vectorCanvas'); |
| | const vectorCtx = vectorCanvas.getContext('2d'); |
| | vectorCanvas.width = vectorCanvas.offsetWidth; |
| | vectorCanvas.height = vectorCanvas.offsetHeight; |
| | |
| | |
| | const freqChartCanvas = document.getElementById('freqChart'); |
| | const freqChart = new Chart(freqChartCanvas, { |
| | type: 'line', |
| | data: { labels: [], datasets: [{ data: [], borderColor: '#4f46e5', tension: 0.1 }] }, |
| | options: { |
| | responsive: true, |
| | maintainAspectRatio: false, |
| | plugins: { legend: { display: false } }, |
| | scales: { |
| | y: { title: { display: true, text: 'Frequency (ωᵢ)' } }, |
| | x: { title: { display: true, text: 'Dimension Pair (i)' } } |
| | } |
| | } |
| | }); |
| | |
| | |
| | let state = { |
| | dimension: 64, |
| | baseFreq: 10000, |
| | position: 0, |
| | relativePos: 1, |
| | isAnimating: false, |
| | animationFrame: null |
| | }; |
| | |
| | |
| | dimensionSelect.addEventListener('change', updateDimension); |
| | baseFreqSlider.addEventListener('input', updateBaseFreq); |
| | positionSlider.addEventListener('input', updatePosition); |
| | relativePosSlider.addEventListener('input', updateRelativePos); |
| | animateBtn.addEventListener('click', toggleAnimation); |
| | |
| | |
| | updateAll(); |
| | |
| | |
| | window.addEventListener('resize', function() { |
| | vectorCanvas.width = vectorCanvas.offsetWidth; |
| | vectorCanvas.height = vectorCanvas.offsetHeight; |
| | renderVectorVisualization(); |
| | }); |
| | |
| | |
| | function updateDimension() { |
| | state.dimension = parseInt(dimensionSelect.value); |
| | updateAll(); |
| | } |
| | |
| | function updateBaseFreq() { |
| | state.baseFreq = parseInt(baseFreqSlider.value); |
| | baseFreqValue.textContent = state.baseFreq.toLocaleString(); |
| | updateAll(); |
| | } |
| | |
| | function updatePosition() { |
| | state.position = parseInt(positionSlider.value); |
| | positionValue.textContent = state.position; |
| | updateAll(); |
| | } |
| | |
| | function updateRelativePos() { |
| | state.relativePos = parseInt(relativePosSlider.value); |
| | relativePosValue.textContent = state.relativePos; |
| | updateAll(); |
| | } |
| | |
| | function toggleAnimation() { |
| | state.isAnimating = !state.isAnimating; |
| | animateBtn.textContent = state.isAnimating ? 'Stop Animation' : 'Animate Rotation'; |
| | |
| | if (state.isAnimating) { |
| | animate(); |
| | } else { |
| | cancelAnimationFrame(state.animationFrame); |
| | } |
| | } |
| | |
| | function animate() { |
| | if (!state.isAnimating) return; |
| | |
| | state.position = (state.position + 1) % 1024; |
| | positionSlider.value = state.position; |
| | positionValue.textContent = state.position; |
| | |
| | updateAll(); |
| | state.animationFrame = requestAnimationFrame(animate); |
| | } |
| | |
| | function updateAll() { |
| | renderVectorVisualization(); |
| | renderFrequencyChart(); |
| | renderVectorDetails(); |
| | calculateDotProducts(); |
| | } |
| | |
| | |
| | function renderVectorVisualization() { |
| | const ctx = vectorCtx; |
| | const width = vectorCanvas.width; |
| | const height = vectorCanvas.height; |
| | const centerX = width / 2; |
| | const centerY = height / 2; |
| | const scale = Math.min(width, height) * 0.3; |
| | |
| | |
| | ctx.clearRect(0, 0, width, height); |
| | |
| | |
| | ctx.strokeStyle = '#e5e7eb'; |
| | ctx.lineWidth = 1; |
| | |
| | |
| | ctx.beginPath(); |
| | ctx.moveTo(0, centerY); |
| | ctx.lineTo(width, centerY); |
| | ctx.stroke(); |
| | |
| | |
| | ctx.beginPath(); |
| | ctx.moveTo(centerX, 0); |
| | ctx.lineTo(centerX, height); |
| | ctx.stroke(); |
| | |
| | |
| | const vector = Array.from({length: state.dimension}, () => math.random(-1, 1)); |
| | |
| | |
| | const frequencies = calculateFrequencies(); |
| | |
| | |
| | const dim1 = 0; |
| | const dim2 = 1; |
| | const dim3 = 2; |
| | |
| | |
| | const x1 = vector[dim1]; |
| | const y1 = vector[dim2]; |
| | const z1 = vector[dim3]; |
| | |
| | |
| | const angle1 = state.position * frequencies[Math.floor(dim1/2)]; |
| | const angle2 = state.position * frequencies[Math.floor(dim2/2)]; |
| | |
| | const rotX1 = x1 * math.cos(angle1) - y1 * math.sin(angle1); |
| | const rotY1 = x1 * math.sin(angle1) + y1 * math.cos(angle1); |
| | const rotZ1 = z1; |
| | |
| | |
| | const project = (x, y, z) => { |
| | const perspective = 1 + z * 0.2; |
| | return { |
| | x: centerX + x * scale * perspective, |
| | y: centerY - y * scale * perspective |
| | }; |
| | }; |
| | |
| | |
| | const origProj = project(x1, y1, z1); |
| | ctx.strokeStyle = '#10b981'; |
| | ctx.lineWidth = 2; |
| | ctx.beginPath(); |
| | ctx.moveTo(centerX, centerY); |
| | ctx.lineTo(origProj.x, origProj.y); |
| | ctx.stroke(); |
| | |
| | |
| | const rotProj = project(rotX1, rotY1, rotZ1); |
| | ctx.strokeStyle = '#3b82f6'; |
| | ctx.lineWidth = 2; |
| | ctx.beginPath(); |
| | ctx.moveTo(centerX, centerY); |
| | ctx.lineTo(rotProj.x, rotProj.y); |
| | ctx.stroke(); |
| | |
| | |
| | ctx.fillStyle = '#111827'; |
| | ctx.font = '12px sans-serif'; |
| | ctx.fillText('Original', origProj.x + 5, origProj.y - 5); |
| | ctx.fillText('Rotated', rotProj.x + 5, rotProj.y + 15); |
| | |
| | |
| | ctx.fillStyle = '#6b7280'; |
| | ctx.font = '10px sans-serif'; |
| | ctx.fillText(`Dimensions ${dim1},${dim2}`, 10, 20); |
| | } |
| | |
| | function renderFrequencyChart() { |
| | const frequencies = calculateFrequencies(); |
| | const labels = frequencies.map((_, i) => i+1); |
| | |
| | freqChart.data.labels = labels; |
| | freqChart.data.datasets[0].data = frequencies; |
| | freqChart.update(); |
| | } |
| | |
| | function renderVectorDetails() { |
| | |
| | vectorDetails.innerHTML = ''; |
| | |
| | |
| | const frequencies = calculateFrequencies(); |
| | |
| | |
| | const vector = Array.from({length: state.dimension}, () => math.random(-1, 1)); |
| | |
| | |
| | const pairsToShow = Math.min(5, Math.floor(state.dimension/2)); |
| | |
| | for (let i = 0; i < pairsToShow; i++) { |
| | const dim1 = 2*i; |
| | const dim2 = 2*i + 1; |
| | |
| | |
| | const val1 = vector[dim1]; |
| | const val2 = vector[dim2]; |
| | |
| | |
| | const angle = state.position * frequencies[i]; |
| | |
| | |
| | const rotVal1 = val1 * math.cos(angle) - val2 * math.sin(angle); |
| | const rotVal2 = val1 * math.sin(angle) + val2 * math.cos(angle); |
| | |
| | |
| | const row = document.createElement('tr'); |
| | row.className = 'rope-vector'; |
| | |
| | row.innerHTML = ` |
| | <td class="px-6 py-4 whitespace-nowrap text-sm font-mono text-gray-900">${dim1}, ${dim2}</td> |
| | <td class="px-6 py-4 whitespace-nowrap text-sm font-mono text-gray-500">${val1.toFixed(3)}, ${val2.toFixed(3)}</td> |
| | <td class="px-6 py-4 whitespace-nowrap text-sm font-mono text-blue-600">${rotVal1.toFixed(3)}, ${rotVal2.toFixed(3)}</td> |
| | <td class="px-6 py-4 whitespace-nowrap text-sm font-mono text-gray-900">${angle.toFixed(5)} rad</td> |
| | <td class="px-6 py-4 whitespace-nowrap text-sm font-mono text-gray-900">${frequencies[i].toExponential(3)}</td> |
| | `; |
| | |
| | vectorDetails.appendChild(row); |
| | } |
| | |
| | |
| | if (pairsToShow < Math.floor(state.dimension/2)) { |
| | const row = document.createElement('tr'); |
| | row.innerHTML = ` |
| | <td colspan="5" class="px-6 py-2 text-center text-sm text-gray-500">... ${Math.floor(state.dimension/2) - pairsToShow} more pairs ...</td> |
| | `; |
| | vectorDetails.appendChild(row); |
| | } |
| | } |
| | |
| | function calculateDotProducts() { |
| | |
| | const q = Array.from({length: state.dimension}, () => math.random(-1, 1)); |
| | const k = Array.from({length: state.dimension}, () => math.random(-1, 1)); |
| | |
| | |
| | const frequencies = calculateFrequencies(); |
| | |
| | |
| | const originalDotValue = dotProduct(q, k); |
| | |
| | |
| | const rotatedQ = applyRoPE(q, state.position, frequencies); |
| | const rotatedK = applyRoPE(k, state.position + state.relativePos, frequencies); |
| | |
| | |
| | const rotatedDotValue = dotProduct(rotatedQ, rotatedK); |
| | |
| | |
| | originalDot.textContent = originalDotValue.toFixed(4); |
| | rotatedDot.textContent = rotatedDotValue.toFixed(4); |
| | dotDifference.textContent = Math.abs(originalDotValue - rotatedDotValue).toFixed(4); |
| | } |
| | |
| | |
| | function calculateFrequencies() { |
| | const frequencies = []; |
| | const numPairs = Math.floor(state.dimension / 2); |
| | |
| | for (let i = 0; i < numPairs; i++) { |
| | frequencies.push(1 / Math.pow(state.baseFreq, (2 * i) / state.dimension)); |
| | } |
| | |
| | return frequencies; |
| | } |
| | |
| | function applyRoPE(vector, position, frequencies) { |
| | const rotated = [...vector]; |
| | |
| | for (let i = 0; i < Math.floor(vector.length / 2); i++) { |
| | const dim1 = 2 * i; |
| | const dim2 = 2 * i + 1; |
| | |
| | const val1 = vector[dim1]; |
| | const val2 = vector[dim2]; |
| | |
| | const angle = position * frequencies[i]; |
| | |
| | rotated[dim1] = val1 * math.cos(angle) - val2 * math.sin(angle); |
| | rotated[dim2] = val1 * math.sin(angle) + val2 * math.cos(angle); |
| | } |
| | |
| | return rotated; |
| | } |
| | |
| | function dotProduct(a, b) { |
| | return a.reduce((sum, val, i) => sum + val * b[i], 0); |
| | } |
| | }); |
| | </script> |
| | </body> |
| | </html> |