| """ |
| D3.js visualization module for interactive token attention visualization. |
| """ |
|
|
| def create_d3_visualization(data): |
| """ |
| Generate a complete, self-contained HTML string with embedded D3.js visualization. |
| |
| Args: |
| data (dict): JSON structure with nodes and links from prepare_d3_data() |
| |
| Returns: |
| str: Complete HTML string with embedded D3.js, CSS, and JavaScript |
| """ |
| |
| |
| input_nodes = [node for node in data.get('nodes', []) if node.get('type') == 'input'] |
| output_nodes = [node for node in data.get('nodes', []) if node.get('type') == 'output'] |
| links = data.get('links', []) |
| |
| |
| width = 800 |
| height = max(400, max(len(input_nodes), len(output_nodes)) * 50 + 100) |
| |
| |
| input_x = 100 |
| output_x = width - 100 |
| |
| |
| def get_y_pos(index, total): |
| if total <= 1: |
| return height // 2 |
| return 80 + (index * (height - 160)) / (total - 1) |
| |
| |
| svg_html = f""" |
| <div style='display: flex; flex-direction: column; align-items: center; border: 1px solid #ddd; padding: 20px; margin: 10px; background: white; border-radius: 8px;'> |
| <div style='text-align: center; margin-bottom: 15px;'> |
| <h3 style='margin: 0; color: #333;'>Token Attention Visualization</h3> |
| <p style='margin: 5px 0; color: #666;'>Step {data.get('step', 0) + 1} | {len(input_nodes)} input → {len(output_nodes)} output | {len(links)} connections</p> |
| </div> |
| |
| <svg width="{width}" height="{height}" style='border: 1px solid #eee; background: #fafafa; display: block;'> |
| <!-- Background grid --> |
| <defs> |
| <pattern id="grid" width="20" height="20" patternUnits="userSpaceOnUse"> |
| <path d="M 20 0 L 0 0 0 20" fill="none" stroke="#f0f0f0" stroke-width="1"/> |
| </pattern> |
| </defs> |
| <rect width="100%" height="100%" fill="url(#grid)" /> |
| |
| <!-- Column headers --> |
| <text x="{input_x}" y="30" text-anchor="middle" font-size="16" font-weight="bold" fill="#4285f4">Input Tokens</text> |
| <text x="{output_x}" y="30" text-anchor="middle" font-size="16" font-weight="bold" fill="#ea4335">Output Tokens</text> |
| """ |
| |
| |
| for link in links: |
| |
| source_node = next((n for n in input_nodes + output_nodes if n['id'] == link['source']), None) |
| target_node = next((n for n in input_nodes + output_nodes if n['id'] == link['target']), None) |
| |
| if source_node and target_node: |
| |
| if source_node['type'] == 'input': |
| source_idx = next((i for i, n in enumerate(input_nodes) if n['id'] == source_node['id']), 0) |
| source_y = get_y_pos(source_idx, len(input_nodes)) |
| source_x_pos = input_x + 20 |
| else: |
| source_idx = next((i for i, n in enumerate(output_nodes) if n['id'] == source_node['id']), 0) |
| source_y = get_y_pos(source_idx, len(output_nodes)) |
| source_x_pos = output_x - 20 |
| |
| if target_node['type'] == 'input': |
| target_idx = next((i for i, n in enumerate(input_nodes) if n['id'] == target_node['id']), 0) |
| target_y = get_y_pos(target_idx, len(input_nodes)) |
| target_x_pos = input_x - 20 |
| else: |
| target_idx = next((i for i, n in enumerate(output_nodes) if n['id'] == target_node['id']), 0) |
| target_y = get_y_pos(target_idx, len(output_nodes)) |
| target_x_pos = output_x - 20 |
| |
| |
| stroke_width = max(1, min(8, link['weight'] * 20)) |
| opacity = max(0.3, min(1.0, link['weight'] * 2)) |
| color = "#4285f4" if link['type'] == 'input_to_output' else "#ea4335" |
| |
| |
| svg_html += f''' |
| <line x1="{source_x_pos}" y1="{source_y}" x2="{target_x_pos}" y2="{target_y}" |
| stroke="{color}" stroke-width="{stroke_width}" opacity="{opacity}"/> |
| ''' |
| |
| |
| for i, node in enumerate(input_nodes): |
| y = get_y_pos(i, len(input_nodes)) |
| token_text = node['token'] |
| |
| |
| if token_text.startswith('Ġ'): |
| token_text = token_text[1:] |
| if token_text.startswith('▁'): |
| token_text = token_text[1:] |
| if token_text.startswith('##'): |
| token_text = token_text[2:] |
| |
| if len(token_text) > 15: |
| token_text = token_text[:13] + "..." |
| |
| svg_html += f''' |
| <g> |
| <circle cx="{input_x}" cy="{y}" r="12" fill="#4285f4" stroke="#1a73e8" stroke-width="2" opacity="0.9"/> |
| <text x="{input_x - 20}" y="{y + 4}" text-anchor="end" font-size="12" fill="#333" font-weight="bold">{token_text}</text> |
| </g> |
| ''' |
| |
| |
| for i, node in enumerate(output_nodes): |
| y = get_y_pos(i, len(output_nodes)) |
| token_text = node['token'] |
| |
| |
| if token_text.startswith('Ġ'): |
| token_text = token_text[1:] |
| if token_text.startswith('▁'): |
| token_text = token_text[1:] |
| if token_text.startswith('##'): |
| token_text = token_text[2:] |
| |
| if len(token_text) > 15: |
| token_text = token_text[:13] + "..." |
| |
| svg_html += f''' |
| <g> |
| <circle cx="{output_x}" cy="{y}" r="12" fill="#ea4335" stroke="#d33b2c" stroke-width="2" opacity="0.9"/> |
| <text x="{output_x + 20}" y="{y + 4}" text-anchor="start" font-size="12" fill="#333" font-weight="bold">{token_text}</text> |
| </g> |
| ''' |
| |
| |
| svg_html += ''' |
| </svg> |
| |
| <div style='margin-top: 20px; padding: 16px; background: #f8fafc; border: 1px solid #e2e8f0; border-radius: 8px;'> |
| <div style='display: flex; justify-content: center; align-items: center; gap: 32px; font-size: 12px; color: #64748b; font-family: Inter, sans-serif;'> |
| <div style='display: flex; align-items: center; gap: 8px;'> |
| <div style='width: 16px; height: 2px; background: #4285f4; border-radius: 1px;'></div> |
| <span style='color: #1e293b; font-weight: 500;'>Input → Output</span> |
| </div> |
| <div style='display: flex; align-items: center; gap: 8px;'> |
| <div style='display: flex; gap: 2px;'> |
| <div style='width: 8px; height: 1px; background: #64748b;'></div> |
| <div style='width: 8px; height: 2px; background: #64748b;'></div> |
| <div style='width: 8px; height: 3px; background: #64748b;'></div> |
| </div> |
| <span style='color: #1e293b; font-weight: 500;'>Line thickness = weight</span> |
| </div> |
| </div> |
| </div> |
| </div> |
| ''' |
| |
| return svg_html |
|
|
| def create_d3_visualization_old(data): |
| """ |
| OLD VERSION - Generate a complete, self-contained HTML string with embedded D3.js visualization. |
| |
| Args: |
| data (dict): JSON structure with nodes and links from prepare_d3_data() |
| |
| Returns: |
| str: Complete HTML string with embedded D3.js, CSS, and JavaScript |
| """ |
| |
| html_template = f""" |
| <!DOCTYPE html> |
| <html> |
| <head> |
| <meta charset="utf-8"> |
| <style> |
| .visualization-container {{ |
| width: 100%; |
| height: 600px; |
| border: 1px solid #ddd; |
| border-radius: 8px; |
| background: #fafafa; |
| position: relative; |
| overflow: hidden; |
| }} |
| |
| .node {{ |
| cursor: pointer; |
| stroke-width: 2px; |
| }} |
| |
| .node.input {{ |
| fill: #4285f4; |
| stroke: #1a73e8; |
| }} |
| |
| .node.output {{ |
| fill: #ea4335; |
| stroke: #d33b2c; |
| }} |
| |
| .node.highlighted {{ |
| stroke-width: 4px; |
| stroke: #ff6d00; |
| }} |
| |
| .node.dimmed {{ |
| opacity: 0.3; |
| }} |
| |
| .link {{ |
| stroke: #666; |
| stroke-opacity: 0.6; |
| fill: none; |
| }} |
| |
| .link.input-to-output {{ |
| stroke: #4285f4; |
| }} |
| |
| .link.output-to-output {{ |
| stroke: #ea4335; |
| }} |
| |
| .link.highlighted {{ |
| stroke-opacity: 1; |
| stroke-width: 3px; |
| }} |
| |
| .link.dimmed {{ |
| stroke-opacity: 0.1; |
| }} |
| |
| .token-label {{ |
| font-family: 'Courier New', monospace; |
| font-size: 12px; |
| text-anchor: middle; |
| dominant-baseline: central; |
| fill: white; |
| font-weight: bold; |
| pointer-events: none; |
| }} |
| |
| .reset-btn {{ |
| position: absolute; |
| top: 10px; |
| right: 10px; |
| padding: 8px 16px; |
| background: #4285f4; |
| color: white; |
| border: none; |
| border-radius: 4px; |
| cursor: pointer; |
| font-size: 12px; |
| z-index: 100; |
| }} |
| |
| .reset-btn:hover {{ |
| background: #1a73e8; |
| }} |
| |
| .info-panel {{ |
| position: absolute; |
| bottom: 10px; |
| left: 10px; |
| background: rgba(255, 255, 255, 0.9); |
| padding: 8px 12px; |
| border-radius: 4px; |
| font-size: 11px; |
| font-family: Arial, sans-serif; |
| border: 1px solid #ddd; |
| }} |
| </style> |
| </head> |
| <body> |
| <div class="visualization-container" id="viz-container"> |
| <button class="reset-btn" onclick="resetView()">Reset View</button> |
| <div class="info-panel"> |
| <div>Step: {data.get('step', 0) + 1} / {data.get('total_steps', 1)}</div> |
| <div>Nodes: {len(data.get('nodes', []))} | Links: {len(data.get('links', []))}</div> |
| <div>Click nodes to filter connections</div> |
| </div> |
| <svg id="visualization"></svg> |
| </div> |
| |
| <script> |
| // Simple visualization without D3 first - just to test |
| const data = {repr(data)}; |
| |
| // Create simple HTML visualization |
| const container = document.getElementById("viz-container"); |
| let html = "<div style='padding: 20px;'>"; |
| html += "<h3>Debug Info</h3>"; |
| html += "<p>Nodes: " + data.nodes.length + "</p>"; |
| html += "<p>Links: " + data.links.length + "</p>"; |
| |
| // Simple SVG without D3 |
| html += "<svg width='800' height='400' style='border: 1px solid #ccc; background: white;'>"; |
| |
| // Draw input nodes (left side) |
| const inputNodes = data.nodes.filter(n => n.type === "input"); |
| const outputNodes = data.nodes.filter(n => n.type === "output"); |
| |
| inputNodes.forEach((node, i) => {{ |
| const y = 50 + i * 40; |
| html += `<circle cx="50" cy="${{y}}" r="15" fill="#4285f4" stroke="#1a73e8" stroke-width="2"/>`; |
| html += `<text x="80" y="${{y + 5}}" font-size="12" fill="black">${{node.token}}</text>`; |
| }}); |
| |
| // Draw output nodes (right side) |
| outputNodes.forEach((node, i) => {{ |
| const y = 50 + i * 40; |
| html += `<circle cx="700" cy="${{y}}" r="15" fill="#ea4335" stroke="#d33b2c" stroke-width="2"/>`; |
| html += `<text x="620" y="${{y + 5}}" font-size="12" fill="black" text-anchor="end">${{node.token}}</text>`; |
| }}); |
| |
| // Draw links |
| data.links.forEach(link => {{ |
| const sourceNode = data.nodes.find(n => n.id === link.source); |
| const targetNode = data.nodes.find(n => n.id === link.target); |
| if (sourceNode && targetNode) {{ |
| const sourceIdx = sourceNode.type === "input" ? |
| inputNodes.findIndex(n => n.id === sourceNode.id) : |
| outputNodes.findIndex(n => n.id === sourceNode.id); |
| const targetIdx = targetNode.type === "input" ? |
| inputNodes.findIndex(n => n.id === targetNode.id) : |
| outputNodes.findIndex(n => n.id === targetNode.id); |
| |
| const sourceX = sourceNode.type === "input" ? 65 : 685; |
| const targetX = targetNode.type === "input" ? 65 : 685; |
| const sourceY = 50 + sourceIdx * 40; |
| const targetY = 50 + targetIdx * 40; |
| |
| const strokeWidth = Math.max(1, link.weight * 10); |
| const color = link.type === "input_to_output" ? "#4285f4" : "#ea4335"; |
| |
| html += `<line x1="${{sourceX}}" y1="${{sourceY}}" x2="${{targetX}}" y2="${{targetY}}" stroke="${{color}}" stroke-width="${{strokeWidth}}" opacity="0.6"/>`; |
| }} |
| }}); |
| |
| html += "</svg>"; |
| html += "</div>"; |
| |
| container.innerHTML = html; |
| |
| </script> |
| </body> |
| </html> |
| """ |
| |
| return html_template |